Deep Learning中的Large Batch Training相关理论与实践

时间:2022-08-28 18:58:24

背景

[作者:DeepLearningStack,阿里巴巴算法工程师,开源TensorFlow Contributor]
在分布式训练时,提高计算通信占比是提高计算加速比的有效手段,当网络通信优化到一定程度时,只有通过增加每个worker上的batch size来提升计算量,进而提高计算通信占比。然而一直以来Deep Learning模型在训练时对Batch Size的选择都是异常敏感的,通常的经验是Large Batch Size会使收敛性变差,而相对小一点的Batch Size才能收敛的更好。当前学术界和工业界已经有一些论文来论证Large Batch Size对收敛性的影响,甚至提出了一些如何使用Large Batch去提高收敛性的方法,本文将对这些论文的重点和脉络做一个梳理。

论文脉络梳理

Large Batch Training是目前学术界和工业界研究的热点,其理论发展非常迅速。但由于非凸优化和Deep Learning的理论研究本身还处于并将长期处于初级阶段,所以即使存在各种各样的理论解释和证明,Large Batch Training相关的理论也尚未得到彻底的解释。为了能够让读者能够更容易理解Large Batch Training当前的学术发展,也为了让论文的阅读更有脉络,我们把学术界中的相关论文按照观点的提出顺序作为梳理如下。下面列出的每篇论文后面都有其要点,便于读者阅读时有个大概的感觉。因为本篇主要梳理Large Batch Training的理论部分,所以会对重点的论文进行分析解释。
  • Accurate, Large Minibatch SGD:Training ImageNet in 1 Hour》:这是FaceBook提出的一篇极具争议性的论文,从实践上来说它的的复现难度也是比较大的。该论文从实践的角度出发,在ResNet上提出了一种针对Large batch training的训练方法,即learning rate scaling rule。当batch size相对于baseline增加N倍时,learning rate也要相应的增加N倍,但也指出batch size的提升有一个upper bound,超过这个值,泛化性依然会变得很差。这篇论文对learning rate scaling rule有一些公式推导,但并不本质,更多的是做了较强的假设。总体来说,这是一篇实验做得比较solid,但理论基础并不丰满的实践论文。
  • A BAYESIAN PERSPECTIVE ON GENERALIZATION AND STOCHASTIC GRADIENT DESCENT》:这是Google发在ICLR 2018上的一篇理论和实验都比较完善的论文。因为在ResNet上已经有了Learning Rate Scaling Rule的成功经验,因此该论文从贝叶斯的角度解释了泛化性和SGD。论文的核心观点是指出了Batch Training相对于Full Batch Training来说引入了Noise,而Noise具有波动的效果,这在论文里被称为Flucturate,它可以在更新时在一定程度上偏离Sharp Minima,从而进入Broad Minima,进而有了较好的泛化性,所以Noise起了较大的作用。进一步的,论文中将SGD的更新公式进行进行分析,等价为一个微分方程的定积分结果,通过将SGD更新公式与微分方程进行等价,导出了Flucturate的表达式,确定了影响其值的变动因素,即和Learning Rate与Batch size有关。若把Flucturate看做常量,那么Learning Rate与Batch Size可以近似看做是线性关系,这与论文2中的Learning Rate Scaling Rule一致。总体来说,这篇论文数学理论相对丰满的解释了Learning Rate Scaling Rule。
  • Don't Decay the Learning Rate, Increase the Batch Size》:这是Google发在ICLR 2018上的第二篇论文,这篇论文的实验和结论非常简单,但是理论基础依然来自于论文3,所以阅读此篇论文之前一定要精度论文3。该论文从推导出的Mini Batch SGD的Flucturate公式出发,提出了一种使用Large Batch Training的加速方法。因为在一个完整的模型训练过程中,通常会随着轮数的增加而适当对Learning Rate做Decay。通过论文3中给出的公式,即Flucturate固定时,Learning Rate与Batch Size成正比关系,引发了思考:究竟是Learning Rate本身需要Decay才能使训练过程继续,还是Learning Rate的Decay间接影响了Noise的Flucturate才能使训练过程继续?通过实验验证,真正影响训练过程的本质是Noise的Flucturate。因此我们考虑到Learning Rate与Batch Size的正比例关系,我们可以固定Learning Rate不变,而将Batch Size增加N倍来缩小Noise的Flucturate。定时增加Batch Size不但可以维持原有方式的Flucturate,还可以加速训练过程,减少Update的更新频次,增加计算通信占比,提高加速比。总体来说,该论文基于论文3为理论基础,提出了一种逐渐增加Batch Size提高计算加速比和收敛加速比的方法。

