热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

多模态特征融合机制(含代码):TFN(TensorFusionNetwork)和LMF(LowrankMultimodalFusion)

文章目录写在前面简单的concatTFN融合策略LWF融合策略论文全称:《TensorFusionNetworkforMultimodalSentimentAnaly


文章目录

    • 写在前面
    • 简单的concat
    • TFN融合策略
    • LWF融合策略


论文全称:
《Tensor Fusion Network for Multimodal Sentiment Analysis》
《Efficient Low-rank Multimodal Fusion with Modality-Specific Factors》



写在前面

最近在做一个分类的比赛,想要用上数据中的多模态信息(主要是文本和图像特征),因此探索了一些多模态特征的融合机制,并记录下来。

下文中均以3种不同模态下的特征融合为例。并设A模态特征维度为512,B模态特征维度为1024,C模态特征维度为32

import torch
A = torch.randn(16, 512)
B = torch.randn(16, 1024)
C = torch.randn(16, 32)

简单的concat

concat既是最简单也是最常用的一种方式,直接在特征维度将不同模态特征进行拼接后,再送入后续的推理模块。

fusion_feature = torch.cat([A, B, C], dim=1)

TFN融合策略

原理简述

TFN来自17年EMNLP会议论文《Tensor Fusion Network for Multimodal Sentiment Analysis》,其主要考虑了inter-modalityintar-modality两个方面。也就是要求既能考虑各模态之间的特征融合,也要有效地利用各特定模态的特征。

在这里插入图片描述
图左为Early Fusion策略,其实就是之前提到的concat方法,图右展现了作者提出的TFN模块(Tensor Fusion Network)。具体做法就是首先对每个模态用1进行维度扩充,然后对不同模态求笛卡尔积

以两个模态为例,对zv,zlz_v,z_lzv,zl1先扩充一维,得到后的特征再进行outer product(外积,张量积)。可以看到,用1扩充后,即计算了两个模态间的特征相关性,又保留了特定模态的信息。
在这里插入图片描述
同理,对三个模态求得了笛卡尔积后([za;1]⨂[zb;1]⨂[zc;1][z_a; 1] \bigotimes [z_b; 1] \bigotimes [z_c; 1][za;1][zb;1][zc;1]),即计算了两两模态间的特征、三模态间的特征,又保留了各特定模态中的特征(见上图的Tensor Fusion细节)。

n = A.shape[0]
# 用 1 扩充维度
A = torch.cat([A, torch.ones(n, 1)], dim=1)
B = torch.cat([B, torch.ones(n, 1)], dim=1)
C = torch.cat([C, torch.ones(n, 1)], dim=1)
# 计算笛卡尔积
A = A.unsqueeze(2) # [n, A, 1]
B = B.unsqueeze(1) # [n, 1, B]
fusion_AB = torch.einsum('nxt, nty->nxy', A, B) # [n, A, B]
fusion_AB = fusion_AB.flatten(start_dim=1).unsqueeze(1) # [n, AxB, 1]
C = C.unsqueeze(1) # [n, 1, C]
fusion_ABC = torch.einsum('ntx, nty->nxy', fusion_AB, C) # [n, AxB, C]
fusion_ABC = fusion_ABC.flatten(start_dim=1) # [n, AxBxC]
# A, B, C分别代表原来的特征维度nA,nB,nC加上1

需要注意的是,实际编程实现时并未直接计算得到3-D的笛卡尔积,而是分别两两计算outer product



LWF融合策略

上面提到的TFN对计算了两/三模态间的相关性,也保留了单模态的相关性,但同时也大大地增加了特征维度。增加特征维度从而会影响计算效率以及增加内存消耗,并且TFN所增加的时间/空间复杂度都与输入模态数呈指数增加。并且参数量一多,就容易增加过拟合的风险。

LMF是发表于ACL2017年的工作,针对TFN的上述问题,作者采用了low-rank weight进行多模态融合,降低参数量的同时还提升了计算速度。


建议先看看这篇博客:LWF论文解读


TFN中的融合后的特征Z维度为d1xd2xd3x....dm,其中m表示模态数,i模态特征维度为di。后续要将其送入推理模块中,通常需要降到h维的特征F,此时需要一个维度为(d1xd2xd3x....dm)xh的(M+1阶)权重W进行全连接操作。

全连接操作中,W可以视为h个M阶矩阵,每个矩阵与融合特征Z计算后的结果为F中的一维。

LMF要做的是就是将W分解成M组与各模态相关low-rank因子。按照上述的视角,将W视为h个矩阵,每个特征矩阵Wk如下所示,其中使得分解成立的最小R称为秩(Rank)。
在这里插入图片描述
在LMF中,人为设定固定的秩r,得到每个Wk矩阵了,对特征矩阵进行重新排列,使其变为与模态m相关的特征Wm
在这里插入图片描述
为了更好地理解排列过程,我画了一张图,展示了3个模态时,秩为r,期望维度为h的情况:
在这里插入图片描述

在这里插入图片描述

那么对特征变换(Zd维特征)的过程可以拆分为如下过程:
在这里插入图片描述
在这里插入图片描述

Z本身也是由不同模态的外积得到的,那么组合起来可得到下式。
在这里插入图片描述
其中Λ\LambdaΛ表示像素级点乘。这样分解之后,避免了从各模态特征Zm去建模Z,并且可以扩展到不同数量的模态上,大大降低了时间复杂度。以3模态的融合为例,图例如下:

