FitNets: Hints for thin deep nets论文笔记

时间:2024-04-03 18:48:06

论文地址:https://arxiv.org/abs/1412.6550
github地址:https://github.com/adri-romsor/FitNets

这篇文章提出一种设置初始参数的算法,目前很多网络的训练需要使用预训练网络参数。对于一个thin但deeper的网络的训练,作者提出知识蒸馏的方式将另一个大网络的中间层输出蒸馏到该网络中作为预训练参数初始化网络。

Motivation

现有的top-performing的网络(论文2015年发表于ICLR)通常都很deep且wide,这使得参数参数量非常大且难训练,inference time也相对较长。但深度的确对网络的训练起到效果,对特征的拟合效果更好。因此,作者提出训练thin且deep的网络的方法。

Methods

首先,论文使用Hinton提出的基于softmax改造的知识蒸馏作为基础,引入中间层输出作为学生网络训练的引导,类似于基于feature map的知识蒸馏。其整体框架如下图所示:
FitNets: Hints for thin deep nets论文笔记
首先选择待蒸馏的中间层(即teacher的Hint layer和student的Guided layer),如图中绿框和红框所示。由于两者的输出尺寸可能不同,因此,在guided layer后另外接一层卷积层,使得输出尺寸与teacher的hint layer匹配。

接着通过知识蒸馏的方式训练student网络的guided layer之前的所有层,使得student网络的中间层学习到teacher的hint layer的输出,其损失函数为所加卷积层的输出与hint layer的输出的L2Norm:
FitNets: Hints for thin deep nets论文笔记
在选择中间层时作者提出应该选择较靠前的层,因为随着层数的增加,所含信息量越多,单纯地使得输出相同可能造成网络过拟合。

在训练好guided layer之前的层后,将当前的参数作为网络的初始参数,利用知识蒸馏的方式训练student网络的所有层参数,使student学习teacher的输出。由于teacher对于简单任务的预测非常准确,在分类任务中近乎one-hot输出,因此为了弱化预测输出,使所含信息更加丰富,作者使用Hinton等人提出的softmax改造方法,即在softmax前引入τ\tau缩放因子,将teacher和student的pre-softmax输出均除以τ\tau。此时的损失函数为:
FitNets: Hints for thin deep nets论文笔记
第一部分为student的输出与groundtruth的交叉熵损失,第二部分为student与teacher的softmax输出的交叉熵损失。λ\lambda用于调节两个交叉熵的权重比。

Experiments

数据集:CIFAR-10, CIFAR-100, SVHN, MNIST, AFLW
网络:Teacher: maxout convolutional networks, Student: FitNet
FitNets: Hints for thin deep nets论文笔记

Results

FitNets: Hints for thin deep nets论文笔记
FitNets: Hints for thin deep nets论文笔记
FitNets: Hints for thin deep nets论文笔记

Thoughts

这篇文章比较久远,但是从中学习到一些trick,比如feature map的蒸馏相当于是增加了正则化约束,为防止过拟合需要选择较靠前的输出进行蒸馏,另一方面,对于softmax的输出蒸馏的时候最好进行弱化。另外,多stage的训练方式训练网络对于复杂的训练方法应该会有效果。