要点梳理

可以按顺序梳理成以下几个方面

理论基础

  • 从贝叶斯理论角度出发,论证Broad Minima相对于Sharp Minima具有更好的泛化性
  • 用贝叶斯理论解释泛化性是有效的
  • 贝叶斯理论与SGD
  • 随机偏微分方程的与Scaling Rule的推导

优化方法

  • 使用Large Batch Training提高训练速度
 

理论基础

理论基础来自于论文《A BAYESIAN PERSPECTIVE ON GENERALIZATION AND STOCHASTIC GRADIENT DESCENT》,这里只对重点内容进行记录。

从贝叶斯理论角度出发,论证broad minima相对于sharp minima具有更好的泛化性

内容

这部分公式较多,但确实是贝叶斯的理论基础,所以尽量以简单的形式展现出来。首先假设某模型M只有一个参数w,训练样本为x,Label为y,那么可以跟据贝叶斯公式直接写出下面的等式
其中等号右面分母上的第一项可以看做似然函数
Deep Learning中的Large Batch Training相关理论与实践
Deep Learning中的Large Batch Training相关理论与实践

一般情况下,我们对模型参数的分布会做高斯假设

Deep Learning中的Large Batch Training相关理论与实践

所以有

Deep Learning中的Large Batch Training相关理论与实践
Deep Learning中的Large Batch Training相关理论与实践

可以看出这个公式就是模型训练中Loss Function的主要部分,前面一项H(w;M)是Cost,而后面一项是正则项。我们要最小化Loss Function,本质上是最大化C(w;M)这一项。假设我们训练了两组模型参数,如何判断哪一个模型的泛化性更好?这里使用如下公式来判断。

Deep Learning中的Large Batch Training相关理论与实践

等式右面的第二项是对模型的偏好因子,在这里应该均设置为1,消除偏置的影响。右边第一项我们叫做Bayesian Evidence Ratio,它描述了训练样本改变了我们对模型先验偏好的程度。为了计算这个比值,我们需要计算分子和分母。

Deep Learning中的Large Batch Training相关理论与实践

使用泰勒展开式对C(w;M)在最优值w_0附近进行近似展开,得到如下式子。

Deep Learning中的Large Batch Training相关理论与实践
Deep Learning中的Large Batch Training相关理论与实践

至此,我们可以对上述公式的结果进行分析。上述公式中最后一项其实就是Occam Factor。通过分析我们也知道二阶导数正负衡量的是函数的凹凸性,而二阶导数的大小衡量和曲率相关。当C''(w_0)越大时,该位置附近就越弯曲,越接近sharp minima,进而导致P(y|x;M)的概率越低,这符合Occam Razor的原则,越简单的模型泛化性越好,这是因为简单的模型是在Broad Minima上。也可以提高正则系数对C''(w_0)进行惩罚,从而控制Occam factor,提高泛化性。当扩展到多个参数后,该公式如下所示。

Deep Learning中的Large Batch Training相关理论与实践
分析方法相同,不再赘述。

小结

这一部分作者从贝叶斯理论出发,从公式上推导出了Occam Razor的结论,并且论证了落入Sharp Minima的模型泛化性较差的原因,同时也得出了正则项对Sharp Minima具有惩罚作用。

用贝叶斯理论解释泛化性是有效的

内容

这里作者借鉴了论文《Understanding deep learning requires rethinking generalization》中的实验来从贝叶斯理论解释泛化性,与ICLR 2017的这篇Best Paper使用的Deep Learning Model不同,作者使用了最简单的线性模型进行实验,原因是线性模型在计算Bayesian Evidence的时候比Deep Learning简单很多。具体的实验配置可以参考论文,这里直接给出图表。
注:Bayesian Evidence实际上是Log Bayesian Evidence,对上面的结果取了对数。
Deep Learning中的Large Batch Training相关理论与实践

