数据读入
Pytorch的数据读入是通过DataSet+DataLoader的方式完成的,DataSet定义好数据的格式和数据变换形式,DataLoader通过iterative的方式不断读入批次数据
读入已有的数据集
Pytorch自身支持很多的数据集,可以直接通过对应的函数得到对应的DataSet,然后传入DataLoader中等待处理:
例如读入MNIST数据集
from torchvision import datasets
import torchvision.transforms as transforms
transform = transforms.Compose([transforms.RandomHorizontalFlip,
transforms.RandomCrop,
transforms.ToTensor])
train = datasets.MNIST(root="./datasets",
train=True,
transform=transform,
download=True)
val = datasets.MNIST(root="./datasets",
train=False,
transform=transform,
download=True)
读入自己的数据集
另外也可以通过实现DataSet类来读入自己的数据集,一般来说需要实现三个函数:
-
__init__
:用于向类中传入外部参数,同时定义样本集 -
__getitem__
:用于逐个读取样本集合中的元素,可以进行一定的变换,并将返回训练/验证所需的数据 -
__len__
:用户返回数据集的样本数
下面的例子是所有的图片存储在一个文件夹下面,同时在一个csv文件中保存有图片名称及其对应的标签
from PIL import Image
class CustomDataSet(Dataset):
def __init__(self, image_path, image_class, transform=None, device="cpu"):
self.image_path = image_path
self.image_class = image_class
self.transform = transform
self.device = device
def show_img(self, index):
plt.subplots(1, 1)
img = Image.open(self.image_path[index])
plt.imshow(img[2])
plt.show()
def __getitem__(self, index):
img = Image.open(self.image_path[index])
if img.mode != 'RGB':
raise ValueError("image:{} isn't RGB mode.".format(self.image_path[index]))
label = np.argmax(self.image_class[index])
label = torch.tensor(label).to(self.device)
if self.transform is not None:
img = self.transform(img)
return img.to(self.device), label
def __len__(self):
return len(self.image_path)
构建好DataSet之后就可以通过DataLoader读取自己的数据了
train_loader = DataLoader(train, batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val, batch_size, shuffle=True, drop_last=False)
- shuffle:表示在加载的时候打乱顺序
- drop_last:丢弃掉最后不够一个batch的数据
全部设置完成之后就可以通过下面的函数不断的读取数据集了
for X, y in train_loader:
pass