在这里插入图片描述
从上图可知,最后的由多模态特征Zm融合成h维特征的过程就变成了:每个模态分别构建r个权重矩阵,融合后对各模态特征进行矩阵乘法,得到一个h维的特征;然后再将各模态得到的h维特征进行像素级乘法即可。代码如下:

import torch
import torch.nn as nn
from torch.nn.parameter import ParameterA = torch.randn(16, 512)
B = torch.randn(16, 1024)
C = torch.randn(16, 32)n = A.shape[0]
A = torch.cat([A, torch.ones(n, 1)], dim=1)
B = torch.cat([B, torch.ones(n, 1)], dim=1)
C = torch.cat([C, torch.ones(n, 1)], dim=1)# 假设所设秩: R = 4, 期望融合后的特征维度: h = 128
R, h = 4, 128
Wa = Parameter(torch.Tensor(R, A.shape[1], h))
Wb = Parameter(torch.Tensor(R, B.shape[1], h))
Wc = Parameter(torch.Tensor(R, C.shape[1], h))
Wf = Parameter(torch.Tensor(1, R))
bias = Parameter(torch.Tensor(1, h))# 分解后,并行提取各模态特征
fusion_A = torch.matmul(A, Wa)
fusion_B = torch.matmul(B, Wb)
fusion_C = torch.matmul(C, Wc)# 利用一个Linear再进行特征融合(融合R维度)
funsion_ABC = fusion_A * fusion_B * fusion_C
funsion_ABC = torch.matmul(Wf, funsion_ABC.permute(1,0,2)).squeeze() + bias

推荐阅读
  • 本文详细介绍了Java反射机制的基本概念、获取Class对象的方法、反射的主要功能及其在实际开发中的应用。通过具体示例,帮助读者更好地理解和使用Java反射。 ... [详细]
  • 本文介绍如何使用OpenCV和线性支持向量机(SVM)模型来开发一个简单的人脸识别系统,特别关注在只有一个用户数据集时的处理方法。 ... [详细]
  • Spring – Bean Life Cycle
    Spring – Bean Life Cycle ... [详细]
  • 单片微机原理P3:80C51外部拓展系统
      外部拓展其实是个相对来说很好玩的章节,可以真正开始用单片机写程序了,比较重要的是外部存储器拓展,81C55拓展,矩阵键盘,动态显示,DAC和ADC。0.IO接口电路概念与存 ... [详细]
  • 在多线程并发环境中,普通变量的操作往往是线程不安全的。本文通过一个简单的例子,展示了如何使用 AtomicInteger 类及其核心的 CAS 无锁算法来保证线程安全。 ... [详细]
  • [转]doc,ppt,xls文件格式转PDF格式http:blog.csdn.netlee353086articledetails7920355确实好用。需要注意的是#import ... [详细]
  • 本文对比了杜甫《喜晴》的两种英文翻译版本:a. Pleased with Sunny Weather 和 b. Rejoicing in Clearing Weather。a 版由 alexcwlin 翻译并经 Adam Lam 编辑,b 版则由哈佛大学的宇文所安教授 (Prof. Stephen Owen) 翻译。 ... [详细]
  • 在《Cocos2d-x学习笔记:基础概念解析与内存管理机制深入探讨》中,详细介绍了Cocos2d-x的基础概念,并深入分析了其内存管理机制。特别是针对Boost库引入的智能指针管理方法进行了详细的讲解,例如在处理鱼的运动过程中,可以通过编写自定义函数来动态计算角度变化,利用CallFunc回调机制实现高效的游戏逻辑控制。此外,文章还探讨了如何通过智能指针优化资源管理和避免内存泄漏,为开发者提供了实用的编程技巧和最佳实践。 ... [详细]
  • JUC(三):深入解析AQS
    本文详细介绍了Java并发工具包中的核心类AQS(AbstractQueuedSynchronizer),包括其基本概念、数据结构、源码分析及核心方法的实现。 ... [详细]
  • 本文介绍了几种常用的图像相似度对比方法,包括直方图方法、图像模板匹配、PSNR峰值信噪比、SSIM结构相似性和感知哈希算法。每种方法都有其优缺点,适用于不同的应用场景。 ... [详细]
  • 本文介绍如何通过 Python 的 `unittest` 和 `functools` 模块封装一个依赖方法,用于管理测试用例之间的依赖关系。该方法能够确保在某个测试用例失败时,依赖于它的其他测试用例将被跳过。 ... [详细]
  • Ihavetwomethodsofgeneratingmdistinctrandomnumbersintherange[0..n-1]我有两种方法在范围[0.n-1]中生 ... [详细]
  • 解决问题:1、批量读取点云las数据2、点云数据读与写出3、csf滤波分类参考:https:github.comsuyunzzzCSF论文题目ÿ ... [详细]
  • 在 Kubernetes 中,Pod 的调度通常由集群的自动调度策略决定,这些策略主要关注资源充足性和负载均衡。然而,在某些场景下,用户可能需要更精细地控制 Pod 的调度行为,例如将特定的服务(如 GitLab)部署到特定节点上,以提高性能或满足特定需求。本文深入解析了 Kubernetes 的亲和性调度机制,并探讨了多种优化策略,帮助用户实现更高效、更灵活的资源管理。 ... [详细]
  • 如何使用 net.sf.extjwnl.data.Word 类及其代码示例详解 ... [详细]
author-avatar
GYuan83_844
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有