这个实验主要是为了证明Bayesian Evidence的曲线和Test Cross Entropy的变化趋势是一致的,并且也复现了《Understanding deep learning requires rethinking generalization》中呢Deep Learning Model的结果。

小结

这一节中的实验证明,使用贝叶斯理论解释泛化性是有效的,并且得出了预期一致的结果。

贝叶斯理论与SGD

内容

在得出Bayesian Evidence和泛化性是强相关关系的结论之后,作者再次对SGD产生了思考。因为无论是Large Batch还是Small Batch,他们都是Full Batch的近似结果,所以都会引入Noise。作者认为造成不同Batch Size产生不同泛化性的根本原因是Noise的Flucturate程度。一定程度的Noise可以逃离Sharp Minima,带领模型进入Bayesian Evidence较大的区域,即Broad Minima区域;而Batch Size越大,Noise的Flucturate就越小,就很容易陷入Sharp Minima。(这部分的公式推导在这里先不给出,因为这不是这篇文章的重点,有兴趣的同学可以关注这篇论文的附录A)这说明SGD的更新规则本身就带有了一些正则化效果,这个正则化的效果很大程度上来自于SGD本身引入的Noise。这与ICLR 2017 Best Paper《Understanding deep learning requires rethinking generalization》观察到的现象和得出的结论一致,该篇文章中主要思考的一个问题是,SGD在训练完全部样本之后,为什么不是记住所有的样本,而是还学到了一些泛化性?
回到这篇论文,作者认定一定存在一个最佳Batch Size,这个Batch Size既没有使模型进入Sharp Minima区域,又有一定的复杂性,使之让当前的模型效果最好。于是做了不同的实验,得到以下结果。
Deep Learning中的Large Batch Training相关理论与实践
Deep Learning中的Large Batch Training相关理论与实践

这些实验其实就是验证不同Batch Size训练出的模型在test集上的表现,并说明存在一个最佳的Batch Size,使用它训练出的模型,其泛化性优于其他Batch Size训练出的模型。

小结

这一部分从对贝叶斯与泛化性的思考入手,进而尝试解释SGD的特点,从而试图验证不同Batch Size对泛化性的影响。Batch Size的选取可以看成是Depth(Sharp)和Breadth(Broad)的Trade off,所以存在一个最佳的Batch Size,在其他超参数固定时使模型达到最好的泛化效果。

随机偏微分方程的与scaling rule的推导

内容

因为Batch Size的选取,从贝叶斯角度去理解,实际上就是Depth和Breadth的Trade off。所以可以更进一步的对SGD引入的Noise进行分析,进一步去探究这个Noise带来的Flucturate与哪些因素相关,这就需要和随机偏微分方程建立联系了。
首先,将SGD的update公式进行改写。
Deep Learning中的Large Batch Training相关理论与实践
Deep Learning中的Large Batch Training相关理论与实践
Deep Learning中的Large Batch Training相关理论与实践
其中N代表训练集的样本数,ε代表学习率。假设我们用<>代表期望的计算,那么我们有
Deep Learning中的Large Batch Training相关理论与实践

根据中心极限定理,我们可以得出以下结论

Deep Learning中的Large Batch Training相关理论与实践

所以标准的Stochastic Gradient Descent可以看成是标准梯度加上一个Noise,这个Noise就是α中的内容。下面进一步研究Noise的性质。

Deep Learning中的Large Batch Training相关理论与实践
Deep Learning中的Large Batch Training相关理论与实践

其中,F(w)为梯度的协方差项,δ_ij代表了Indicator,即当i=j时,δ_ij=1,否则等于0。这是因为样本和样本之间是相互独立的关系,所以协方差应该等于0。如果看不懂这个公式可以按照下面的原型推理,一目了然。

Deep Learning中的Large Batch Training相关理论与实践

根据协方差矩阵的可列可拆的性质,我们求得如下期望。

Deep Learning中的Large Batch Training相关理论与实践

