热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

迁移学习论文(五):LearningSemanticRepresentationsforUnsupervisedDomainAdaptation论文原理及复现工作

目录前言原理阐述文章介绍模型结构模型总述超参数设置总结前言本文属于我迁移学习专栏里的一篇,该专栏用于记录本人研究生阶段相关迁移学习论文的原理阐述以及复现工作。本专栏

目录

  • 前言
  • 原理阐述
    • 文章介绍
    • 模型结构
      • 模型总述
    • 超参数设置
  • 总结


前言
  • 本文属于我迁移学习专栏里的一篇,该专栏用于记录本人研究生阶段相关迁移学习论文的原理阐述以及复现工作。
  • 本专栏的文章主要内容为解释原理,论文具体的翻译及复现代码在文章的github中。

原理阐述

文章介绍


  • 这篇文章于2018年发表在ICML会议,作者是Shaoan Xie、Zibin Zheng、Liang Chen、Chuan Chen。
  • 这篇文章解决的主要问题是如何利用伪标签来进行域适应。之前的方法都忽略了样本的语义信息,比如之前的算法可能将目标域的背包映射到源域的小汽车附近。 这篇文章最要的贡献就是提出了 moving semantic transfer network 这个网络,简称mstn,其主要是通过对齐源域(有标签)和 目标域(伪标签,网络预测一个标签)相同类别的中心,以学习到样本的语义信息。

模型结构


  • 模型是这样的:
    在这里插入图片描述

模型总述


  • 上述模型的G特征提取器和F标签分类器以及D域分类器与DANN中的特征提取器、标签分类器和全局域分类器是一样的,这里不展开研究了。
  • 这个论文有价值的地方在于使用了伪标签,提出了semantic transfer loss,这个论文中的方法其实我也有考虑到过,我是受了DAAN的启发,但DAAN应该是受了该文的启发,因为DAAN是2019年发表的。DAAN中的局部域分类器也是将样本的每个类单独分开计算损失,但是DAAN计算的是域分类损失,而MSTN考虑的是MSE,因为相同类别经过特征提取之后的特征应当是相近的,这对应域适应中的条件概率损失。
  • 但是MSTN考虑到了两个问题,1.每次抽取样本可能会使得某些类别没有抽取到样本,那么就无从计算MSE。2.伪标签可能是不准确的,这样可能导致相反的效果,比如使一个书包的特征和一个汽车的特征进行对齐。
  • MSTN的解决办法非常有意思:
    在这里插入图片描述
    对每个类维护一个全局特征CTk或者CSkC^k_{T}或者C^k_{S}CTkCSk,每次使用CTk或者CSkC^k_{T}或者C^k_{S}CTkCSk来计算损失,CTk或者CSkC^k_{T}或者C^k_{S}CTkCSk的计算同时考虑当前的CTk或者CSkC^k_{T}或者C^k_{S}CTkCSk和本次根据样本生成的平均特征。所以就算本次抽取样本中没有某一类的样本,也可以根据该类上一次的CTk或者CSkC^k_{T}或者C^k_{S}CTkCSk来计算,同时假如有错误的伪标签也因为占比不大所以影响不大。
  • 其实MSTN这种解决办法也是尽可能的削弱错误影响,并没有根本上解决这些问题。

超参数设置


  • 学习率采用衰减,
    在这里插入图片描述
    p是迭代次数占总的比例,学习率每次迭代更新一次,

def train(epoch, model, sourceDataLoader, targetDataLoader,DEVICE,args):learningRate=args.lr/math.pow((1+10*(epoch-1)/args.epoch),0.75)

  • 损失函数在这里插入图片描述
    三项分别是标签分类损失,域分类损失,semantic transfer loss,其中γ=λγ=λγ=λ,λ遵循下面的公式:
    在这里插入图片描述
    里面的上图的γ可不是损失函数中的γ,上图的p设置为当前batchid占总的比例,如下代码所示:

lenSourceDataLoader = len(sourceDataLoader)for batch_idx, (sourceData, sourceLabel) in tqdm.tqdm(enumerate(sourceDataLoader),total=lenSourceDataLoader,desc='Train epoch {}'.format(epoch),ncols=80,leave=False):p = float(batch_idx + 1 + epoch * lenSourceDataLoader) / args.epoch / lenSourceDataLoaderalpha = 2. / (1. + np.exp(-10 * p)) - 1

  • CNN 采用的是AlexNet作为基本结构,fc7后面接了一个bottleneck layer(瓶颈层,主要作用是降维)。
  • 鉴别器,我们采用的是RevGard相同的结构:x-》1024-》1024-》2
  • 超参数的设置:θ = 0.7。

总结
  • 该文总体来说提供了一种思路,但是我觉得伪标签的问题其实并没有办法真正解决,会限制该类模型的上限并不会很高。

推荐阅读
author-avatar
生活趣图分享
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有