Semantic Instance Segmentation with a Discriminative Loss Function【论文详解】

时间:2024-03-28 09:06:55

PAPER:https://arxiv.org/abs/1708.02551
CODE:https://github.com/DavyNeven/fastSceneUnderstanding

一、整体框架、流程

本文提出了一种基于量度学习Metic learning的用于语义分割和实例分割的方法。自定义LOSS训练CNN学习到一种metric,即从像素空间到高纬度空间的映射。使得同类(同实例)物体中的像素映射到高维空间后,得到的embedding vector之间的距离(L1、L2距离)相近,从而使用聚类的方式完成分割任务。

其流程简介如下:

Semantic Instance Segmentation with a Discriminative Loss Function【论文详解】

首先,上图是输入图片,下图是实例分割的ground truth。

Semantic Instance Segmentation with a Discriminative Loss Function【论文详解】

下面用一个二维的embedding空间来形象的解释系统的工作流程。如图所示,第二行显示的是通过不同迭代次数后的CNN所得的映射二维空间(这里为了方便显示和理解,实际中使用的维度往往更高),即将3通道的像素向量[r,g,b]映射到二维向量[u,v]。第二行显示的颜色即是将u,v分别赋给R,G通道得到的图像。直观地,假如最终学习到的CNN可以將Input图像中的不同instance中的像素映射为不同的颜色(大致如第二行最后一列的图像)即可完成Instance的分割。

第一行则是将embedding vector在二维坐标系的显示,一个点代表一个vector。为了方便区分,使用ground truth的颜色表示这些vector(没用颜色也没有关系)。可以看到,在迭代多次后,不同instance得到的embedding vectors聚集到了不同的区域内,这时使用聚类算法(例如mean-shift)就可以将像素聚类,从而实现分割。

下图显示的就是不同迭代次数下的聚类结果,得到这些结果的方法将在后面具体分析,这里我们先大致了解算法的流程即可。

Semantic Instance Segmentation with a Discriminative Loss Function【论文详解】

二、Metric learning 如何得到上节提到的映射

通过第一节的描述,我们看到embedding vector需要具有的特性是:

  1. 同instance内部像素的embedding vector在映射空间中要尽可能的临近(L1、L2距离)
  2. 不同instance的mean embedding vector(即在映射空间中聚类的中心点)要尽可能的远离

即希望embedding vectors 在映射空间中的位置如下图所示

Semantic Instance Segmentation with a Discriminative Loss Function【论文详解】

图中所示的intra-cluster pull force 即为上述条件1,inter-cluster push force即为上述条件2

为了满足这两个条件,可以进一步设计Loss function。具体公式如下:

Semantic Instance Segmentation with a Discriminative Loss Function【论文详解】

其中

​​​​​​ Semantic Instance Segmentation with a Discriminative Loss Function【论文详解】

Semantic Instance Segmentation with a Discriminative Loss Function【论文详解】代表groundtruth中instance数目,Semantic Instance Segmentation with a Discriminative Loss Function【论文详解】代表某个instance中的像素个数。Semantic Instance Segmentation with a Discriminative Loss Function【论文详解】代表instance中第i个像素产生的embedding vector,Semantic Instance Segmentation with a Discriminative Loss Function【论文详解】是groundtruth中该instance的所有像素对应的embedding vectors在映射空间中的中心(mean vector)。

最后一项为正则项,目的是让映射空间中每个cluster(对应于instance)的中心与原点的距离不要过远。

训练过程中使用ground truth作为instance mask。上述LOSS仅涉及同一类object的不同instance之间的聚类学习。如果要用到多类object,则需要分别对每个类计算LOSS并累加。

三、测试过程(推理Inference)

  1. Input inference 产生embedding vectors (output featuremap)
  2. 使用Mean-shift算法进行聚类,聚类参数bandwidth使用训练时的 cluster 间margin

产生几个cluster 即有几个instance。如果涉及多类object,则首先进行semantic segmentation,然后使用semantic segmentation结果作为不同类的mask,再进一步进行Mean-shift聚类。

四、误差分析

  1. semantic segmentation
  2. Mean-shift

Semantic Instance Segmentation with a Discriminative Loss Function【论文详解】

前两行使用resnet38的semantic segmentation结果

后两行使用groundtruth的分割结果

Mean-shift代表使用聚类算法估计cluster中心,并用threshold筛选的结果

center threshold代表使用groundtruth得到的中心,并用threshold筛选的结果