前言
- 本文属于我迁移学习专栏里的一篇,该专栏用于记录本人研究生阶段相关迁移学习论文的原理阐述以及复现工作。
- 本专栏的文章主要内容为解释原理,论文具体的翻译及复现代码在文章的github中。
原理阐述
文章介绍
- 这篇文章于2018年发表在ICML会议,作者是Shaoan Xie、Zibin Zheng、Liang Chen、Chuan Chen。
- 这篇文章解决的主要问题是如何利用伪标签来进行域适应。之前的方法都忽略了样本的语义信息,比如之前的算法可能将目标域的背包映射到源域的小汽车附近。 这篇文章最要的贡献就是提出了 moving semantic transfer network 这个网络,简称mstn,其主要是通过对齐源域(有标签)和 目标域(伪标签,网络预测一个标签)相同类别的中心,以学习到样本的语义信息。
模型结构
- 模型是这样的:
模型总述
- 上述模型的G特征提取器和F标签分类器以及D域分类器与DANN中的特征提取器、标签分类器和全局域分类器是一样的,这里不展开研究了。
- 这个论文有价值的地方在于使用了伪标签,提出了semantic transfer loss,这个论文中的方法其实我也有考虑到过,我是受了DAAN的启发,但DAAN应该是受了该文的启发,因为DAAN是2019年发表的。DAAN中的局部域分类器也是将样本的每个类单独分开计算损失,但是DAAN计算的是域分类损失,而MSTN考虑的是MSE,因为相同类别经过特征提取之后的特征应当是相近的,这对应域适应中的条件概率损失。
- 但是MSTN考虑到了两个问题,1.每次抽取样本可能会使得某些类别没有抽取到样本,那么就无从计算MSE。2.伪标签可能是不准确的,这样可能导致相反的效果,比如使一个书包的特征和一个汽车的特征进行对齐。
- MSTN的解决办法非常有意思:
对每个类维护一个全局特征CTk或者CSkC^k_{T}或者C^k_{S}CTk或者CSk,每次使用CTk或者CSkC^k_{T}或者C^k_{S}CTk或者CSk来计算损失,CTk或者CSkC^k_{T}或者C^k_{S}CTk或者CSk的计算同时考虑当前的CTk或者CSkC^k_{T}或者C^k_{S}CTk或者CSk和本次根据样本生成的平均特征。所以就算本次抽取样本中没有某一类的样本,也可以根据该类上一次的CTk或者CSkC^k_{T}或者C^k_{S}CTk或者CSk来计算,同时假如有错误的伪标签也因为占比不大所以影响不大。 - 其实MSTN这种解决办法也是尽可能的削弱错误影响,并没有根本上解决这些问题。
超参数设置
- 学习率采用衰减,
p是迭代次数占总的比例,学习率每次迭代更新一次,
def train(epoch, model, sourceDataLoader, targetDataLoader,DEVICE,args):learningRate=args.lr/math.pow((1+10*(epoch-1)/args.epoch),0.75)
- 损失函数
三项分别是标签分类损失,域分类损失,semantic transfer loss,其中γ=λγ=λγ=λ,λ遵循下面的公式:
里面的上图的γ可不是损失函数中的γ,上图的p设置为当前batchid占总的比例,如下代码所示:
lenSourceDataLoader = len(sourceDataLoader)for batch_idx, (sourceData, sourceLabel) in tqdm.tqdm(enumerate(sourceDataLoader),total=lenSourceDataLoader,desc='Train epoch {}'.format(epoch),ncols=80,leave=False):p = float(batch_idx + 1 + epoch * lenSourceDataLoader) / args.epoch / lenSourceDataLoaderalpha = 2. / (1. + np.exp(-10 * p)) - 1
- CNN 采用的是AlexNet作为基本结构,fc7后面接了一个bottleneck layer(瓶颈层,主要作用是降维)。
- 鉴别器,我们采用的是RevGard相同的结构:x-》1024-》1024-》2
- 超参数的设置:θ = 0.7。
总结
- 该文总体来说提供了一种思路,但是我觉得伪标签的问题其实并没有办法真正解决,会限制该类模型的上限并不会很高。