深度学习(增量学习)——GAN在增量学习中的应用(文献综述)

时间:2024-03-22 07:10:36

前言

我将看过的增量学习论文建了一个github库,方便各位阅读地址

持续学习的目的是解决灾难性遗忘,当前持续学习(lifelong learning)的研究主要集中在图像分类这一基础任务上。图像分类任务出现灾难性遗忘(Catastrophic forgetting),其根源在于TT时刻的分类模型没有TT时刻之前的图像数据,意味着需要在没有输入分布的前提下对TT时刻之前的数据进行分类,为了还原出输入图像的分布,目前有研究开始使用生成对抗模型(Generative Adversarial Nets),原因在于GAN可以进行概率分布的变化,可以将隐空间中的概率分布变化为训练图像的概率分布

​如果仅仅利用TT时刻的数据finetuning GAN,则GAN也会出现灾难性遗忘,如下图所示,将MNIST数据集分为10个任务,每个任务GAN只学习生成一类数字,利用condition GAN在MNIST数据集上进行持续学习,condition GAN的输入由类别label、隐空间变量zz组成,可以依据类别label生成对应类别的图像,训练完毕后,生成的图片均为9,即出现灾难性遗忘。

深度学习(增量学习)——GAN在增量学习中的应用(文献综述)

​为了解决GAN上的灾难性遗忘,研究人员采取了一系列措施,大致分为两类:

  • 使用记忆重放(Memory replay)机制。
  • Regularization,即在损失函数中添加正则项,来防止灾难性遗忘。

Memory replay

​如[6],[1],[3],[5]所示,Memory replay在训练TT时刻的GAN时,让T1T-1时刻的GAN生成一批旧类别图片,与TT时刻的新类别图片混合在一起,训练TT时刻的GAN。[1]存储了部分旧类别的图片,与生成的旧类别图片一起训练TT时刻的GAN。为了确保每一个旧类别都具有Memory replay生成的图片,目前主要采用两类方式:

  • 使用Uncondition GAN,用T1T-1时刻的分类器对生成的图片进行分类。为了保证图像质量,在用分类器判断完图像属于AA类别后,AA类别得分高于一个阈值θ\theta,才会用于训练TT时刻的GAN。
  • 使用Condition GAN,可以依据label生成对应类别的图片。

Memory replay的缺陷

​ Memory replay的缺点很明显,若T1T-1时刻的GAN生成的图片质量极差,无法反映图像真实的概率分布,会影响TT时刻GAN的训练,如下图所示,Task 3训练完毕后,生成的iris图像质量较差,直接导致Task 4、Task 5生成的iris图像质量较差。
深度学习(增量学习)——GAN在增量学习中的应用(文献综述)

Regularization

​ Regularization即通过在损失函数中添加正则项,来寻求一个合适的解,如下图红线所示,通过合适的正则项,可以寻找到即可以较好生成AA任务图像,又可以生成较好BB任务图像的解。

深度学习(增量学习)——GAN在增量学习中的应用(文献综述)

​ 目前研究采用的Regularization大致可以分为两类:

  • EWC Regularization
  • L2、L1距离

Regularization 方式一:EWC Regularization

​ [2]在Generator的loss中添加了EWC Regularization,计算公式如下图,第一项为传统的Generator loss,第二项为EWC Regularization。其中FiF_i为Fisher information
深度学习(增量学习)——GAN在增量学习中的应用(文献综述)
深度学习(增量学习)——GAN在增量学习中的应用(文献综述)

​ EWC Regularization可以从概率角度判断Generator中,哪一部分参数对于生成旧类别图像更为重要,通过限制这类参数发生太大改变,来防止GAN发生灾难性遗忘。

Regularization 方式二:L2、L1距离

​ L2、L1距离类似于knowledge distillation,将T1T-1时刻Generator的知识蒸馏到TT时刻的Generator。如下图所示,[7] [6]将同一隐变量输入到T1T-1时刻与TT时刻的Generator,将得到的两张图片做L1距离或L2距离(对应图中的LRAL_{RA}),作为Generator的正则项,若LRAL_{RA}为L2距离,此时Generator的损失函数变为式12.0

深度学习(增量学习)——GAN在增量学习中的应用(文献综述)
深度学习(增量学习)——GAN在增量学习中的应用(文献综述)

Regularization存在的问题

​ Regularization要求TT时刻的GAN,对同一输入,在训练阶段生成的图片既要与新任务图像一致(否则无法欺骗Discriminator),又要与旧任务一致(否则Regularization的值会很高),这是矛盾的。

​ 现有的方案通过condition GAN解决上述问题,如下图所示,假设每次只学习一个新类别,cc表示类别的label,zz表示隐变量,StS_t表示第tt个类别的数据,下图左半部分表示只有旧类别的隐变量会参与Regularization的计算,右半部分表示只有新类别的数据会用于训练Discriminator,如此一来,为了让Regularization项变小,对于旧类别的隐变量,TT时刻Generator生成的图片应该与T1T-1时刻Generator生成的图片尽可能一致,对于新任务,TT时刻的Generator生成的图片需要尽可能与新任务图像一致,才可以欺骗Discriminator。

深度学习(增量学习)——GAN在增量学习中的应用(文献综述)

参考文献

[1] Amanda Rios ,Laurent Itti. Closed-Loop Memory GAN for Continual Learning. In IJCAI, 2019

[2] Ari Seff, Alex Beatson, Daniel Suo,,Han Liu. Continual Learning in Generative Adversarial Nets.2017

[3] Hanul Shin,Jung Kwon Lee,Jaehong Kim,Jiwon Kim. Continual Learning with Deep Generative Replay. In NIPS, 2017

[4] Yue Wu,Yinpeng Chen,Lijuan Wang,Yuancheng Ye,Zicheng Liu,Yandong Guo,Zhengyou Zhang2,Yun Fu.Incremental Classifier Learning with Generative Adversarial Networks.2018

[5] Ye Xiang,Ying Fu,Pan Ji,Hua Huang.Incremental Learning Using Conditional Adversarial Networks.In ICCV, 2019

[6] henshen Wu,Luis Herranz,Xialei Liu,Yaxing Wang,Joost van de Weijer, Bogdan Raducanu. Memory replay GANs- Learning to generate images from new categories without forgetting. In NIPS, 2018

[7]Mengyao Zhai, Lei Chen,Fred Tung,Jiawei He,Megha Nawhal, Greg Mori.Lifelong GAN: Continual Learning for Conditional Image Generation.In ICCV,2019

深度学习(增量学习)——GAN在增量学习中的应用(文献综述)