热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

BERT句向量(一):SentenceBERT

前言句向量:能够表征整个句子语义的向量,目前效果比较好的方法还是通过bert模型结构来实现,也是本文的主题。有了句向量,

前言

句向量:能够表征整个句子语义的向量,目前效果比较好的方法还是通过bert模型结构来实现,也是本文的主题。

有了句向量,我们可以用来进行聚类,处理大规模的文本相似度比较,或者基于语义搜索的信息检索。

例如搜索系统中的输入query和匹配文档document、Q&A任务的问题和答案等等,都可以转化为计算两个句子的语义相似/相关度,相关度最高的n个作为模型的返回结果。


题外话

这种类似的模型一般称为passage retrieval models,即段落检索,有两个代表:


  1. sparse models:BM25、TF-IDF等;
  2. dense models(DPR,Dense Passage Retrieval):将query和doc(question和passage/answer)都转化为稠密向量,然后通过faiss等工具进行相关召回。

原生Bert

原生的BERT模型在诸多句子分类和句子对的回归任务上都取得了state-of-the-art的表现,它使用一种 cross-encoder的结构:将两个句子拼接输入到模型,经过带有self attention的transformer网络得到最终的预测值。

但这种做法不适用于大量句子对的回归任务,例如给定10000个句子,找出每个句子最相似的句子,那么每个句子就得需要与其他所有句子进行两两组合,才能得到与所有句子的相似度,即需要进行n*(n-1)/2= 49995000次的推理计算,这显然是不合理的。

这其实与推荐场景类似,采用这种结构的话,query需要与所有的doc进行分别计算,才能分数相关度最高的doc,这是不现实。所以这种做法一般是放在后面的排序阶段

而在此之前,一般会先经过召回阶段,则是需要事先将所有doc输入到bert模型,提取出句向量进行存储,实际使用时,实时计算query的句向量,然后通过faiss等ann工具,来从所有doc中召回相关度最高的n个。

因此,sentence-bert此时就派上用场,它使得bert模型能够提取表征句子语义的句向量。


Sentence-BERT

相关论文:《Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks》


pooling strategies

其实原生bert模型本身是具备句向量提取的能力,一般是以下3种方法,sentence-bert也是采用相同的方法:


  1. CLS:使用[CLS]字符最后一层的输出向量,作为句向量;
  2. MEAN:使用句子的所有字符的最后一层输出向量,计算它们的均值,作为句向量;
  3. MAX:使用句子的所有字符的最后一层输出向量,所有字符向量对应位置提取最大值,作为句向量。

但是,如果**直接使用原生bert模型来提取句向量,效果十分不理想,甚至不如GloVe提取的句向量。**


fine-tune

所以,作者提出一种针对句向量,对bert模型进行微调的方法,包括无监督和监督训练。

请添加图片描述

fine-tune的三种结构:

1. Classification Objective Function

如图1的分类结构,句子A和句子B输入到同个bert模型(参数绑定),然后使用[CLS]向量或者所有字符的向量均值得到A的句向量u、B的句向量v,然后拼接u、v和 element-wise的 |u-v|,最后通过softmax做一个k分类,loss为cross-entropy;

请添加图片描述

2. Regression Objective Function

如图2的回归结构,同样的方法得到u和v, 再经过cosine函数得到u和v的相似度,使用MSE( mean-squared-error)作为loss;

3. Triplet Objective Function

最后一种为三元组结构,如下式,句子a和p为负例,a和n为正例,sas_asa为句子a的句向量,方法同上。这个结构是让负例句子的距离要尽量比正例的大

请添加图片描述

其中|| · ||是距离度量,例如欧式距离,

ξ\xiξ为 margin ,控制负例和正例句子的距离差最小为ξ\xiξ


inference

推理阶段,按照上图2的做法,两个句子u和v输入到Sentence-BERT结构微调后的模型,选择一种pooling策略,得到句子的向量,然后使用cosine函数来计算两个句子的相似/相关度。


无监督训练

作者使用 SNLI(Bowman et al., 2015) 和Multi-Genre NLI(Williams et al., 2018)两个公开的数据集,带有三种标签contradiction、eintailment、neutral的句子对。

使用Classifification Objective Function来对bert模型进行微调,详细参数为:batch_size为16、Adam optimizer、2e-5的学习率、10%的线性学习率warmup,采用MEAN的pooling策略。

然后在STS数据集上进行验证,由于未使用到目标数据集,因此可以认为是无监督训练,具体效果如下:

( STS12-STS16:SemEval 2012-2016, STSb: STSbenchmark, SICK-R: SICK relatedness dataset,这些数据集带有0-5级的相关程度)

明显看出微调后的sentence-bert比原生bert的句向量效果提升了许多,并且使用RoBERTa可以进一步提升效果。

(作者也是做了实验,才得出原生bert句向量甚至不如GloVe的结论)
请添加图片描述


监督训练

上面提到,STS数据的标签是0-5级的相关程度,作者使用了regression objective function的结构进行微调SBERT。

