CNN实战(一):pytorch处理图像数据(Dataset和Dataloader)

2022-09-16 17:20:33 浏览数 (1)

链接:数据集[1] 提取码:onda

  pytorch给我们提供了很多已经封装好的数据集,但是我们经常得使用自己找到的数据集,因此,想要得到一个好的训练结果,合理的数据处理是必不可少的。我们以1400张猫狗图片来进行分析:

1.分析数据:

训练集包含500张狗的图片以及500张猫的图片,测试接包含200张狗的图片以及200张猫的图片。

2.数据预处理:得到一个包含所有图片文件名(包含路径)和标签(狗1猫0)的列表:

代码语言:javascript复制
def init_process(path, lens):
    data = []
    name = find_label(path)
    for i in range(lens[0], lens[1]):
        data.append([path % i, name])
        
    return data

  图片的命名都是带有编号的,训练集中数据编号为0-499,测试集中编号为1000-1200,因此我们可以根据这个规律来读取文件名,比如参数传入:

代码语言:javascript复制
path1 = 'cnn_data/data/training_data/cats/cat.%d.jpg'
data1 = init_process(path1, [0, 500])

data1就是一个包含五百个文件名以及标签的列表。find_label来判断标签是dog还是cat:

代码语言:javascript复制
def find_label(str):
    first, last = 0, 0
    for i in range(len(str) - 1, -1, -1):
        if str[i] == '%' and str[i - 1] == '.':
            last = i - 1
        if (str[i] == 'c' or str[i] == 'd') and str[i - 1] == '/':
            first = i
            break

    name = str[first:last]
    if name == 'dog':
        return 1
    else:
        return 0

dog返回1,cat返回0。

  有了上面两个函数之后,我们经过四次操作,就可以得到四个列表:

代码语言:javascript复制
path1 = 'cnn_data/data/training_data/cats/cat.%d.jpg'
data1 = init_process(path1, [0, 500])
path2 = 'cnn_data/data/training_data/dogs/dog.%d.jpg'
data2 = init_process(path2, [0, 500])
path3 = 'cnn_data/data/testing_data/cats/cat.%d.jpg'
data3 = init_process(path3, [1000, 1200])
path4 = 'cnn_data/data/testing_data/dogs/dog.%d.jpg'
data4 = init_process(path4, [1000, 1200])

随便输出一个列表的前五个:

代码语言:javascript复制
[['cnn_data/data/testing_data/dogs/dog.1000.jpg', 1], ['cnn_data/data/testing_data/dogs/dog.1001.jpg', 1], ['cnn_data/data/testing_data/dogs/dog.1002.jpg', 1], ['cnn_data/data/testing_data/dogs/dog.1003.jpg', 1], ['cnn_data/data/testing_data/dogs/dog.1004.jpg', 1]]

3.利用PIL包的Image库处理图片:

代码语言:javascript复制
def Myloader(path):
    return Image.open(path).convert('RGB')

4.重写pytorch的Dataset类:

代码语言:javascript复制
class MyDataset(Dataset):
    def __init__(self, data, transform, loder):
        self.data = data
        self.transform = transform
        self.loader = loder
    def __getitem__(self, item):
        img, label = self.data[item]
        img = self.loader(img)
        img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.data)

里面有2个比较重要的函数:

__getitem__真正读取数据的地方,迭代器通过索引来读取数据集中的数据,因此只需要这一个方法中加入读取数据的相关功能即可。在这个函数里面,我们对第二步处理得到的列表进行索引,接着利用第三步定义的Myloader来对每一个路径进行处理,最后利用pytorch的transforms对RGB数据进行处理,将其变成Tensor数据。

transform为:

代码语言:javascript复制
transform = transforms.Compose([
        transforms.CenterCrop(224),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))  # 归一化
    ])

对上面四个操作做一些解释:

1)transforms.CenterCrop(224),从图像中心开始裁剪图像,224为裁剪大小

2)transforms.Resize((224, 224)),重新定义图像大小

3)transforms.ToTensor(),很重要的一步,将图像数据转为Tensor

4)transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),归一化

__len__中提供迭代器索引的范围。

因此我们只需要:

代码语言:javascript复制
train_data = data1   data2   data3[0:150]   data4[0:150]
train = MyDataset(train_data, transform=transform, loder=Myloader)
test_data = data3[150:200]   data4[150:200]
test= MyDataset(test_data, transform=transform, loder=Myloader)

就可以得到处理好的Dataset,其中训练集我给了1300张图片,测试集只给了100张。

5.通过pytorch的DataLoader对第四步得到的Dataset进行shuffle以及mini-batch操作,分成一个个小的数据集:

代码语言:javascript复制
train_data = DataLoader(dataset=train, batch_size=5, shuffle=True, num_workers=0, pin_memory=True)
test_data = DataLoader(dataset=test, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)

最后我们只要给定义好的神经网络模型喂数据就OK了!!!

完整代码:

代码语言:javascript复制
# -*- coding: utf-8 -*-
"""
@Time :2021/8/18 9:11
@Author :KI 
@File :CNN.py
@Motto:Hungry And Humble

"""
import torch
from torch import optim
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

def Myloader(path):
    return Image.open(path).convert('RGB')

#得到一个包含路径与标签的列表
def init_process(path, lens):
    data = []
    name = find_label(path)
    for i in range(lens[0], lens[1]):
        data.append([path % i, name])

    return data

class MyDataset(Dataset):
    def __init__(self, data, transform, loder):
        self.data = data
        self.transform = transform
        self.loader = loder
    def __getitem__(self, item):
        img, label = self.data[item]
        img = self.loader(img)
        img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.data)


def find_label(str):
    first, last = 0, 0
    for i in range(len(str) - 1, -1, -1):
        if str[i] == '%' and str[i - 1] == '.':
            last = i - 1
        if (str[i] == 'c' or str[i] == 'd') and str[i - 1] == '/':
            first = i
            break

    name = str[first:last]
    if name == 'dog':
        return 1
    else:
        return 0

def load_data():
    transform = transforms.Compose([
        #transforms.RandomHorizontalFlip(p=0.5),
        #transforms.RandomVerticalFlip(p=0.5),
        transforms.CenterCrop(224),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))  # 归一化
    ])
    path1 = 'cnn_data/data/training_data/cats/cat.%d.jpg'
    data1 = init_process(path1, [0, 500])
    path2 = 'cnn_data/data/training_data/dogs/dog.%d.jpg'
    data2 = init_process(path2, [0, 500])
    path3 = 'cnn_data/data/testing_data/cats/cat.%d.jpg'
    data3 = init_process(path3, [1000, 1200])
    path4 = 'cnn_data/data/testing_data/dogs/dog.%d.jpg'
    data4 = init_process(path4, [1000, 1200])

    train_data = data1   data2   data3[0:150]   data4[0:150]

    train = MyDataset(train_data, transform=transform, loder=Myloader)

    test_data = data3[150:200]   data4[150:200]
    test= MyDataset(test_data, transform=transform, loder=Myloader)

    train_data = DataLoader(dataset=train, batch_size=5, shuffle=True, num_workers=0, pin_memory=True)
    test_data = DataLoader(dataset=test, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)

    return train_data, test_data

train_data以及test_data就是我们最终需要得到的数据。

References

[1] 数据集: https://pan.baidu.com/s/1_M1xZMBvu_wGYdXvq06sVQ

0 人点赞