热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PatternExploitingTrainingMLM任务用于文本匹配【代码解读】

一、总结•原文:#PET-文本分类的又一种妙解:https:xv44586.github.io20201025pet#ccf问答匹配比赛(下

一、总结
• 原文:

# PET-文本分类的又一种妙解:https://xv44586.github.io/2020/10/25/pet/
# ccf问答匹配比赛(下):如何只用“bert”夺冠:https://xv44586.github.io/2021/01/20/ccf-qa-2/

在这里插入图片描述


三、代码注释

原始链接:https://github.com/xv44586/ccf_2020_qa_match

# -*- coding: utf-8 -*-
# @Date : 2020/11/4
# @Author : mingming.xu
# @Email : xv44586@gmail.com
# @File : ccf_2020_qa_match_pet.py
"""
Pattern-Exploiting Training(PET): 增加pattern,将任务转换为MLM任务。
线上f1: 0.761tips:切换模型时,修改对应config_path/checkpoint_path/dict_path路径以及build_transformer_model 内的参数
"""
import os
import numpy as np
import json
from tqdm import tqdm
import numpy as np
import pandas as pd
from toolkit4nlp.backend import keras, K
from toolkit4nlp.tokenizers import Tokenizer, load_vocab
from toolkit4nlp.models import build_transformer_model, Model
from toolkit4nlp.optimizers import *
from toolkit4nlp.utils import pad_sequences, DataGenerator
from toolkit4nlp.layers import *os.environ["CUDA_VISIBLE_DEVICES"] = "1"# PET-文本分类的又一种妙解:https://xv44586.github.io/2020/10/25/pet/
# ccf问答匹配比赛(下):如何只用“bert”夺冠:https://xv44586.github.io/2021/01/20/ccf-qa-2/
num_classes = 32
maxlen = 128
batch_size = 8# BERT baseconfig_path = 'data/pretrained/nezha/NEZHA-Base/bert_config.json'
checkpoint_path = 'data/pretrained/nezha/NEZHA-Base/model.ckpt-900000'
dict_path = 'data/pretrained/nezha/NEZHA-Base/vocab.txt'tokenizer = Tokenizer(dict_path, do_lower_case=True)# pattern
pattern = '下面两个句子的语义相似度较高:'
# tokenizer.encode的第一个位置是cls,所以mask的index要+1
tokens = ["CLS"]+list(pattern)
print(tokens[14])
mask_idx = [14]id2label = {0: '低',1: '高'
}label2id = {v: k for k, v in id2label.items()}
print('label2id:',label2id)#label2id: {'低': 0, '高': 1}
labels = list(id2label.values())
print('labels:',labels)#labels: ['低', '高']
# labels在token中的ids,encode的时候,第一个数是cls,所以取encode输出的tokens[1:-1],代表跳过了cls的
label_ids = np.array([tokenizer.encode(l)[0][1:-1] for l in labels])
print('label_ids:',label_ids)#label_ids: [[ 856] [7770]]# 这里本文其实没有用到
def random_masking(token_ids):"""对输入进行随机mask"""# n个随机数rands &#61; np.random.random(len(token_ids))source, target &#61; [], []for r, t in zip(rands, token_ids):# [mask, 0.15 * 0.8, t(本身), 0.15 * 0.9, 随机, 0.15, 本身&#xff0c;target&#61;0&#xff0c;其余target都为1]if r < 0.15 * 0.8:# 通过mask来预测targetsource.append(tokenizer._token_mask_id)target.append(t)elif r < 0.15 * 0.9:# 通过本身来预测targetsource.append(t)target.append(t)elif r < 0.15:# 通过随机token来预测targetsource.append(np.random.choice(tokenizer._vocab_size - 1) &#43; 1)target.append(t)else:# 通过本身->label&#61;0?source.append(t)target.append(0)return source, targetclass data_generator(DataGenerator):def __init__(self, prefix&#61;False, *args, **kwargs):super(data_generator, self).__init__(*args, **kwargs)self.prefix &#61; prefixdef __iter__(self, shuffle&#61;False):batch_token_ids, batch_segment_ids, batch_target_ids &#61; [], [], []# 拿到query和replyfor is_end, (q, r, label) in self.get_sample(shuffle):# 没有label的时候定义为Nonelabel &#61; int(label) if label is not None else None# 有label的时候&#xff0c;才添加前缀if label is not None or self.prefix:q &#61; pattern &#43; q# 拿到token_ids和segment_idtoken_ids, segment_ids &#61; tokenizer.encode(q, r, maxlen&#61;maxlen)# 本文没有用到这个if shuffle:# 这里做了随机mask&#xff0c;随机mask有点没看懂, 但是本文都没用到这个source_tokens, target_tokens &#61; random_masking(token_ids)else:# 理论上target_tokens就等于source_tokenssource_tokens, target_tokens &#61; token_ids[:], token_ids[:]# mask labelif label is not None:# 将label转化成token&#xff0c;因为是mlm任务&#xff0c;最终的label其实就是tokenlabel_ids &#61; tokenizer.encode(id2label[label])[0][1:-1]# pattern &#61; &#39;直接回答问题:&#39;# mask_idx &#61; [1]# 这里label_ids也只有一个&#xff0c;所以是直接复制# mask_idx代表的其实是label在原文中的位置for m, lb in zip(mask_idx, label_ids):# 这里相当于把原文的label更换成为mask_id# source_tokens[1] &#61; mask_id# 然后target_tokens[1] &#61; label_id(也就是label对应的token_id)# 这里只更改了label对应的token&#xff0c;其余部分不变source_tokens[m] &#61; tokenizer._token_mask_idtarget_tokens[m] &#61; lbelif self.prefix:# 这里就一个mask_id&#xff0c;如果有多个多个都直接赋值成为token_idfor i in mask_idx:source_tokens[i] &#61; tokenizer._token_mask_id# 最后拿到mlm任务的source_tokens,segment_ids,target_tokensbatch_token_ids.append(source_tokens)batch_segment_ids.append(segment_ids)batch_target_ids.append(target_tokens)if is_end or len(batch_token_ids) &#61;&#61; self.batch_size:# 满足batch_size要求了&#xff0c;把他yield出去batch_token_ids &#61; pad_sequences(batch_token_ids)batch_segment_ids &#61; pad_sequences(batch_segment_ids)batch_target_ids &#61; pad_sequences(batch_target_ids)# batch_target_ids是每个位置target的idyield [batch_token_ids, batch_segment_ids, batch_target_ids], None# 将原始的batch里面的内容置为空batch_token_ids, batch_segment_ids, batch_target_ids &#61; [], [], []class CrossEntropy(Loss):"""交叉熵作为loss&#xff0c;并mask掉输入部分"""def compute_loss(self, inputs, mask&#61;None):y_true, y_pred &#61; inputs# K.not_equal, 拿到y_true不为0的部分&#xff0c;然后转化成为floaty_mask &#61; K.cast(K.not_equal(y_true, 0), K.floatx())# 计算精度accuracy &#61; keras.metrics.sparse_categorical_accuracy(y_true, y_pred)# mask掉输入部分accuracy &#61; K.sum(accuracy * y_mask) / K.sum(y_mask)# 拿到acc精度self.add_metric(accuracy, name&#61;&#39;accuracy&#39;)# 拿到交叉熵loss &#61; K.sparse_categorical_crossentropy(y_true, y_pred)# maskloss &#61; K.sum(loss * y_mask) / K.sum(y_mask)return loss# tokenizer
# tokenizer &#61; Tokenizer(dict_path, do_lower_case&#61;True)def train(train_data, val_data, test_data, best_model_file, test_result_file):train_generator &#61; data_generator(data&#61;train_data &#43; test_data, batch_size&#61;batch_size)valid_generator &#61; data_generator(data&#61;val_data, batch_size&#61;batch_size)test_generator &#61; data_generator(data&#61;test_data, batch_size&#61;batch_size, prefix&#61;True)target_in &#61; Input(shape&#61;(None,))model &#61; build_transformer_model(config_path&#61;config_path,checkpoint_path&#61;checkpoint_path,with_mlm&#61;True, # with_nlm为True是不是返回的output就不一样了&#xff0c;应该返回的就是mlm的output# model&#61;&#39;bert&#39;, # 加载bert/Roberta/erniemodel&#61;&#39;nezha&#39;)output &#61; CrossEntropy(output_idx&#61;1)([target_in, model.output])# 输入的时候&#xff0c;添加一个target_in&#xff0c; 输出还是和之前一样train_model &#61; Model(model.inputs &#43; [target_in], output)# 梯度衰减&#43;梯度积累AdamW &#61; extend_with_weight_decay(Adam)AdamWG &#61; extend_with_gradient_accumulation(AdamW)opt &#61; AdamWG(learning_rate&#61;1e-5, exclude_from_weight_decay&#61;[&#39;Norm&#39;, &#39;bias&#39;], grad_accum_steps&#61;4)train_model.compile(opt)train_model.summary()def evaluate(data):P, R, TP &#61; 0., 0., 0.for d, _ in tqdm(data):x_true, y_true &#61; d[:2], d[2]# 拿到预测结果&#xff0c;已经转化为label_ids里面的index了y_pred &#61; predict(x_true)# 只取mask_idx对应的y -> 原始token -> 原始label中的indexy_true &#61; np.array([labels.index(tokenizer.decode(y)) for y in y_true[:, mask_idx]])# print(y_true, y_pred)# 计算f1R &#43;&#61; y_pred.sum()P &#43;&#61; y_true.sum()TP &#43;&#61; ((y_pred &#43; y_true) > 1).sum()print(P, R, TP)pre &#61; TP / Rrec &#61; TP / Preturn 2 * (pre * rec) / (pre &#43; rec)def predict(x):if len(x) &#61;&#61; 3:x &#61; x[:2]# 拿到mask_idx对应的output# todo:这里这个model为什么不是train_model啊?y_pred &#61; model.predict(x)[:, mask_idx]# 这个维度信息不太清楚# batch, 0,label_ids对应的值, label_ids应该是可能有多个id&#xff0c;对应分类的多个类别y_pred &#61; y_pred[:, 0, label_ids[:, 0]]# 最后是取得所有label_ids里面的最大值&#xff0c;得到mlm的预测结果的&#xff0c;这里面的mlm的预测的结果的个数与分类的label数一致y_pred &#61; y_pred.argmax(axis&#61;1)return y_predclass Evaluator(keras.callbacks.Callback):def __init__(self, valid_generator, best_pet_model_file&#61;"best_pet_model.weights"):self.best_acc &#61; 0.self.valid_generator &#61; valid_generatorself.best_pet_model_file &#61; best_pet_model_filedef on_epoch_end(self, epoch, logs&#61;None):acc &#61; evaluate(self.valid_generator)if acc > self.best_acc:self.best_acc &#61; accself.model.save_weights(self.best_pet_model_file)print(&#39;acc :{}, best acc:{}&#39;.format(acc, self.best_acc))def write_to_file(path, test_generator, test_data):preds &#61; []# 分批预测结果for x, _ in tqdm(test_generator):pred &#61; predict(x)preds.extend(pred)# 把原始的query&#xff0c;reply以及预测的p都写入到文件中ret &#61; []for data, p in zip(test_data, preds):if data[2] is None:label &#61; -1else:label &#61; data[2]ret.append([data[0], data[1], str(label), str(p)])with open(path, &#39;w&#39;) as f:for r in ret:f.write(&#39;\t&#39;.join(r) &#43; &#39;\n&#39;)evaluator &#61; Evaluator(valid_generator, best_model_file)train_model.fit_generator(train_generator.generator(),steps_per_epoch&#61;len(train_generator),epochs&#61;10,callbacks&#61;[evaluator])train_model.load_weights(best_model_file)write_to_file(test_result_file, test_generator, test_data)def load_pair_data(f, isshuffle&#61;False):data &#61; []df &#61; pd.read_csv(f)if isshuffle:df &#61; df.sample(frac&#61;1.0, random_state&#61;1234)columns &#61; list(df.columns)if &#39;text_a&#39; not in columns and &#39;query1&#39; in columns:df.rename(columns&#61;{&#39;query1&#39;:&#39;text_a&#39;, &#39;query2&#39;:&#39;text_b&#39;}, inplace&#61;True)for i in range(len(df)):can &#61; df.iloc[i]text_a &#61; can[&#39;text_a&#39;]text_b &#61; can[&#39;text_b&#39;]if &#39;label&#39; not in columns:label &#61; Noneelse:label &#61; int(can[&#39;label&#39;])if label &#61;&#61; -1:label &#61; Nonedata.append([text_a, text_b, label])return datadef load_data():""":return: [text_a, text_b, label]天池疫情文本匹配数据集"""data_dir &#61; &#39;../data/tianchi/&#39;train_file &#61; data_dir &#43; &#39;train_20200228.csv&#39;dev_file &#61; data_dir &#43; &#39;dev_20200228.csv&#39;test_file &#61; data_dir &#43; &#39;test.example_20200228.csv&#39;train_data &#61; load_pair_data(train_file)val_data &#61; load_pair_data(dev_file)test_data &#61; load_pair_data(test_file)return train_data, val_data, test_datadef test_data_generator():data_dir &#61; &#39;../data/tianchi/&#39;train_file &#61; data_dir &#43; &#39;train_20200228.csv&#39;data &#61; load_pair_data(train_file)train_generator &#61; data_generator(data&#61;data, batch_size&#61;batch_size)for d in train_generator:print(d)breakdef run():train_data, val_data, test_data &#61; load_data()best_model_file &#61; &#39;best_pet_model.weights&#39;test_result_file &#61; &#39;pet_submission.tsv&#39;train(train_data, val_data, test_data, best_model_file, test_result_file)if __name__ &#61;&#61; &#39;__main__&#39;:test_data_generator()run()

