热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

使用HuggingfaceTrainer对自定义数据集进行文本分类

文本分类是一项常见的NLP任务,它根据文本的内容定义文本的类型、流派或主题。Huggingface🤗Transformers提供API和工具来轻松下载




文本分类是一项常见的 NLP 任务,它根据文本的内容定义文本的类型、流派或主题。Huggingface🤗 Transformers 提供 API 和工具来轻松下载和训练最先进的预训练模型。Huggingface Transformers 支持 PyTorch、TensorFlow 和 JAX 之间的框架互操作性。模型还可以导出为 ONNX 和 TorchScript 等格式,以便在生产环境中部署。

在这里插入图片描述
此博客将指导您使用 Huggingface Transformers 对自定义数据集的 Distillbert 进行微调。

DistilBERT 是一种小型、快速、便宜且轻便的 Transformer 模型,通过提取 BERT 基础进行训练。它比bert-base-uncased少了 40% 的参数,运行速度提高了 60%,同时保留了 BERT 在 GLUE 语言理解基准测试中超过 95% 的性能

本博客源代码:Google Colab


1.安装transformers库

按照以下链接中的这些说明安装 huggingface 库:https 😕/huggingface.co/docs/datasets/v1.11.0/installation.html

!pip install transformers
!pip install sentencepiece

我们将在本教程中使用 IMDb 数据集。IMDB 数据集是一个大型电影评论数据集。这是一个用于二元情感分类的数据集,包含比以前的基准数据集多得多的数据。他们提供了一组 25,000 条高度极端的电影评论用于训练,25,000 条用于测试。还有其他未标记的数据可供使用。我们可以通过以下步骤下载数据集。


2.下载数据集

!wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xf aclImdb_v1.tar.gz

此数据被组织到每个示例一个文本文件的文件夹中pos。read_imdb_split函数将帮助我们读取和处理数据集。neg

from pathlib import Path
def read_imdb_split(split_dir):
split_dir = Path(split_dir)
texts = []
labels = []
for label_dir in ["pos", "neg"]:
for text_file in (split_dir/label_dir).iterdir():
texts.append(text_file.read_text())
labels.append(0 if label_dir is "neg" else 1)
return texts, labels
train_texts, train_labels = read_imdb_split('aclImdb/train')
test_texts, test_labels = read_imdb_split('aclImdb/test')

让我们使用 Scikit-learn 中的train_test_split实用程序从训练数据集创建训练和验证集。我们将使用验证数据集使用自定义度量函数来评估和调整我们的模型。

from sklearn.model_selection import train_test_split
train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)

3.预处理输入数据

下一步是对模型的输入数据进行预处理。与清理数据一起,标记化是任何 NLP 管道中的第一步。Tokenizer 将非结构化字符串转换为适合机器学习的数字数据结构。由于我们将使用预训练的 DistilBert,因此我们将使用 DistilBert 分词器进行分词。

from transformers import DistilBertTokenizerFast
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)

4.数据集对象

torch.utils.data.Dataset 是 Pytorch 中表示数据集的抽象类。我们的自定义数据集将继承Dataset并覆盖以下方法:

__len__以便len(dataset)返回数据集的大小。
__getitem__支持索引,这样dataset[i]可以用来获取第i个样本。

import torch
class IMDbDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
train_dataset = IMDbDataset(train_encodings, train_labels)
val_dataset = IMDbDataset(val_encodings, val_labels)
test_dataset = IMDbDataset(test_encodings, test_labels)






