【论文阅读】Segment Anything论文梳理

时间:2024-03-09 13:55:04

A. Segment Anything Model and Task Details

【图像编码器】

  • 一般来说,图像编码器可以是任何输出C×H×W图像嵌入的网络。基于不同规模的和强大的预训练,我们使用MAE 预训练视觉transformer(ViT),以最小的适应来处理高分辨率输入,特别是ViT-H/16,有14×14 的windowed attention和4个equally-spaced global attention blocks。图像编码器的输出是输入图像的16倍下采样的image embedding。由于我们的运行时目标是实时处理每个提示,因此我们可以提供大量的图像编码器片段,因为它们每幅图像只计算一次,而不是每个提示只计算一次。
    根据标准的实践(例如,[40]),我们使用了1024×1024的输入分辨率,这是通过重新缩放图像和填充较短的边而获得的。因此,图像嵌入值为64×64。为了减少通道维度,在[62]之后,我们在256通道的3×3卷积之后,使用256通道的1×1卷积。每个卷积之后都是一个层的归一化[4]。

【提示编码器】

  • 稀疏提示被映射到256维的向量嵌入如下。
    • 一个点被表示为:该点的位置编码(positional encoding)+ 两个网络学习到的嵌入(learned embedding)之一(该点是在前景中还是在背景中)。
    • box 由嵌入对表示: (1)其左上角的位置编码与表示“左上角”的learned embedding的和,(2)其右下角的位置编码与表示“右下角”的learned embedding的和。
    • 为了表示*形式的文本,我们使用了来自CLIP 的文本编码器(任何文本编码器通常都是可能的)。
  • 密集的提示(即掩码)与图像具有空间对应关系。
    相比与输入图片 输入的掩码为4X下采样,然后使用两个2×2,stride-2卷积进行下采样,输出通道分别为4和16。最后的1×1卷积将通道维度映射到256。每一层都会使用GELU激活[50]和层归一化。然后,将掩码与图像嵌入进行元素级相加。如果没有掩码提示,则在每个图像嵌入位置添加一个表示“无掩码”的学习嵌入。

【轻量级解码器】

  • 该模块有效地将图像嵌入和一组提示嵌入映射到一个输出掩码。为了结合这些输入,我们从 Transformer segmentation models[14,20]中获得灵感,并修改了一个标准的 Transformer decoder。在应用我们的解码器之前,我们首先在提示嵌入集合中 嵌入一个学习到的输出token embedding,该嵌入将用于解码器的输出,类似于[33]中的[class]token。为简单起见,我们将这些嵌入(不包括图像嵌入)统称为“tokens”。

    在这里插入图片描述
    我们的解码器设计如图14所示。每个解码器层执行4个步骤:
    • (1) 在 tokens 上进行自注意力(self-attention)
      (2)从 tokens (作为查询)到 image embedding 上进行交叉注意力(cross-attention)
      (3)点级MLP更新每个标记,
      (4)从 image embedding(作为查询)到tokens上进行交叉注意到标记。
    • 最后一步是使用prompt information更新 image embedding。
      在交叉注意过程中,将图像嵌入视为一组 6 4 2 64^2 642的 256维向量。每个自/交叉注意和MLP都有一个残差连接、层归一化和dropout=0.1。
      下一个解码器层从上一层中获取更新的tokens和更新的图像嵌入。我们使用了一个两层解码器。
      为了确保解码器能够访问关键的几何信息,当位置编码参与注意力层时,它们将被添加到图像嵌入中。此外,整个原始提示tokens(包括它们的位置编码)都会被重新添加到更新后的tokens中。这强烈地依赖于提示tokens的几何位置和类型。
      在运行解码器后,我们用两个转置的卷积层对更新后的图像嵌入上采样4×(现在它相对于输入图像缩小了4×)。然后,tokens再次关注图像嵌入,我们将更新后的输出tokens嵌入传递给一个小的3层MLP,该MLP输出一个与放大的图像嵌入的通道维数相匹配的向量。最后,我们预测了一个具有空间点级乘积的掩码,在方法的图像嵌入和MLP的输出之间。

      该transformer使用的嵌入尺寸为256。transformer MLP块有一个很大的内部尺寸为2048,但MLP只应用于有相对较少(基本小于20)的提示 tokens。然而,在交叉注意层中,我们有一个64×64的图像嵌入,为了提高计算效率,我们将查询、键和值的通道维数降低2×到128。所有的注意力层都使用8个头。
      用于放大的输出图像嵌入的转置卷积为2×2,步幅为2,输出通道尺寸分别为64和32,并具有GELU激活。每层后跟着归一化。