推荐阅读
  • 本文节选自《NLTK基础教程——用NLTK和Python库构建机器学习应用》一书的第1章第1.2节,作者Nitin Hardeniya。本文将带领读者快速了解Python的基础知识,为后续的机器学习应用打下坚实的基础。 ... [详细]
  • 通过将常用的外部命令集成到VSCode中,可以提高开发效率。本文介绍如何在VSCode中配置和使用自定义的外部命令,从而简化命令执行过程。 ... [详细]
  • Linux CentOS 7 安装PostgreSQL 9.5.17 (源码编译)
    近日需要将PostgreSQL数据库从Windows中迁移到Linux中,LinuxCentOS7安装PostgreSQL9.5.17安装过程特此记录。安装环境&#x ... [详细]
  • 技术分享:使用 Flask、AngularJS 和 Jinja2 构建高效前后端交互系统
    技术分享:使用 Flask、AngularJS 和 Jinja2 构建高效前后端交互系统 ... [详细]
  • 基于Net Core 3.0与Web API的前后端分离开发:Vue.js在前端的应用
    本文介绍了如何使用Net Core 3.0和Web API进行前后端分离开发,并重点探讨了Vue.js在前端的应用。后端采用MySQL数据库和EF Core框架进行数据操作,开发环境为Windows 10和Visual Studio 2019,MySQL服务器版本为8.0.16。文章详细描述了API项目的创建过程、启动步骤以及必要的插件安装,为开发者提供了一套完整的开发指南。 ... [详细]
  • 在OpenShift上部署基于MongoDB和Node.js的多层应用程序
    本文档详细介绍了如何在OpenShift 4.x环境中部署一个包含MongoDB数据库和Node.js后端及前端的多层应用程序。通过逐步指导,读者可以轻松完成整个部署过程。 ... [详细]
  • Python 3 Scrapy 框架执行流程详解
    本文详细介绍了如何在 Python 3 环境下安装和使用 Scrapy 框架,包括常用命令和执行流程。Scrapy 是一个强大的 Web 抓取框架,适用于数据挖掘、监控和自动化测试等多种场景。 ... [详细]
  • 在JavaWeb开发中,文件上传是一个常见的需求。无论是通过表单还是其他方式上传文件,都必须使用POST请求。前端部分通常采用HTML表单来实现文件选择和提交功能。后端则利用Apache Commons FileUpload库来处理上传的文件,该库提供了强大的文件解析和存储能力,能够高效地处理各种文件类型。此外,为了提高系统的安全性和稳定性,还需要对上传文件的大小、格式等进行严格的校验和限制。 ... [详细]
  • 本文介绍了如何使用 Node.js 和 Express(4.x 及以上版本)构建高效的文件上传功能。通过引入 `multer` 中间件,可以轻松实现文件上传。首先,需要通过 `npm install multer` 安装该中间件。接着,在 Express 应用中配置 `multer`,以处理多部分表单数据。本文详细讲解了 `multer` 的基本用法和高级配置,帮助开发者快速搭建稳定可靠的文件上传服务。 ... [详细]
  • 如何将Python与Excel高效结合:常用操作技巧解析
    本文深入探讨了如何将Python与Excel高效结合,涵盖了一系列实用的操作技巧。文章内容详尽,步骤清晰,注重细节处理,旨在帮助读者掌握Python与Excel之间的无缝对接方法,提升数据处理效率。 ... [详细]
  • 利用 JavaScript 和 Node.js 验证时间的有效性
    本文探讨了如何使用 JavaScript 和 Node.js 验证时间的有效性。通过编写一个 `isTime` 函数,我们可以确保输入的时间格式正确且有效。该函数利用正则表达式匹配时间字符串,检查其是否符合常见的日期时间格式,如 `YYYY-MM-DD` 或 `HH:MM:SS`。此外,我们还介绍了如何处理不同时间格式的转换和验证,以提高代码的健壮性和可靠性。 ... [详细]
  • 如何使用 `org.apache.tomcat.websocket.server.WsServerContainer.findMapping()` 方法及其代码示例解析 ... [详细]
  • 本指南介绍了如何在ASP.NET Web应用程序中利用C#和JavaScript实现基于指纹识别的登录系统。通过集成指纹识别技术,用户无需输入传统的登录ID即可完成身份验证,从而提升用户体验和安全性。我们将详细探讨如何配置和部署这一功能,确保系统的稳定性和可靠性。 ... [详细]
  • 深入探索HTTP协议的学习与实践
    在初次访问某个网站时,由于本地没有缓存,服务器会返回一个200状态码的响应,并在响应头中设置Etag和Last-Modified等缓存控制字段。这些字段用于后续请求时验证资源是否已更新,从而提高页面加载速度和减少带宽消耗。本文将深入探讨HTTP缓存机制及其在实际应用中的优化策略,帮助读者更好地理解和运用HTTP协议。 ... [详细]
  • 在Ubuntu系统中安装Android SDK的详细步骤及解决“Failed to fetch URL https://dlssl.google.com/”错误的方法
    在Ubuntu 11.10 x64系统中安装Android SDK的详细步骤,包括配置环境变量和解决“Failed to fetch URL https://dlssl.google.com/”错误的方法。本文详细介绍了如何在该系统上顺利安装并配置Android SDK,确保开发环境的稳定性和高效性。此外,还提供了解决网络连接问题的实用技巧,帮助用户克服常见的安装障碍。 ... [详细]
author-avatar
泉州多棱汽车销售服务有限公司
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有