介绍
Tiny ImageNet Challenge 来源于斯坦福 CS231N 课程,共237M
Tiny Imagenet 有 200 个类。 每个类有 500 张训练图像、50 张验证图像和 50 张测试图像。
下载链接:
/
数据集使用
因为下载来的train跟val文件夹下图片存放位置不一样,所以路径需要一些变动
存放着标签
存放标签跟对应的描述,可以在few-shot或是zero-shot的时候用(下面的加载代码没有使用,只是做简单的分类任务)
train/label/xx/xx_boxes.txt与val/val_annotations.txt: 包括lable与boundingbox的标注,目标检测任务中使用(下面的加载代码没有使用,只是做简单的分类任务)
下面附上代码:
from typing import Any
import torch
import torchvision
import as transforms
from import Dataset
import glob
import argparse
from PIL import Image
class TrainTinyImageNet(Dataset):
def __init__(self, root, id, transform=None) -> None:
super().__init__()
= (root + "\\train\*\*\*.JPEG")
= transform
self.id_dict = id
def __len__(self):
return len()
def __getitem__(self, idx: Any) -> Any:
img_path = [idx]
image = (img_path)
if == 'L':
image = ('RGB')
label = self.id_dict[img_path.split('\\')[-3]]
if :
image = (image)
return image, label
class ValTinyImageNet(Dataset):
def __init__(self, root, id, transform=None):
= (root + "\\val\images\*.JPEG")
= transform
self.id_dict = id
self.cls_dic = {}
for i, line in enumerate(open(root + '\\val\\val_annotations.txt', 'r')):
a = ('\t')
img, cls_id = a[0], a[1]
self.cls_dic[img] = self.id_dict[cls_id]
def __len__(self):
return len()
def __getitem__(self, idx):
img_path = [idx]
image = (img_path)
if == 'L':
image = ('RGB')
label = self.cls_dic[img_path.split('\\')[-1]]
if :
image = (image)
return image, label
def load_tinyimagenet(args):
batch_size = args.batch_size
nw =
root = 'E:\PythonProjects\dataset\\tiny-imagenet-200'
id_dic = {}
for i, line in enumerate(open(root+'\','r')):
id_dic[('\n', '')] = i
num_classes = len(id_dic)
data_transform = {
"train": ([(224),
(224, padding=4),
(),
(),
([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": ([(224),
(),
([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
train_dataset = TrainTinyImageNet(root, id=id_dic, transform=data_transform["train"])
val_dataset = ValTinyImageNet(root, id=id_dic, transform=data_transform["val"])
train_loader = (train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=nw)
val_loader = (val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=nw)
print("TinyImageNet Loading SUCCESS"+
"\nlen of train dataset: "+str(len(train_dataset))+
"\nlen of val dataset: "+str(len(val_dataset)))
return train_loader, val_loader, num_classes
if __name__ == '__main__':
parser = ("parameters")
parser.add_argument("--batch-size", type=int, default=120, help="number of batch size, (default, 512)")
parser.add_argument('--workers', type=int, default=7)
parser.add_argument('--seed', default=42, type=int, nargs='+',
help='seed for initializing training. ')
args = parser.parse_args()
train, val, num_classes = load_tinyimagenet(args)
workers是数据预加载的参数,可以根据cpu情况自行更改