热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

一起读Bert文本分类代码(pytorch篇六)

Bert是去年google发布的新模型,打破了11项纪录,关于模型基础部分就不在这篇文章里多说了。这次想和大家一起读的是huggingface的pytorch-pretrained

Bert是去年google发布的新模型,打破了11项纪录,关于模型基础部分就不在这篇文章里多说了。这次想和大家一起读的是huggingface的pytorch-pretrained-BERT代码examples里的文本分类任务run_classifier。

关于源代码可以在huggingface的github中找到。

huggingface/pytorch-pretrained-BERTgithub.com《一起读Bert文本分类代码 (pytorch篇 六)》

在前五篇文章中我分别介绍了数据预处理部分和模型部分:

周剑:一起读Bert文本分类代码 (pytorch篇 一)zhuanlan.zhihu.com《一起读Bert文本分类代码 (pytorch篇 六)》
周剑:一起读Bert文本分类代码 (pytorch篇 二)zhuanlan.zhihu.com《一起读Bert文本分类代码 (pytorch篇 六)》
周剑:一起读Bert文本分类代码 (pytorch篇 三)zhuanlan.zhihu.com《一起读Bert文本分类代码 (pytorch篇 六)》
周剑:一起读Bert文本分类代码 (pytorch篇 四)zhuanlan.zhihu.com《一起读Bert文本分类代码 (pytorch篇 六)》
周剑:一起读Bert文本分类代码 (pytorch篇 五)zhuanlan.zhihu.com《一起读Bert文本分类代码 (pytorch篇 六)》

在这篇文章中,我会带大家继续读examples.run_classifier.py文件中的训练以及预测部分。

我们接着model继续往下读:

if args.fp16:
model.half()
model.to(device)
if args.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = DDP(model)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
# Prepare optimizer
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
if args.fp16:
try:
from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
optimizer = FusedAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
bias_correction=False,
max_grad_norm=1.0)
if args.loss_scale == 0:
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
else:
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
else:
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)

这段代码主要是处理单精度和双精度,把模型添加到cpu或者gpu中,定义优化函数的作用。

其中可以选择的优化函数有FusedAdam,它是NVIDIA开源面向精简混合精度和分布式训练的Pytorch扩展的优化函数,具体可以去NVIDIA的apex包下查看。还可以选择BertAdam,它是自定义Adam优化器,具体的代码可以在pytorch_pretrained_bert.optimization.py中找到。

继续往下看主函数,这段代码主要是训练和保存模型:

global_step = 0
nb_tr_steps = 0
tr_loss = 0
if args.do_train:
train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer)
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1:
train_sampler = RandomSampler(train_data)
else:
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train()
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
loss = model(input_ids, segment_ids, input_mask, label_ids)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step()
optimizer.zero_grad()
global_step += 1
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
if args.do_train:
torch.save(model_to_save.state_dict(), output_model_file)

这段代码的主要内容是执行前面定义数据与处理函数,将数据处理好后feed入模型中,自动计算backward并更新模型参数。

这里值得一提的地方有两个:第一,DistributedSampler是pytorch中将数据自动分发给多gpu的函数。pytorch在1.0版本后对多gpu训练进行了优化,现在已经很不错了。第二,在这段代码最后保存模型时,建议只保存模型参数,即模型本身。这样每次读取的时候就只需要读取模型参数,不需要从保存的文件中读取模型结构,能加快模型读取速度。但只保存模型参数的话,就需要自己先定义模型,然后读取参数。

继续读主函数,这段代码主要是预测用:

model_state_dict = torch.load(output_model_file)
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict, num_labels=num_labels)
model.to(device)
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = processor.get_dev_examples(args.data_dir)
eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer)
logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
# Run prediction for full data
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
model.eval()
eval_loss, eval_accuracy = 0, 0
nb_eval_steps, nb_eval_examples = 0, 0
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
with torch.no_grad():
tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
logits = model(input_ids, segment_ids, input_mask)
logits = logits.detach().cpu().numpy()
label_ids = label_ids.to('cpu').numpy()
tmp_eval_accuracy = accuracy(logits, label_ids)
eval_loss += tmp_eval_loss.mean().item()
eval_accuracy += tmp_eval_accuracy
nb_eval_examples += input_ids.size(0)
nb_eval_steps += 1
eval_loss = eval_loss / nb_eval_steps
eval_accuracy = eval_accuracy / nb_eval_examples
loss = tr_loss/nb_tr_steps if args.do_train else None
result = {'eval_loss': eval_loss,
'eval_accuracy': eval_accuracy,
'global_step': global_step,
'loss': loss}
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))

这段代码主要是读取模型并预测输出结果用。如果使用多gpu,还是会把数据计算自动分配到各个gpu上。但是,numpy()只能使用于cpu。所以,这里需要注意,如果使用gpu训练或者预测,之使用numpy记得转到cpu中。这段代码最后输出一个txt的预测文件。