推荐阅读
  • 本文详细介绍了如何使用Python中的smtplib库来发送带有附件的邮件,并提供了完整的代码示例。作者:多测师_王sir,时间:2020年5月20日 17:24,微信:15367499889,公司:上海多测师信息有限公司。 ... [详细]
  • 目录预备知识导包构建数据集神经网络结构训练测试精度可视化计算模型精度损失可视化输出网络结构信息训练神经网络定义参数载入数据载入神经网络结构、损失及优化训练及测试损失、精度可视化qu ... [详细]
  • 2020年9月15日,Oracle正式发布了最新的JDK 15版本。本次更新带来了许多新特性,包括隐藏类、EdDSA签名算法、模式匹配、记录类、封闭类和文本块等。 ... [详细]
  • 机器学习算法:SVM(支持向量机)
    SVM算法(SupportVectorMachine,支持向量机)的核心思想有2点:1、如果数据线性可分,那么基于最大间隔的方式来确定超平面,以确保全局最优, ... [详细]
  • 本文节选自《NLTK基础教程——用NLTK和Python库构建机器学习应用》一书的第1章第1.2节,作者Nitin Hardeniya。本文将带领读者快速了解Python的基础知识,为后续的机器学习应用打下坚实的基础。 ... [详细]
  • 本文详细介绍了如何使用 Python 进行主成分分析(PCA),包括数据导入、预处理、模型训练和结果可视化等步骤。通过具体的代码示例,帮助读者理解和应用 PCA 技术。 ... [详细]
  • 近期,微信公众平台上的HTML5游戏引起了广泛讨论,预示着HTML5游戏将迎来新的发展机遇。磊友科技的赵霏,作为一名HTML5技术的倡导者,分享了他在微信平台上开发HTML5游戏的经验和见解。 ... [详细]
  • JUC(三):深入解析AQS
    本文详细介绍了Java并发工具包中的核心类AQS(AbstractQueuedSynchronizer),包括其基本概念、数据结构、源码分析及核心方法的实现。 ... [详细]
  • 利用python爬取豆瓣电影Top250的相关信息,包括电影详情链接,图片链接,影片中文名,影片外国名,评分,评价数,概况,导演,主演,年份,地区,类别这12项内容,然后将爬取的信息写入Exce ... [详细]
  • 使用jqTransform插件美化表单
    jqTransform 是由 DFC Engineering 开发的一款 jQuery 插件,专用于美化表单元素,操作简便,能够美化包括输入框、单选按钮、多行文本域、下拉选择框和复选框在内的所有表单元素。 ... [详细]
  • 在多线程并发环境中,普通变量的操作往往是线程不安全的。本文通过一个简单的例子,展示了如何使用 AtomicInteger 类及其核心的 CAS 无锁算法来保证线程安全。 ... [详细]
  • 本文将详细介绍如何在Mac上安装Jupyter Notebook,并提供一些常见的问题解决方法。通过这些步骤,您将能够顺利地在Mac上运行Jupyter Notebook。 ... [详细]
  • 本文详细介绍了 InfluxDB、collectd 和 Grafana 的安装与配置流程。首先,按照启动顺序依次安装并配置 InfluxDB、collectd 和 Grafana。InfluxDB 作为时序数据库,用于存储时间序列数据;collectd 负责数据的采集与传输;Grafana 则用于数据的可视化展示。文中提供了 collectd 的官方文档链接,便于用户参考和进一步了解其配置选项。通过本指南,读者可以轻松搭建一个高效的数据监控系统。 ... [详细]
  • 本文介绍了如何使用Python的Paramiko库批量更新多台服务器的登录密码。通过示例代码展示了具体实现方法,确保了操作的高效性和安全性。Paramiko库提供了强大的SSH2协议支持,使得远程服务器管理变得更加便捷。此外,文章还详细说明了代码的各个部分,帮助读者更好地理解和应用这一技术。 ... [详细]
  • 大类|电阻器_使用Requests、Etree、BeautifulSoup、Pandas和Path库进行数据抓取与处理 | 将指定区域内容保存为HTML和Excel格式
    大类|电阻器_使用Requests、Etree、BeautifulSoup、Pandas和Path库进行数据抓取与处理 | 将指定区域内容保存为HTML和Excel格式 ... [详细]
author-avatar
传奇gk_543
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有