热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

使用HuggingfaceTrainer对自定义数据集进行文本分类

文本分类是一项常见的NLP任务,它根据文本的内容定义文本的类型、流派或主题。Huggingface🤗Transformers提供API和工具来轻松下载




文本分类是一项常见的 NLP 任务,它根据文本的内容定义文本的类型、流派或主题。Huggingface🤗 Transformers 提供 API 和工具来轻松下载和训练最先进的预训练模型。Huggingface Transformers 支持 PyTorch、TensorFlow 和 JAX 之间的框架互操作性。模型还可以导出为 ONNX 和 TorchScript 等格式,以便在生产环境中部署。

在这里插入图片描述
此博客将指导您使用 Huggingface Transformers 对自定义数据集的 Distillbert 进行微调。

DistilBERT 是一种小型、快速、便宜且轻便的 Transformer 模型,通过提取 BERT 基础进行训练。它比bert-base-uncased少了 40% 的参数,运行速度提高了 60%,同时保留了 BERT 在 GLUE 语言理解基准测试中超过 95% 的性能

本博客源代码:Google Colab


1.安装transformers库

按照以下链接中的这些说明安装 huggingface 库:https 😕/huggingface.co/docs/datasets/v1.11.0/installation.html

!pip install transformers
!pip install sentencepiece

我们将在本教程中使用 IMDb 数据集。IMDB 数据集是一个大型电影评论数据集。这是一个用于二元情感分类的数据集,包含比以前的基准数据集多得多的数据。他们提供了一组 25,000 条高度极端的电影评论用于训练,25,000 条用于测试。还有其他未标记的数据可供使用。我们可以通过以下步骤下载数据集。


2.下载数据集

!wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xf aclImdb_v1.tar.gz

此数据被组织到每个示例一个文本文件的文件夹中pos。read_imdb_split函数将帮助我们读取和处理数据集。neg

from pathlib import Path
def read_imdb_split(split_dir):
split_dir = Path(split_dir)
texts = []
labels = []
for label_dir in ["pos", "neg"]:
for text_file in (split_dir/label_dir).iterdir():
texts.append(text_file.read_text())
labels.append(0 if label_dir is "neg" else 1)
return texts, labels
train_texts, train_labels = read_imdb_split('aclImdb/train')
test_texts, test_labels = read_imdb_split('aclImdb/test')

让我们使用 Scikit-learn 中的train_test_split实用程序从训练数据集创建训练和验证集。我们将使用验证数据集使用自定义度量函数来评估和调整我们的模型。

from sklearn.model_selection import train_test_split
train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)

3.预处理输入数据

下一步是对模型的输入数据进行预处理。与清理数据一起,标记化是任何 NLP 管道中的第一步。Tokenizer 将非结构化字符串转换为适合机器学习的数字数据结构。由于我们将使用预训练的 DistilBert,因此我们将使用 DistilBert 分词器进行分词。

from transformers import DistilBertTokenizerFast
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)

4.数据集对象

torch.utils.data.Dataset 是 Pytorch 中表示数据集的抽象类。我们的自定义数据集将继承Dataset并覆盖以下方法:

__len__以便len(dataset)返回数据集的大小。
__getitem__支持索引,这样dataset[i]可以用来获取第i个样本。

import torch
class IMDbDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
train_dataset = IMDbDataset(train_encodings, train_labels)
val_dataset = IMDbDataset(val_encodings, val_labels)
test_dataset = IMDbDataset(test_encodings, test_labels)






推荐阅读
  • 使用Python在SAE上开发新浪微博应用的初步探索
    最近重新审视了新浪云平台(SAE)提供的服务,发现其已支持Python开发。本文将详细介绍如何利用Django框架构建一个简单的新浪微博应用,并分享开发过程中的关键步骤。 ... [详细]
  • Scala 实现 UTF-8 编码属性文件读取与克隆
    本文介绍如何使用 Scala 以 UTF-8 编码方式读取属性文件,并实现属性文件的克隆功能。通过这种方式,可以确保配置文件在多线程环境下的一致性和高效性。 ... [详细]
  • 本文介绍了一种方法,通过使用Python的ctypes库来调用C++代码。具体实例为实现一个简单的加法器,并详细说明了从编写C++代码到编译及最终在Python中调用的全过程。 ... [详细]
  • PyCharm下载与安装指南
    本文详细介绍如何从官方渠道下载并安装PyCharm集成开发环境(IDE),涵盖Windows、macOS和Linux系统,同时提供详细的安装步骤及配置建议。 ... [详细]
  • 本文详细解析了Python中的os和sys模块,介绍了它们的功能、常用方法及其在实际编程中的应用。 ... [详细]
  • 本文探讨了图像标签的多种分类场景及其在以图搜图技术中的应用,涵盖了从基础理论到实际项目实施的全面解析。 ... [详细]
  • 如何用GPU服务器运行Python
    如何用GPU服务器运行Python-目录前言一、服务器登录1.1下载安装putty1.2putty远程登录 1.3查看GPU、显卡常用命令1.4Linux常用命令二、 ... [详细]
  • 图神经网络模型综述
    本文综述了图神经网络(Graph Neural Networks, GNN)的发展,从传统的数据存储模型转向图和动态模型,探讨了模型中的显性和隐性结构,并详细介绍了GNN的关键组件及其应用。 ... [详细]
  • 本文介绍了如何利用snownlp库对微博内容进行情感分析,包括安装、基本使用以及如何自定义训练模型以提高分析准确性。 ... [详细]
  • 整理于2020年10月下旬:总结过去,展望未来Itistoughtodayandtomorrowwillbetougher.butthedayaftertomorrowisbeau ... [详细]
  • 自然语言处理(NLP)——LDA模型:对电商购物评论进行情感分析
    目录一、2020数学建模美赛C题简介需求评价内容提供数据二、解题思路三、LDA简介四、代码实现1.数据预处理1.1剔除无用信息1.1.1剔除掉不需要的列1.1.2找出无效评论并剔除 ... [详细]
  • Java 中的 BigDecimal pow()方法,示例 ... [详细]
  • 本文探讨了如何在Python中处理长数据的完全显示问题,包括numpy数组、pandas DataFrame以及tensor类型的完整输出设置。 ... [详细]
  • 如何更换Anaconda和pip的国内镜像源
    本文详细介绍了如何通过国内多个知名镜像站(如北京外国语大学、中国科学技术大学、阿里巴巴等)更换Anaconda和pip的源,以提高软件包的下载速度和安装效率。 ... [详细]
  • 精选10款Python框架助力并行与分布式机器学习
    随着神经网络模型的不断深化和复杂化,训练这些模型变得愈发具有挑战性,不仅需要处理大量的权重,还必须克服内存限制等问题。本文将介绍10款优秀的Python框架,帮助开发者高效地实现分布式和并行化的深度学习模型训练。 ... [详细]
author-avatar
传奇gk_543
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有