热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

开发笔记:未雨绸缪:随手保存PyTorch训练模型

篇首语:本文由编程笔记#小编为大家整理,主要介绍了未雨绸缪:随手保存PyTorch训练模型相关的知识,希望对你有一定的参考价值。

篇首语:本文由编程笔记#小编为大家整理,主要介绍了未雨绸缪:随手保存 PyTorch 训练模型相关的知识,希望对你有一定的参考价值。


我们都知道,训练一个深度神经网络是需要挺长的时间的,即使是在高性能服务器上,有些训练也要持续几天之久。

不知道大家有没有遇到这种尴尬的情况:花了一天时间好不容易训练模型到 60% 啦,突然,机房要停电?学长要占用服务器?购买的 GPU 计算时间用完了等等。

怎么办??未雨绸缪:随手保存 PyTorch 训练模型

训练了一大半的模型不能功亏一篑呀,能不能把没训练完成的模型先保存下来,回头有机会了再加载接着训练?

那么今天我就给大家来介绍一个小技巧。

教你如何将未训练完成的 PyTorch 模型保存下来,而且不只是模型,训练过程中的优化器(optimizer),迭代数(epochs),以及正确率(score)等等,都可以以文件的形式保存下来,并在未来继续加载训练。

下面重点来啦!

(敲黑板)

我们利用 torch.save方法,使用这个方法来保存模型。

首先让我们封装一个用来保存模型的函数。

它有三个参数:

第一个参数就是我们要保存的模型状态以及优化器的状态,迭代次数等信息。



def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):

    torch.save(state, filename)

    if is_best:

        shutil.copyfile(filename, 'model_best.pth.tar')


然后我们用下面的方式来调用它:



save_checkpoint({

    'epoch': epoch + 1,

    'arch': args.arch,

    'state_dict': model.state_dict(),

    'best_prec1': best_prec1,

    'optimizer' : optimizer.state_dict(),

}, is_best)


这样我们就将模型保存为一个文件。

那要怎么重新加载呢?

像这样:



if args.resume:

    if os.path.isfile(args.resume):

        print("=> loading checkpoint '{}'".format(args.resume))

        # 通过参数指定要加载的模型文件名

        checkpoint = torch.load(args.resume)

        # 读取出保存的模型训练参数

        args.start_epoch = checkpoint['epoch']

        best_prec1 = checkpoint['best_prec1']

        # 重新加载模型训练进度

        model.load_state_dict(checkpoint['state_dict'])

        # 重新加载优化器进度

        optimizer.load_state_dict(checkpoint['optimizer'])

        print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))

    else:

        print("=> no checkpoint found at '{}'".format(args.resume))


这样就可以保存(save)&加载(load)训练中的 PyTorch 模型啦!善用SL(save&load)大法,不但不怕服务器突然掉链子,还能够把训练各个阶段的模型都保存下来,用于研究模型训练的各个步骤。




未雨绸缪:随手保存 PyTorch 训练模型




本节中的全部代码取自 PyTorch 的 ImageNet 官方代码,你可以在这里找到完整代码。


https://github.com/pytorch/examples/blob/master/imagenet/main.py#L139


本文编译自 PyTorch 官方论坛,原址:


https://discuss.pytorch.org/t/saving-and-loading-a-model-in-pytorch/2610/7






未雨绸缪:随手保存 PyTorch 训练模型





推荐阅读:


为什么他们要来集智AI学园学习 PyTorch?



为什么机器学习研究者都投入了 PyTorch 的怀抱?



重磅系列课:火炬上的深度学习(下)














获取更多更有趣的AI教程吧!


学园网站:campus.swarma.org









 商务合作|zhangqian@swarma.org     

投稿转载|wangjiannan@swarma.org








点击学习PyTorch



推荐阅读
author-avatar
CC_橙_CC
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有