热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

使用HuggingfaceTrainer对自定义数据集进行文本分类

文本分类是一项常见的NLP任务,它根据文本的内容定义文本的类型、流派或主题。Huggingface🤗Transformers提供API和工具来轻松下载




文本分类是一项常见的 NLP 任务,它根据文本的内容定义文本的类型、流派或主题。Huggingface🤗 Transformers 提供 API 和工具来轻松下载和训练最先进的预训练模型。Huggingface Transformers 支持 PyTorch、TensorFlow 和 JAX 之间的框架互操作性。模型还可以导出为 ONNX 和 TorchScript 等格式,以便在生产环境中部署。

在这里插入图片描述
此博客将指导您使用 Huggingface Transformers 对自定义数据集的 Distillbert 进行微调。

DistilBERT 是一种小型、快速、便宜且轻便的 Transformer 模型,通过提取 BERT 基础进行训练。它比bert-base-uncased少了 40% 的参数,运行速度提高了 60%,同时保留了 BERT 在 GLUE 语言理解基准测试中超过 95% 的性能

本博客源代码:Google Colab


1.安装transformers库

按照以下链接中的这些说明安装 huggingface 库:https 😕/huggingface.co/docs/datasets/v1.11.0/installation.html

!pip install transformers
!pip install sentencepiece

我们将在本教程中使用 IMDb 数据集。IMDB 数据集是一个大型电影评论数据集。这是一个用于二元情感分类的数据集,包含比以前的基准数据集多得多的数据。他们提供了一组 25,000 条高度极端的电影评论用于训练,25,000 条用于测试。还有其他未标记的数据可供使用。我们可以通过以下步骤下载数据集。


2.下载数据集

!wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xf aclImdb_v1.tar.gz

此数据被组织到每个示例一个文本文件的文件夹中pos。read_imdb_split函数将帮助我们读取和处理数据集。neg

from pathlib import Path
def read_imdb_split(split_dir):
split_dir = Path(split_dir)
texts = []
labels = []
for label_dir in ["pos", "neg"]:
for text_file in (split_dir/label_dir).iterdir():
texts.append(text_file.read_text())
labels.append(0 if label_dir is "neg" else 1)
return texts, labels
train_texts, train_labels = read_imdb_split('aclImdb/train')
test_texts, test_labels = read_imdb_split('aclImdb/test')

让我们使用 Scikit-learn 中的train_test_split实用程序从训练数据集创建训练和验证集。我们将使用验证数据集使用自定义度量函数来评估和调整我们的模型。

from sklearn.model_selection import train_test_split
train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)

3.预处理输入数据

下一步是对模型的输入数据进行预处理。与清理数据一起,标记化是任何 NLP 管道中的第一步。Tokenizer 将非结构化字符串转换为适合机器学习的数字数据结构。由于我们将使用预训练的 DistilBert,因此我们将使用 DistilBert 分词器进行分词。

from transformers import DistilBertTokenizerFast
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)

4.数据集对象

torch.utils.data.Dataset 是 Pytorch 中表示数据集的抽象类。我们的自定义数据集将继承Dataset并覆盖以下方法:

__len__以便len(dataset)返回数据集的大小。
__getitem__支持索引,这样dataset[i]可以用来获取第i个样本。

import torch
class IMDbDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
train_dataset = IMDbDataset(train_encodings, train_labels)
val_dataset = IMDbDataset(val_encodings, val_labels)
test_dataset = IMDbDataset(test_encodings, test_labels)






