热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

文本分类上分微调技巧实战

目录 引言 How to Fine-Tune BERT for Text Classification 论文 微调策略 ITPT:继续预训练 学术论文分类挑战赛微调 huggingf

目录

  • 引言
  • How to Fine-Tune BERT for Text Classification 论文

    • 微调策略
    • ITPT:继续预训练
  • 学术论文分类挑战赛微调

    • huggingface工具介绍
    • bert模型介绍
    • 数据创建
    • 模型定义
    • 模型训练与评估
    • 模型改进

How to Fine-Tune BERT for Text Classification 论文

微调策略

  1. 处理长文本 我们知道BERT 的最大序列长度为 512,BERT 应用于文本分类的第一个问题是如何处理长度大于 512 的文本。本文尝试了以下方式处理长文章。

Truncation methods 截断法 文章的关键信息位于开头和结尾。 我们可以使用三种不同的截断文本方法来执行 BERT 微调。

head-only: keep the first 510 tokens 头部510个字符,加上两个特殊字符刚好是512 ;
tail-only: keep the last 510 tokens;尾部510个字符,同理加上两个特殊字符刚好是512 ;
head+tail: empirically select the first 128and the last 382 tokens.:尾部结合

Hierarchical methods 层级法 输入的文本首先被分成k = L/510个片段,喂入 BERT 以获得 k 个文本片段的表示向量。 每个分数的表示是最后一层的 [CLS] 标记的隐藏状态,然后我们使用均值池化、最大池化和自注意力来组合所有分数的表示。

  1. 不同层的特征 BERT 的每一层都捕获输入文本的不同特征。 文本研究了来自不同层的特征的有效性, 然后我们微调模型并记录测试错误率的性能。
文本分类上分微调技巧实战
image

我们可以看到:最后一层表征效果最好;最后4层进行max-pooling效果最好

  1. 灾难性遗忘 Catastrophic forgetting (灾难性遗忘)通常是迁移学习中的常见诟病,这意味着在学习新知识的过程中预先训练的知识会被遗忘。 因此,本文还研究了 BERT 是否存在灾难性遗忘问题。 我们用不同的学习率对 BERT 进行了微调,发现需要较低的学习率,例如 2e-5,才能使 BERT 克服灾难性遗忘问题。 在 4e-4 的较大学习率下,训练集无法收敛。
文本分类上分微调技巧实战
image

这个也深有体会,当预训练模型失效不能够收敛的时候多检查下超参数是否设置有问题。

  1. Layer-wise Decreasing Layer Rate 逐层降低学习率 下表 显示了不同基础学习率和衰减因子在 IMDb 数据集上的性能。 我们发现为下层分配较低的学习率对微调 BERT 是有效的,比较合适的设置是 ξ=0.95 和 lr=2.0e-5

为不同的BERT设置不同的学习率及衰减因子,BERT的表现如何?

文本分类上分微调技巧实战
image

η^l代表第几层的学习率

ITPT:继续预训练

ITPT:继续预训练
Bert是在通用的语料上进行预训练的,如果要在特定领域应用文本分类,数据分布一定是有一些差距的。这时候可以考虑进行深度预训练。

  1. Within-task pre-training:Bert在训练语料上进行预训练 In-domain pre-training:在同一领域上的语料进行预训练 Cross-domain pre-training:在不同领域上的语料进行预训练
    Within-task pretraining
    [图片上传失败…(image-177ffd-1627289320036)]
    BERT-ITPT-FiT 的意思是“BERT + with In-Task Pre-Training + Fine-Tuning”,上图表示IMDb 数据集上进行不同步数的继续预训练是有收益的。

  2. In-Domain 和 Cross-Domain Further Pre-Training

    文本分类上分微调技巧实战
    image

    我们发现几乎所有进一步的预训练模型在所有七个数据集上的表现都比原始 BERT 基础模型。 一般来说,域内预训练可以带来比任务内预训练更好的性能。 在小句子级 TREC 数据集上,任务内预训练会损害性能,而在使用 Yah 的领域预训练中。Yah. A.语料库可以在TREC上取得更好的结果。
    这篇论文与其他模型进行了比较,结果如下表所示:

    文本分类上分微调技巧实战
    image

改进1 Last 4 Layers Concatenating

class LastFourModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        cOnfig= AutoConfig.from_pretrained(PRE_TRAINED_MODEL_NAME)
        config.update({'output_hidden_states':True})
        self.model = AutoModel.from_pretrained(PRE_TRAINED_MODEL_NAME, cOnfig=config)
        self.linear = nn.Linear(4*HIDDEN_SIZE, n_classes)
        
    def forward(self, input_ids, attention_mask):
        
        outputs = self.model(input_ids, attention_mask)
        all_hidden_states = torch.stack(outputs[2])
        concatenate_pooling = torch.cat(
            (all_hidden_states[-1], all_hidden_states[-2], all_hidden_states[-3], all_hidden_states[-4]), -1
        )
        concatenate_pooling = concatenate_pooling[:,0]
        output = self.linear(concatenate_pooling)
        
        return soutput

改进2 模型层间差分学习率

def get_parameters(model, model_init_lr, multiplier, classifier_lr):
    parameters = []
    lr = model_init_lr
    for layer in range(12,-1,-1):
        layer_params = {
            'params': [p for n,p in model.named_parameters() if f'encoder.layer.{layer}.' in n],
            'lr': lr
        }
        parameters.append(layer_params)
        lr *= multiplier
    classifier_params = {
        'params': [p for n,p in model.named_parameters() if 'layer_norm' in n or 'linear' in n 
                   or 'pooling' in n],
        'lr': classifier_lr
    }
    parameters.append(classifier_params)
    return parameters