至此,Noise的统计特性已经全部计算出来,下面需要和随机偏微分方程进行等价。首先,SGD的Update规则是一个离散的过程,不是连续的过程。如果我们把SGD的每一步想象成为一个连续的可微分的过程,每次Update一个偏微分算子,那么可以将上述学习率为ε的Update公式看成是某个微分方程的定积分结果,下面先介绍这个偏微分方程(这个偏微分方程的产生来自于《Handbook of Stochastic Methods》)。

Deep Learning中的Large Batch Training相关理论与实践

这里t是连续的变量,η(t)代表了t时刻的Noise,具有如下性质。

Deep Learning中的Large Batch Training相关理论与实践

因为我们知道Noise的期望必定等于0,而方差会有个波动的Scale,且波动的大小是以F(w)有关,所以这个Scale我们用g来表示,即Flucturate。而SGD的Update规则可以改写如下所示。

Deep Learning中的Large Batch Training相关理论与实践

为了探求g的变化因素,我们需要将偏微分方程的最后一项的方差和SGD的α方差对应起来,得到

Deep Learning中的Large Batch Training相关理论与实践

上面最后的积分公式推导可能会有些迷惑,大概是会迷惑在积分的方差是如何化简到二重积分这一过程,其实积分符号只是个对连续变量的求和过程,所以依然可以使用协方差的可列可拆的性质,如果还是不习惯,将积分符合和dt换成求和符号再去使用协方差公式即可轻松得到结论。

所以,我们得到了相当重要的结论,这是在一定程度上能够解释Learning Rate Scaling Rule的结论。
Deep Learning中的Large Batch Training相关理论与实践

所以,我们得到了结论,SGD引入了一些Noise,这个Noise具有一定的Flucturate,它的大小是和Batch Size成反比,与Learning Rate成正比。

小结

这一节使用偏微分方程和SGD的更新规则,经过一系列的数学推导,得到了SGD引入的Noise对更新过程的Flucturation大小与Batch size和Learning rate的关系。这是这篇论文十分重要的结论,也是Learning Rate Scaling Rule的理论基石。

理论总结

至此,理论基础部分梳理完毕,虽然公式较多较为复杂,但是结论却非常简单。作者从贝叶斯理论的角度出发,推导出了Occam Razor的形式表达,并从公式上论证了Sharp Minima相对于Broad Minima泛化性差的原因。而后又验证了Bayesian Evidence和模型泛化性一致的结论,进而从贝叶斯理论的角度对SGD的更新过程进行了猜测:SGD会引入Noise,而正是Noise的Flucturate帮助模型在更新过程中逃离Sharp Minima,进入更高的Bayesian Evidence区域,即Broad Minima,所以指出Batch Size的选择实际上是Noise Flucturate的调整,本质上是Sharp Minima和Broad Minima的Trade off。最后作者通过将SGD更新公式进行改写,并联合偏微分方程,得出了Noise的Fluctruate的形式表达,它Batch Size成反比,和Learning Rate成正比。
之前FAIR发表的论文《Accurate, Large Minibatch SGD:Training ImageNet in 1 Hour》中提出了Learning Rate Scaling Rule在ResNet上具有很好的效果,该论文在实验上做的比较充分,但是在理论上并没有特别Solid,而Google的这篇论文可以作为它的理论基石之一。

优化方法

优化方法来自论文《Don't Decay the Learning Rate, Increase the Batch Size》,这篇论文在理解完前一篇论文之后会显得非常简单,完全是一篇实验性论文,实验做得较为充分,这里只会对重要内容做个简单的梳理。

理论基础公式

对于SGD来说,Flucturation形式表达为
Deep Learning中的Large Batch Training相关理论与实践

对于Momentum-SGD来说,形式表达为(公式推导来自于langvein动力学)

Deep Learning中的Large Batch Training相关理论与实践

Large batch training的优化原理

无论是SGD还是Momentum-SGD,我们都可以发现g与Batch Size成反比,与Learning Rate成正比,而在一般的Deep Learning Model训练过程中,会在固定轮数对Learning Rate做Decay,这个过程让作者引发了思考,究竟在训练过程中,泛化性的提升是由于Learning Rate做Decay导致的,还是g发生变化导致的?如果是后者,那么定时增加Batch Size也应该会达到同样的效果,因此作者做了几组实验。
Deep Learning中的Large Batch Training相关理论与实践

作者做了三组实验,一组是标准的对Learning Rate做Decay,一组是固定Rearning Rate不变,在原来发生Learning Rate Decay的轮数将Batch Size扩大N倍(N是Learning Rate Decay的Factor,即与Learning Rate的Decay为相同力度)。另一组是二者的结合Hybrid,即先Learning Rate Decay,后变化Batch Size。实验证明三者的泛化性曲线相同,所以证明了Learning Rate Decay实际上是对g做了Scale down。然而增加Batch Size不但可以达到同样的效果,还能提高计算通信占比,并且在整体训练过程中减少Update的次数,这是Increase Batch Size Training的优化点。

关于Momentum-SGD

在Momentum-SGD的flucturation形式表达中,我们还看到了momentum的作用,即增加m的值可以增加g的值。但是实验证明,增加m同时扩大batch size得到的泛化性相对于改变learning rate和batch size要差一些。这是因为提高momentum会使Momentum-SGD中的accumulator需要更多的轮数才能到达稳定的状态,而在到达稳定状态之前,update的scale是会被supressed的,作者在论文附录中论证了这一观点,这里不再详细赘述。后续的实验也证明了这一点。
Deep Learning中的Large Batch Training相关理论与实践

更大Batch Size和消除Warm Up

在论文《Accurate, Large Minibatch SGD:Training ImageNet in 1 Hour》中,作者实验的最大Batch Size为8192。然而在这篇论文中,作者使用更大的初始Batch Size(最大尝试到65536)对ImageNet进行训练,并且在固定的轮数对Noise做Decay(增加Batch Size)。作者消去了Warm Up的过程,但是引入了Mometum的超参调优,当使用更大Batch Size时,不仅调整初始Learning Rate,还增加m值来进一步放大Noise,帮助训练过程逃离Sharp Minima。实验效果如下。
Deep Learning中的Large Batch Training相关理论与实践

小结

此篇论文更像是《A BAYESIAN PERSPECTIVE ON GENERALIZATION AND STOCHASTIC GRADIENT DESCENT》工作的延续,以该篇论证的理论基础出发,得出了一种提高训练计算加速比和收敛加速比的方法。结论和实验比较简单,但背后的数学推导较为复杂。

总结

工业界的分布式算力提升对Large Batch Training提出了需求,因为增加Batch Size显然是提高计算通信占比的最佳方式,所以Large Batch Training固有的收敛性问题就成为了学术界研究的重点方向。本文通过梳理近些年来学术界对Large Batch Training的论文研究,从理论角度阐述了Large Batch Training造成收敛性较差的原因——容易陷入Broad Minima。而Google发表的论文从贝叶斯角度给出了另外的解释——不同Batch Size训练引入的Noise不同造成Fluctuate也不同,最终导致收敛性的不同。为了验证这一观点,Google又从实践角度给出了验证——通过固定Learning Rate,逐步增大Batch Size来稳定Fluctuate,达到使用大Batch Size加速训练的目的。截止到目前,这些理论方面的论证和解释依然处于蓬勃发展之中,未来还会有更深入研究在学术界中出现。
 