推荐阅读
  • 开源Keras Faster RCNN模型介绍及代码结构解析
    本文介绍了开源Keras Faster RCNN模型的环境需求和代码结构,包括FasterRCNN源码解析、RPN与classifier定义、data_generators.py文件的功能以及损失计算。同时提供了该模型的开源地址和安装所需的库。 ... [详细]
  • 基于PgpoolII的PostgreSQL集群安装与配置教程
    本文介绍了基于PgpoolII的PostgreSQL集群的安装与配置教程。Pgpool-II是一个位于PostgreSQL服务器和PostgreSQL数据库客户端之间的中间件,提供了连接池、复制、负载均衡、缓存、看门狗、限制链接等功能,可以用于搭建高可用的PostgreSQL集群。文章详细介绍了通过yum安装Pgpool-II的步骤,并提供了相关的官方参考地址。 ... [详细]
  • Skywalking系列博客1安装单机版 Skywalking的快速安装方法
    本文介绍了如何快速安装单机版的Skywalking,包括下载、环境需求和端口检查等步骤。同时提供了百度盘下载地址和查询端口是否被占用的命令。 ... [详细]
  • 安装mysqlclient失败解决办法
    本文介绍了在MAC系统中,使用django使用mysql数据库报错的解决办法。通过源码安装mysqlclient或将mysql_config添加到系统环境变量中,可以解决安装mysqlclient失败的问题。同时,还介绍了查看mysql安装路径和使配置文件生效的方法。 ... [详细]
  • EPICS Archiver Appliance存储waveform记录的尝试及资源需求分析
    本文介绍了EPICS Archiver Appliance存储waveform记录的尝试过程,并分析了其所需的资源容量。通过解决错误提示和调整内存大小,成功存储了波形数据。然后,讨论了储存环逐束团信号的意义,以及通过记录多圈的束团信号进行参数分析的可能性。波形数据的存储需求巨大,每天需要近250G,一年需要90T。然而,储存环逐束团信号具有重要意义,可以揭示出每个束团的纵向振荡频率和模式。 ... [详细]
  • Windows下配置PHP5.6的方法及注意事项
    本文介绍了在Windows系统下配置PHP5.6的步骤及注意事项,包括下载PHP5.6、解压并配置IIS、添加模块映射、测试等。同时提供了一些常见问题的解决方法,如下载缺失的msvcr110.dll文件等。通过本文的指导,读者可以轻松地在Windows系统下配置PHP5.6,并解决一些常见的配置问题。 ... [详细]
  • 本文介绍了在mac环境下使用nginx配置nodejs代理服务器的步骤,包括安装nginx、创建目录和文件、配置代理的域名和日志记录等。 ... [详细]
  • 深度学习中的Vision Transformer (ViT)详解
    本文详细介绍了深度学习中的Vision Transformer (ViT)方法。首先介绍了相关工作和ViT的基本原理,包括图像块嵌入、可学习的嵌入、位置嵌入和Transformer编码器等。接着讨论了ViT的张量维度变化、归纳偏置与混合架构、微调及更高分辨率等方面。最后给出了实验结果和相关代码的链接。本文的研究表明,对于CV任务,直接应用纯Transformer架构于图像块序列是可行的,无需依赖于卷积网络。 ... [详细]
  • 本文介绍了在CentOS上安装Python2.7.2的详细步骤,包括下载、解压、编译和安装等操作。同时提供了一些注意事项,以及测试安装是否成功的方法。 ... [详细]
  • 【shell】网络处理:判断IP是否在网段、两个ip是否同网段、IP地址范围、网段包含关系
    本文介绍了使用shell脚本判断IP是否在同一网段、判断IP地址是否在某个范围内、计算IP地址范围、判断网段之间的包含关系的方法和原理。通过对IP和掩码进行与计算,可以判断两个IP是否在同一网段。同时,还提供了一段用于验证IP地址的正则表达式和判断特殊IP地址的方法。 ... [详细]
  • CEPH LIO iSCSI Gateway及其使用参考文档
    本文介绍了CEPH LIO iSCSI Gateway以及使用该网关的参考文档,包括Ceph Block Device、CEPH ISCSI GATEWAY、USING AN ISCSI GATEWAY等。同时提供了多个参考链接,详细介绍了CEPH LIO iSCSI Gateway的配置和使用方法。 ... [详细]
  • 本文介绍了在Mac上安装Xamarin并使用Windows上的VS开发iOS app的方法,包括所需的安装环境和软件,以及使用Xamarin.iOS进行开发的步骤。通过这种方法,即使没有Mac或者安装苹果系统,程序员们也能轻松开发iOS app。 ... [详细]
  • Python操作MySQL(pymysql模块)详解及示例代码
    本文介绍了使用Python操作MySQL数据库的方法,详细讲解了pymysql模块的安装和连接MySQL数据库的步骤,并提供了示例代码。内容涵盖了创建表、插入数据、查询数据等操作,帮助读者快速掌握Python操作MySQL的技巧。 ... [详细]
  • 本文介绍了在Windows系统下安装Python、setuptools、pip和virtualenv的步骤,以及安装过程中需要注意的事项。详细介绍了Python2.7.4和Python3.3.2的安装路径,以及如何使用easy_install安装setuptools。同时提醒用户在安装完setuptools后,需要继续安装pip,并注意不要将Python的目录添加到系统的环境变量中。最后,还介绍了通过下载ez_setup.py来安装setuptools的方法。 ... [详细]
  • 本文总结了使用不同方式生成 Dataframe 的方法,包括通过CSV文件、Excel文件、python dictionary、List of tuples和List of dictionary。同时介绍了一些注意事项,如使用绝对路径引入文件和安装xlrd包来读取Excel文件。 ... [详细]
author-avatar
传奇gk_543
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有