数据集下载地址:
链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw
提取码:2xq4
之前在:https://cloud.tencent.com/developer/article/1686281创建好了数据集,将它上传到谷歌colab
在colab上的目录如下:
在utils中的rdata.py定义了读取该数据集的代码:
代码语言:javascript复制from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import torch
#预处理
transform = transforms.Compose([transforms.ToTensor()])
path = "/content/drive/My Drive/colab notebooks/data/dogcat"
train_path=path "/train"
test_path=path "/test"
#使用torchvision.datasets.ImageFolder读取数据集指定train和test文件夹
train_data = torchvision.datasets.ImageFolder(train_path, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=1)
test_data = torchvision.datasets.ImageFolder(test_path, transform=transform)
test_loader = DataLoader(test_data, batch_size=32, shuffle=True, num_workers=1)
print(train_data.classes) #根据分的文件夹的名字来确定的类别
print(train_data.class_to_idx) #按顺序为这些类别定义索引为0,1...
print(train_data.imgs) #返回从所有文件夹中得到的图片的路径以及其类别
print(test_data.classes) #根据分的文件夹的名字来确定的类别
print(test_data.class_to_idx) #按顺序为这些类别定义索引为0,1...
print(test_data.imgs) #返回从所有文件夹中得到的图片的路径以及其类别
ImageFolder可以读取我们的train或test下面的文件夹,并为每一个标签进行编码,同时将图片与标签进行对应。
在test.ipynb中运行rdata.py
说明我们创建的数据集是可以用的了。
有了数据集,接下来就是网络的搭建以及训练和测试了。