parameters=get_parameters(model,2e-5,0.95, 1e-4)
optimizer=AdamW(parameters)

改进3 ITPT 继续预训练

import warnings
import pandas as pd
from transformers import (AutoModelForMaskedLM,
                          AutoTokenizer, LineByLineTextDataset,
                          DataCollatorForLanguageModeling,
                          Trainer, TrainingArguments)

warnings.filterwarnings('ignore')

train_data = pd.read_csv('data/train/train.csv', sep='t')
test_data = pd.read_csv('data/test/test.csv', sep='t')
train_data['text'] = train_data['title'] + '.' + train_data['abstract']
test_data['text'] = test_data['title'] + '.' + test_data['abstract']
data = pd.concat([train_data, test_data])
data['text'] = data['text'].apply(lambda x: x.replace('n', ''))

text = 'n'.join(data.text.tolist())

with open('text.txt', 'w') as f:
    f.write(text)

model_name = 'roberta-base'
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.save_pretrained('./paper_roberta_base')

train_dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path="text.txt",  # mention train text file here
    block_size=256)

valid_dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path="text.txt",  # mention valid text file here
    block_size=256)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

training_args = TrainingArguments(
    output_dir="./paper_roberta_base_chk",  # select model path for checkpoint
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=2,
    evaluation_strategy='steps',
    save_total_limit=2,
    eval_steps=200,
    metric_for_best_model='eval_loss',
    greater_is_better=False,
    load_best_model_at_end=True,
    prediction_loss_Only=True,
    report_to="none")

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset)

trainer.train()
trainer.save_model(f'./paper_roberta_base')

推荐阅读
  • 模板引擎StringTemplate的使用方法和特点
    本文介绍了模板引擎StringTemplate的使用方法和特点,包括强制Model和View的分离、Lazy-Evaluation、Recursive enable等。同时,还介绍了StringTemplate语法中的属性和普通字符的使用方法,并提供了向模板填充属性的示例代码。 ... [详细]
  • 前景:当UI一个查询条件为多项选择,或录入多个条件的时候,比如查询所有名称里面包含以下动态条件,需要模糊查询里面每一项时比如是这样一个数组条件:newstring[]{兴业银行, ... [详细]
  • 本文详细介绍了如何使用MySQL来显示SQL语句的执行时间,并通过MySQL Query Profiler获取CPU和内存使用量以及系统锁和表锁的时间。同时介绍了效能分析的三种方法:瓶颈分析、工作负载分析和基于比率的分析。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • 本文介绍了一个在线急等问题解决方法,即如何统计数据库中某个字段下的所有数据,并将结果显示在文本框里。作者提到了自己是一个菜鸟,希望能够得到帮助。作者使用的是ACCESS数据库,并且给出了一个例子,希望得到的结果是560。作者还提到自己已经尝试了使用"select sum(字段2) from 表名"的语句,得到的结果是650,但不知道如何得到560。希望能够得到解决方案。 ... [详细]
  • ASP.NET2.0数据教程之十四:使用FormView的模板
    本文介绍了在ASP.NET 2.0中使用FormView控件来实现自定义的显示外观,与GridView和DetailsView不同,FormView使用模板来呈现,可以实现不规则的外观呈现。同时还介绍了TemplateField的用法和FormView与DetailsView的区别。 ... [详细]
  • MyBatis多表查询与动态SQL使用
    本文介绍了MyBatis多表查询与动态SQL的使用方法,包括一对一查询和一对多查询。同时还介绍了动态SQL的使用,包括if标签、trim标签、where标签、set标签和foreach标签的用法。文章还提供了相关的配置信息和示例代码。 ... [详细]
  • 本文介绍了在处理不规则数据时如何使用Python自动提取文本中的时间日期,包括使用dateutil.parser模块统一日期字符串格式和使用datefinder模块提取日期。同时,还介绍了一段使用正则表达式的代码,可以支持中文日期和一些特殊的时间识别,例如'2012年12月12日'、'3小时前'、'在2012/12/13哈哈'等。 ... [详细]
  • 本文介绍了在iOS开发中使用UITextField实现字符限制的方法,包括利用代理方法和使用BNTextField-Limit库的实现策略。通过这些方法,开发者可以方便地限制UITextField的字符个数和输入规则。 ... [详细]
  • 本文介绍了Swing组件的用法,重点讲解了图标接口的定义和创建方法。图标接口用来将图标与各种组件相关联,可以是简单的绘画或使用磁盘上的GIF格式图像。文章详细介绍了图标接口的属性和绘制方法,并给出了一个菱形图标的实现示例。该示例可以配置图标的尺寸、颜色和填充状态。 ... [详细]
  • Android自定义控件绘图篇之Paint函数大汇总
    本文介绍了Android自定义控件绘图篇中的Paint函数大汇总,包括重置画笔、设置颜色、设置透明度、设置样式、设置宽度、设置抗锯齿等功能。通过学习这些函数,可以更好地掌握Paint的用法。 ... [详细]
  • 本文介绍了Hive常用命令及其用途,包括列出数据表、显示表字段信息、进入数据库、执行select操作、导出数据到csv文件等。同时还涉及了在AndroidManifest.xml中获取meta-data的value值的方法。 ... [详细]
  • Python教学练习二Python1-12练习二一、判断季节用户输入月份,判断这个月是哪个季节?3,4,5月----春 ... [详细]
  • 在本教程中,我们将看到如何使用FLASK制作第一个用于机器学习模型的RESTAPI。我们将从创建机器学习模型开始。然后,我们将看到使用Flask创建AP ... [详细]
  • ①页面初始化----------收到客户端的请求,产生相应页面的Page对象,通过Page_Init事件进行page对象及其控件的初始化.②加载视图状态-------ViewSta ... [详细]
author-avatar
终渐疯分_501
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有