【AutoML】强化学习如何用于模型蒸馏?

时间:2022-10-12 20:59:57


大家好,欢迎来到专栏《AutoML》,在这个专栏中我们会讲述AutoML技术在深度学习中的应用,这一期讲述在知识蒸馏中的应用。

作者&编辑 | 言有三

上一期给大家介绍了AutoML技术在模型量化中的应用,而知识蒸馏作为模型压缩的另外一个技术,也是学术界和工业界的重点研究方向,今天给大家介绍AutoML在模型蒸馏中的应用思路。

1 N2N Learning

N2N learning是一个知识蒸馏框架,它使用增强学习算法来将teacher模型转化为student模型,框架结构如下:

【AutoML】强化学习如何用于模型蒸馏?

N2N learning基于一个假设,即一个教师网络转化成学生网络的过程可以看作是马尔可夫决策过程Markov Decision Process(MDP),当前的步骤只和有限的之前几步有关系,使用增强学习来进行优化,其基本概念如下:

状态S:将网络的架构作为状态,对于任何一个大的网络,采样后比它小的网络都很多,所以状态空间非常大。

动作A与状态转换T:包括层的缩减以及移除操作。这一个过程通过双向LSTM来实现,它会观察某一层与前后层的关系,学习到是否进行约减或者删除。

奖励r:模型压缩的目标是保证精度的同时尽可能压缩模型,因此reward就是压缩率。

在上图结构中中包括了两个动作,实际进行优化时依次进行:

(1) 首先选择一组动作,用来判断当前层的去留;

(2) 然后使用另一个策略选择一组动作,用来决定剩下每一层的裁剪程度,伪代码流程如下:

【AutoML】强化学习如何用于模型蒸馏?

以上两种动作使用的LSTM也有差异,下图(a)是层的去留使用的LSTM,可以看出是一个双向的LSTM,要同时考虑前向和反向的中间状态。

下图(b)是层的裁剪使用的LSTM,可以看出是一个单向的LSTM,其中at的值范围是[0.1,0.2,...,1],表示压缩率。

【AutoML】强化学习如何用于模型蒸馏?

而奖励reward公式如下:

【AutoML】强化学习如何用于模型蒸馏?

其中C表示压缩率,A表示学生模型准确率,Ateacher表示教师模型准确率,教师网络的输出作为学生网络的真值。

损失函数包括两项,分别是学生网络的分类损失以及蒸馏损失,其中蒸馏损失中教师网络的输出作为学生网络的真值,使用L2距离。

【AutoML】强化学习如何用于模型蒸馏?

【AutoML】强化学习如何用于模型蒸馏?

实验结果如下:

【AutoML】强化学习如何用于模型蒸馏?

从结果来看,实现了非常高的压缩率,并且性能下降不明显。

2 关于模型优化

如果说要我在深度学习这么多方向里选择一个的话,我会毫不犹豫地选择模型优化。公众号写过很多的模型解读了,如下是一些文章总结和直播链接以及资源下载。