【使模型模糊感知】

  • 如前所述,单个输入提示符可能是模糊的,因为它对应于多个有效的掩码,并且模型学习到这些掩码的平均。我们通过一个简单的修改来消除这个问题:我们不是预测单个掩模,而是使用少量的输出tokens并同时预测多个掩模。默认情况下,我们预测三个掩码,因为我们观察到三个层(整个、部分和子部分)通常足以描述嵌套掩码。在训练过程中,我们计算标签和每个预测掩码之间的损失,但只对最小的损失反向传播。这是一种用于具有多个输出[15,45,64]的模型的常见技术。为了在应用程序中使用,我们希望对预测的掩码进行排序,因此我们添加一个小头(在一个额外的输出tokens上操作),它估计每个预测掩码和它所覆盖的对象之间的IoU。
  • 有多个提示的模糊性比较罕见,而且三个输出掩码通常会变得相似。为了减少训练时退化损失的计算,并确保单个明确的掩码接收到一个规则的梯度信号,当给出多个提示时,我们只预测单个掩模。为了预测额外的掩码,是通过添加第四个输出tokens来实现的。单个提示不会返回第四个提示码,它是多个提示返回的唯一掩码。

【损失】

  • 我们使用focal loss和dice loss的线性组合来监督掩码预测,其中focal loss和dice loss的比例为20:1,遵循[20,14]。与[20,14]不同,我们观察到在每个解码器层后的辅助深度监督是没有帮助的。[ IoU预测头输出]与[ 预测掩码和标签掩码的IOU ]的均方误差,作为损失。它被添加到掩码损失的一个恒定的比例因子为1.0。

【训练算法】
根据最近的方法[92,37],我们在训练期间模拟了一个交互式分割设置。

  • 首先,以等概率随机选择前景点或边界框。点从标签掩码中均匀采样。将框作为真实掩码的边界框,在每个坐标中添加随机噪声,标准差等于框边长的10%,最大为20像素。这种噪声轮廓是实例分割等应用程序之间的合理设置,实例分割会在目标对象周围产生一个紧密的box,而在交互式分割中,用户可能会画一个松散的box。
  • 在从第一个提示符进行预测后,从前一个掩码预测和标签掩码之间的误差区域中均匀地选择后续的点。如果误差区域是假阴性或假阳性,则每个新点分别是前景或背景。我们还提供了来自前一个迭代的掩码预测,作为我们的模型的附加提示。为了给下一次迭代提供最大的信息,我们提供了无阈值掩码日志,而不是二值化掩码。当返回多个掩码时,传递给下一个迭代并用于采样下一个点的掩码是预测有效值最高的掩码。
  • 我们发现在8个迭代采样点(我们已经进行了测试到16个)后,收益递减。此外,为了鼓励模型从所提供的掩码中获益,我们还使用了另外两个迭代,其中没有额外的点被采样。其中一个迭代被随机插入到8个迭代采样点中,另一个总是在最后。这给出了11次迭代:一个采样的初始输入提示,8个迭代采样点,两个迭代没有新的外部信息,因此它可以学习改进自己的掩码预测。我们注意到,使用相对较多的迭代是可能的,因为我们的轻量级掩码解码器需要的计算量不到图像编码器的1%,因此,每次迭代只增加了很小的开销。这与以前的交互式方法不同,即每个优化器更新[70,9,37,92]只执行一个或几个交互式步骤。

【训练时设置】

  • 优化器:使用AdamW(β1 = 0.9,β2 = 0.999)
  • 预热学习率:线性学习速率进行250次迭代
  • 学习速率的衰减:预热后的初始学习率(lr)为8e−4。会训练90k次迭代(∼2SA-1B轮次),并在60k次迭代和86666次迭代时将lr减少了10倍。
  • batch size:256张图像。
  • 权重衰减:为了规范SAM,我们将权重衰减(wd)设置为0.1,并以0.4的速率应用drop path [53]。
  • 分层学习速率衰减[5](ld)为0.8。
  • 不应用数据增强。
  • 从MAE预训练的ViT-H初始化SAM。由于大图像编码器和1024×1024输入大小,我们将训练分布在256个gpu上。为了限制GPU内存的使用,我们训练每个GPU使用多达64个随机采样的掩码。
  • 此外,我们发现,轻微过滤SA-1B掩码,丢弃任何覆盖90%以上图像的掩码,可以定性地改善结果。

对于消融实验和其他训练上的变化(例如,text到掩码D.5),我们和上面设置不一致的如下。当只使用来自第一和第二个数据引擎阶段的数据进行训练时,我们使用 input with large-scale jitter [40]增加输入,其尺度范围为[0.1,2.0]。直观地说,当训练数据更有限时,数据增强可能会有所帮助。为了训练ViT-B和ViT-L,我们使用了180k次迭代,批处理大小为128次,分布在128个gpu上。我们分别为ViT-B/L设置了lr=8e−4/4e−4、ld = 0.6/0.8、wd = 0.1和dp = 0.6/0.4。