实验了两种监督训练方案:


  1. 仅使用STSb数据进行监督训练;
  2. 先在NLI数据进行训练,然后再使用STSb数据

结果如下:

监督训练比无监督训练效果进一步提升,并且BERT的模型大小影响较大,BERT-large比base提升3-4点;

但使用RoBERTa未没有明显的效果提升。
请添加图片描述


代码实现

tensorflow1.x:https://github.com/QunBB/DeepLearning/tree/main/NLP/sentence_bert/sbert

pytorch推荐使用:Sentence-Transformers


推荐阅读
  • 推荐系统遇上深度学习(十七)详解推荐系统中的常用评测指标
    原创:石晓文小小挖掘机2018-06-18笔者是一个痴迷于挖掘数据中的价值的学习人,希望在平日的工作学习中,挖掘数据的价值, ... [详细]
  • 开源Keras Faster RCNN模型介绍及代码结构解析
    本文介绍了开源Keras Faster RCNN模型的环境需求和代码结构,包括FasterRCNN源码解析、RPN与classifier定义、data_generators.py文件的功能以及损失计算。同时提供了该模型的开源地址和安装所需的库。 ... [详细]
  • HashMap的相关问题及其底层数据结构和操作流程
    本文介绍了关于HashMap的相关问题,包括其底层数据结构、JDK1.7和JDK1.8的差异、红黑树的使用、扩容和树化的条件、退化为链表的情况、索引的计算方法、hashcode和hash()方法的作用、数组容量的选择、Put方法的流程以及并发问题下的操作。文章还提到了扩容死链和数据错乱的问题,并探讨了key的设计要求。对于对Java面试中的HashMap问题感兴趣的读者,本文将为您提供一些有用的技术和经验。 ... [详细]
  • 判断编码是否可立即解码的程序及电话号码一致性判断程序
    本文介绍了两个编程题目,一个是判断编码是否可立即解码的程序,另一个是判断电话号码一致性的程序。对于第一个题目,给出一组二进制编码,判断是否存在一个编码是另一个编码的前缀,如果不存在则称为可立即解码的编码。对于第二个题目,给出一些电话号码,判断是否存在一个号码是另一个号码的前缀,如果不存在则说明这些号码是一致的。两个题目的解法类似,都使用了树的数据结构来实现。 ... [详细]
  • CF:3D City Model(小思维)问题解析和代码实现
    本文通过解析CF:3D City Model问题,介绍了问题的背景和要求,并给出了相应的代码实现。该问题涉及到在一个矩形的网格上建造城市的情景,每个网格单元可以作为建筑的基础,建筑由多个立方体叠加而成。文章详细讲解了问题的解决思路,并给出了相应的代码实现供读者参考。 ... [详细]
  • 本文详细介绍了Java中vector的使用方法和相关知识,包括vector类的功能、构造方法和使用注意事项。通过使用vector类,可以方便地实现动态数组的功能,并且可以随意插入不同类型的对象,进行查找、插入和删除操作。这篇文章对于需要频繁进行查找、插入和删除操作的情况下,使用vector类是一个很好的选择。 ... [详细]
  • 本文介绍了在满足特定条件时如何在输入字段中使用默认值的方法和相应的代码。当输入字段填充100或更多的金额时,使用50作为默认值;当输入字段填充有-20或更多(负数)时,使用-10作为默认值。文章还提供了相关的JavaScript和Jquery代码,用于动态地根据条件使用默认值。 ... [详细]
  • 1Lock与ReadWriteLock1.1LockpublicinterfaceLock{voidlock();voidlockInterruptibl ... [详细]
  • Python教学练习二Python1-12练习二一、判断季节用户输入月份,判断这个月是哪个季节?3,4,5月----春 ... [详细]
  • 求解连通树的最小长度及优化
    本文介绍了求解连通树的最小长度的方法,并通过四边形不等式进行了优化。具体方法为使用状态转移方程求解树的最小长度,并通过四边形不等式进行优化。 ... [详细]
  • 文章目录题目:二叉搜索树中的两个节点被错误地交换。基本思想1:中序遍历题目:二叉搜索树中的两个节点被错误地交换。请在不改变其结构的情况下 ... [详细]
  • 查找给定字符串的所有不同回文子字符串原文:https://www ... [详细]
  • Flutter 布局(四) Baseline、FractionallySizedBox、IntrinsicHeight、IntrinsicWidth详解
    本文主要介绍Flutter布局中的Baseline、FractionallySizedBox、IntrinsicHeight、IntrinsicWidth四种控件,详细介绍了其布局 ... [详细]
  • Opencv提供了几种分类器,例程里通过字符识别来进行说明的1、支持向量机(SVM):给定训练样本,支持向量机建立一个超平面作为决策平面,使得正例和反例之间的隔离边缘被最大化。函数原型:训练原型cv ... [详细]
  • WPF之Binding初探
      初学wpf,经常被Binding搞晕,以下记录写Binding的基础。首先,盗用张图。这图形象的说明了Binding的机理。对于Binding,意思是数据绑定,基本用法是:1、 ... [详细]
author-avatar
手机用户2502885633
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有