深度學習-Pytorch框架學習之數據處理篇

前言

數據是深度學習的核心,大部分論文裡都會提到data-driven這個詞,也就是

數據驅動的意思。基本的模型搭建完成後,如何處理數據,如何將數據送給網絡,如何做數據增強等等,對於提高網絡的性能都十分重要,本篇文章會簡單講述下數據處理過程,後續有時間會持續更新這方面的內容,互相學習,共勉!


深度學習-Pytorch框架學習之數據處理篇

數據集的均值和標準差

<code>def compute_mean_and_std(dataset):
    # 輸入為PyTorch的dataset,即數據集,輸出為對應數據集均值和標準差
    
    # 均值
    mean_r = 0
    mean_g = 0
    mean_b = 0

    for img, _ in dataset:
        img = np.asarray(img) # 將 PIL Image 改變成numpy的數組類型
        mean_b += np.mean(img[:, :, 0])
        mean_g += np.mean(img[:, :, 1])
        mean_r += np.mean(img[:, :, 2])

    mean_b /= len(dataset)
    mean_g /= len(dataset)
    mean_r /= len(dataset)

    diff_r = 0
    diff_g = 0
    diff_b = 0

    N = 0

    for img, _ in dataset:
        img = np.asarray(img)

        diff_b += np.sum(np.power(img[:, :, 0] - mean_b, 2))
        diff_g += np.sum(np.power(img[:, :, 1] - mean_g, 2))
        diff_r += np.sum(np.power(img[:, :, 2] - mean_r, 2))
        N += np.prod(img[:, :, 0].shape)

    std_b = np.sqrt(diff_b / N)
    std_g = np.sqrt(diff_g / N)
    std_r = np.sqrt(diff_r / N)

    mean = (mean_b.item() / 255.0, mean_g.item() / 255.0, mean_r.item() / 255.0)
    std = (std_b.item() / 255.0, std_g.item() / 255.0, std_r.item() / 255.0)
   
return mean, std/<code>

常用訓練和驗證數據預處理

ToTensor 會將 PIL.Image形狀為 H×W×D,數值範圍為 [0, 255] 的numpy數組轉換形狀為 D×H×W,數值範圍為 [0.0, 1.0] 的 torch.Tensor

<code>train_transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406)std=(0.229, 0.224, 0.225)),
 ])
 val_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                     std=(0.229, 0.224, 0.225)),
])/<code>

視頻數據

<code>import cv2
video = cv2.VideoCapture(mp4_path)
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(video.get(cv2.CAP_PROP_FPS))
video.release()/<code>


未完待續,持續更新!


分享到:


相關文章: