【论文阅读】Generative Pretraining from Pixels

时间:2024-03-17 14:23:05

Generative Pretraining From Pixels

引用: Chen M, Radford A, Child R, et al. Generative pretraining from pixels[C]//International conference on machine learning. PMLR, 2020: 1691-1703.

论文链接: http://proceedings.mlr.press/v119/chen20s.html

简介

受自然语言中无监督表示学习进展的启发,作者研究了类似的模型是否能够学习图像的有用表示,训练了一个序列Transformer来自回归地预测像素,而不包含2D输入结构的知识。尽管是在低分辨率的ImageNet上进行训练,没有标签,但实验发现一个GPT-2规模的模型通过线性探测、微调和低数据分类学习,学习到了强大的图像表示。在CIFAR-10上,使用线性探测达到了96.3%的准确率,超过了监督的Wide ResNet,全微调达到了99.0%的准确率,与*监督预训练模型相匹配。同时,作者还在ImageNet上与自监督基准进行了比较,通过将像素替换为VQVAE编码,在线性探测特征时达到了69.0%的top-1准确率。

Method

论文的方法包括预训练阶段和微调阶段。在预训练中,探索了auto-regressive和BERT,还应用序列Transformer架构来预测像素,而不是语言标记。而测量表征质量的一种方法是对图像分类进行微调。微调为模型添加了一个小的分类头,用于优化分类目标并调整所有权重。当与早停结合使用时,预训练可以被视为一种有利的初始化或正则化。另一种方法则使用预先训练的模型作为特征提取器。特别地,给定标记的示例(X,Y),将模型应用于X以产生特征fx。然后,在(fx,Y)上训练线性分类器。线性探测源自一种直觉,即好的特征应该线性地分离转移任务的类别。此外,线性探测有助于将特征质量与模型架构区分开来:在微调中,一个模型可能优于另一个模型,因为它的架构更适合下游任务,而不是因为更好的预训练。

Pre-training

给定由高维数据 X = ( x 1 , . . . , x n ) X=(x_1,...,x_n) X=x1,...,xn组成的未标记数据集 X X X,可以选择集合 [ 1 , n ] [1,n] [1n]的排列π,并对密度 p ( x ) p(x) p(x)进行自回归建模:

当处理图像时,选择 1 ≤ i ≤ n 1≤i≤n 1in的单位置换 π i = i π_i=i πi=i,也称为光栅顺序。通过最小化数据的负对数似然来训练模型:

对于BERT目标,其采样为子序列 M ⊂ [ 1 , n ] M⊂[1,n] M[1n],使得每个索引 i i i独立地具有出现在 M M M中的概率为0.15。称 M M M为BERT掩码,并且通过最小化以“未掩码”为条件的“掩码”元素 x M x_M xM的负对数似然来训练模型:

Architecture

transformer decoder取一个输入序列 x 1 , . . . , x n x_1,...,x_n x1,...,xn,并为每个位置产生 d d d维嵌入。解码器被实现为 L L L个块的堆栈,其中第 l l l个产生中间嵌入 h l 1 , . . . , h l n h_l^1,...,h_l^n hl1,...,hln也是维数d。我们使用transformer decoder块的GPT-2公式,它作用于输入张量 h l h_l hl如下:

特别地,**层规范在注意力机制和MLP之前,并且所有运算都位于残差路径上。**这样的配置可以轻松地缩放transformer。

序列元素之间的唯一混合发生在注意力操作中,为了确保在训练AR目标时进行适当的调节,将标准的上三角掩码应用于注意力逻辑的n×n矩阵。当使用BERT目标时,不需要注意logit掩蔽:在将内容嵌入应用于输入序列之后,将M中的位置清零。

此外,由于学习了每个序列元素的独立位置嵌入,BERT模型没有位置归纳偏差(即它是置换不变的)。换句话说,位置之间的任何空间关系都必须由模型在训练时学习。对于AR模型来说,这并不完全正确,因为选择光栅顺序也会修复预先指定的条件顺序。然而,置换不变性是与卷积神经网络形成强烈对比的一种特性,卷积神经网络包含了特征应该从空间上接近的元素产生的归纳偏差。

Fine-tuning

当进行微调时,我们对序列的 n L n^L nL维度进行平均池化,以提取每个示例的特征的d维向量。然后,学习从 f L f_L fL到类别的logits的投影,使用它来最小化交叉熵损失。

Linear Probing

为线性探测提取固定特征遵循与微调类似的过程,只是平均池化并不总是在最后一层:

其中0≤l≤l。实验表明,最佳特征通常位于网络的中间。在微调中,投影这些中间特征以产生类logits。

实验

在这里插入图片描述

表征质量在很大程度上取决于提取特征的层。与监督模型相比,这些生成模型的最佳表征位于网络的中间层。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述