Deep Learning中的Large Batch Training相关理论与实践的更多相关文章

  1. ON LARGE BATCH TRAINING FOR DEEP LEARNING&colon; GENERALIZATION GAP AND SHARP MINIMA

    目录 概 主要内容 一些解决办法 Keskar N S, Mudigere D, Nocedal J, et al. On Large-Batch Training for Deep Learning ...

  2. Deep Learning and Shallow Learning

    Deep Learning and Shallow Learning 由于 Deep Learning 现在如火如荼的势头,在各种领域逐渐占据 state-of-the-art 的地位,上个学期在一门 ...

  3. AndrewNG Deep learning课程笔记 - CNN

    参考, An Intuitive Explanation of Convolutional Neural Networks http://www.hackcv.com/index.php/archiv ...

  4. Deep Learning in NLP (一)词向量和语言模型

    原文转载:http://licstar.net/archives/328 Deep Learning 算法已经在图像和音频领域取得了惊人的成果,但是在 NLP 领域中尚未见到如此激动人心的结果.关于这 ...

  5. (转)分布式深度学习系统构建 简介 Distributed Deep Learning

    HOME ABOUT CONTACT SUBSCRIBE VIA RSS   DEEP LEARNING FOR ENTERPRISE Distributed Deep Learning, Part ...

  6. Deep Learning In NLP 神经网络与词向量

    0. 词向量是什么 自然语言理解的问题要转化为机器学习的问题,第一步肯定是要找一种方法把这些符号数学化. NLP 中最直观,也是到目前为止最常用的词表示方法是 One-hot Representati ...

  7. Word2Vec之Deep Learning in NLP (一)词向量和语言模型

    转自licstar,真心觉得不错,可惜自己有些东西没有看懂 这篇博客是我看了半年的论文后,自己对 Deep Learning 在 NLP 领域中应用的理解和总结,在此分享.其中必然有局限性,欢迎各种交 ...

  8. deep learning深度学习之学习笔记基于吴恩达coursera课程

    feature study within neural network 在regression问题中,根据房子的size, #bedrooms原始特征可能演算出family size(可住家庭大小), ...

  9. 学习Data Science&sol;Deep Learning的一些材料

    原文发布于我的微信公众号: GeekArtT. 从CFA到如今的Data Science/Deep Learning的学习已经有一年的时间了.期间经历了自我的兴趣.擅长事务的探索和试验,有放弃了的项目 ...

随机推荐

  1. table表格中的内容溢出布局方式

    什么是内容溢出呢?其实就是当文字很多的时候,如果内容区域只有那么长,那么多出的部分以点点点代替. 这次做的案例是在table里面,我们知道当我们在table里输入过多的文字内容的时候会撑乱表格,例如一 ...

  2. getopts

    http://blog.sina.com.cn/s/blog_81c2cf020100v0wh.html http://www.cnblogs.com/xiangzi888/archive/2012/ ...

  3. &lbrack;转&rsqb; shell字符串操作方法,以及实例

    每一种语言都有他独自的字符串操作方法,shell也一样,下面以以例子的方式,简单介绍常用方法. 1,取得字符串长度 string=abc12342341 //等号二边不要有空格 echo ${#str ...

  4. winform 解决界面闪动、提升加载速度 分类: WinForm 2015-02-03 16&colon;34 161人阅读 评论&lpar;0&rpar; 收藏

    说明: 从一个技术交流群里获得,经验证效果不错. //作用 加快界面加载 protected override CreateParams CreateParams          {         ...

  5. Windows下Flume的安装

    flume(日志收集系统) Flume是Cloudera提供的一个高可用的,高可靠的,分布式的海量日志采集.聚合和传输的系统,Flume支持在日志系统中定制各类数据发送方,用于收集数据:同时,Flum ...

  6. hdu-2255(带权二分图)

    题解:板子题.... #include<iostream> #include<cstring> #include<cstdio> #include<queue ...

  7. redis键值操作

    1.1. redis键值操作 1.1.1. keys patten 查询相应的key 可以精确的查,也可以模糊的查 1.1.1.1. 通配符:* ? [] 在redis里,模糊查询key的时候有3个通 ...

  8. 关于static、内部类

    1.static不能修饰外部类的原因 static修饰的成员是属于某个类的.而外部类的上一级程序单元是包,所以static不能修饰外部类. 2.外部类,内部类有不同访问权限的原因 外部类的上一级程序单 ...

  9. javascript易混淆的split&lpar;&rpar;、splice&lpar;&rpar;、slice&lpar;&rpar;方法详解

    很多时候,一门语言总有那么些相似的方法,容易让人傻傻分不清楚,尤其在不经常用的时候.而本文主要简单总结了JavaScript中的关于字符串和数组中三个容易混淆的方法.旨在方便查阅,在容易混淆的时候有据 ...

  10. java规范之checkstyle

    1. 概述 随着中心的代码规范的建立和实施,项目组对代码规范要求,以及软件工程师们对自身代码的编写规范重要性的认知,“代码规范”已经成为了中心的一个“热词”.然后怎么才能写出有规范的代码,怎么才能养成 ...