这样就带大家读完了所有examples.run_classifier.py中的代码和其调用的Bert模型的代码。


推荐阅读
  • SpringBoot整合SpringSecurity+JWT实现单点登录
    SpringBoot整合SpringSecurity+JWT实现单点登录,Go语言社区,Golang程序员人脉社 ... [详细]
  • AstridDAO 专访:波卡稳定币黑马 BAI
    加入Pol ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • x86 linux的进程调度,x86体系结构下Linux2.6.26的进程调度和切换
    进程调度相关数据结构task_structtask_struct是进程在内核中对应的数据结构,它标识了进程的状态等各项信息。其中有一项thread_struct结构的 ... [详细]
  • 在Kubernetes上部署JupyterHub的步骤和实验依赖
    本文介绍了在Kubernetes上部署JupyterHub的步骤和实验所需的依赖,包括安装Docker和K8s,使用kubeadm进行安装,以及更新下载的镜像等。 ... [详细]
  • 学习SLAM的女生,很酷
    本文介绍了学习SLAM的女生的故事,她们选择SLAM作为研究方向,面临各种学习挑战,但坚持不懈,最终获得成功。文章鼓励未来想走科研道路的女生勇敢追求自己的梦想,同时提到了一位正在英国攻读硕士学位的女生与SLAM结缘的经历。 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 云原生边缘计算之KubeEdge简介及功能特点
    本文介绍了云原生边缘计算中的KubeEdge系统,该系统是一个开源系统,用于将容器化应用程序编排功能扩展到Edge的主机。它基于Kubernetes构建,并为网络应用程序提供基础架构支持。同时,KubeEdge具有离线模式、基于Kubernetes的节点、群集、应用程序和设备管理、资源优化等特点。此外,KubeEdge还支持跨平台工作,在私有、公共和混合云中都可以运行。同时,KubeEdge还提供数据管理和数据分析管道引擎的支持。最后,本文还介绍了KubeEdge系统生成证书的方法。 ... [详细]
  • 本文介绍了PhysioNet网站提供的生理信号处理工具箱WFDB Toolbox for Matlab的安装和使用方法。通过下载并添加到Matlab路径中或直接在Matlab中输入相关内容,即可完成安装。该工具箱提供了一系列函数,可以方便地处理生理信号数据。详细的安装和使用方法可以参考本文内容。 ... [详细]
  • 本文介绍了Perl的测试框架Test::Base,它是一个数据驱动的测试框架,可以自动进行单元测试,省去手工编写测试程序的麻烦。与Test::More完全兼容,使用方法简单。以plural函数为例,展示了Test::Base的使用方法。 ... [详细]
  • Linux如何安装Mongodb的详细步骤和注意事项
    本文介绍了Linux如何安装Mongodb的详细步骤和注意事项,同时介绍了Mongodb的特点和优势。Mongodb是一个开源的数据库,适用于各种规模的企业和各类应用程序。它具有灵活的数据模式和高性能的数据读写操作,能够提高企业的敏捷性和可扩展性。文章还提供了Mongodb的下载安装包地址。 ... [详细]
  • python中安装并使用redis相关的知识
    本文介绍了在python中安装并使用redis的相关知识,包括redis的数据缓存系统和支持的数据类型,以及在pycharm中安装redis模块和常用的字符串操作。 ... [详细]
  • 本文介绍了在Ubuntu 11.10 x64环境下安装Android开发环境的步骤,并提供了解决常见问题的方法。其中包括安装Eclipse的ADT插件、解决缺少GEF插件的问题以及解决无法找到'userdata.img'文件的问题。此外,还提供了相关插件和系统镜像的下载链接。 ... [详细]
  • C#设计模式之八装饰模式(Decorator Pattern)【结构型】
    一、引言今天我们要讲【结构型】设计模式的第三个模式,该模式是【装饰模式】,英文名称:DecoratorPattern。我第一次看到这个名称想到的是另外一个词语“装修”,我就说说我对“装修”的理 ... [详细]
  • Learning to Paint with Model-based Deep Reinforcement Learning
    本文介绍了一种基于模型的深度强化学习方法,通过结合神经渲染器,教机器像人类画家一样进行绘画。该方法能够生成笔画的坐标点、半径、透明度、颜色值等,以生成类似于给定目标图像的绘画。文章还讨论了该方法面临的挑战,包括绘制纹理丰富的图像等。通过对比实验的结果,作者证明了基于模型的深度强化学习方法相对于基于模型的DDPG和模型无关的DDPG方法的优势。该研究对于深度强化学习在绘画领域的应用具有重要意义。 ... [详细]
author-avatar
灵绾绾
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有