深度学习比赛入门——街景字符识别(二)

时间:2024-02-23 19:32:16

这是街景字符识别的第二阶段,数据的读取以及数据的扩增
在我们理解了赛题,知道自己所要完成的任务以及优化目标之后,我们就要对赛题的数据进行处理。本赛题的baseline本质来说,就是对每个字符进行单独的分类,因为只有数字,所以简单处理作为分类任务也是可以的,但是对于带有字符(包括英文或者中文),简单的作为分类问题就不是很适合了,这时候就得查找关于处理OCR字符识别的相关方法了。
回到数据处理上来说,根据指定比赛的数据形式特点,应该对Dataset做一下重载,简单来说就是对原始的Dataset类重新改写,以适应本赛题数据的形式。

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert(\'RGB\')

        if self.transform is not None:
            img = self.transform(img)

        # 设置最长的字符长度为5个
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl) + (5 - len(lbl)) * [10]
        return img, torch.from_numpy(np.array(lbl[:8]))

    def __len__(self):
        return len(self.img_path)

这里,我想说我在这里碰到的坑,我们在读取图像路径以及图像标签的时候,读取出来我发现是非顺序的,即图片与标签是不对应的,这就导致,后面的训练的时候,train_loss以及val_loss一直在5.6或者6.2之间附近震荡,损失值始终一直下降不下去,先开始的时候以为是过拟合,网络太复杂,但是朝着这个方向进行改进,发现没有一丝变化;而且,更离谱的是最后预测测试集的时候,其结果全部是一样,这样进行分析的时候,可以知道网络并没有在数据中学习到东西,通过排查原始数据,发现图像和标签没有对应起来,这样学习最终是得不到好的结果的。所以以后碰到这种情况时,或许我们可以排查以下输入数据的问题,这也证明了一点——除了改造网络,数据的处理也是非常重要的。
数据扩增,可以增加训练数据的数量,提高网络的鲁棒性,在增加数据的同时,可以适当的增加网络的复杂性,这样可以提高训练的正确性;数据增强的常见方法包括:随机旋转,随机裁剪,随机颜色的改变等。
当然,对数据的处理,还可以增加图像的预处理——比如说,图像的增强,锐化,二值化处理,以此来保证识别的字符足够明显清晰,或许可以一点程度上增加训练的效果。数据的增强还是得根据不同的场景适当使用,不如本题就不可以进行翻转操作,这样反而使训练效果下降

train_loader = torch.utils.data.DataLoader(
    SVHNDataset(train_path, train_label,
                transforms.Compose([
                    transforms.Resize((64, 128)),
                    transforms.RandomCrop((60, 120)),
                    transforms.ColorJitter(0.3, 0.3, 0.2),
                    transforms.RandomRotation(5),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])),
    batch_size=40,
    shuffle=True,
    num_workers=0,
)