热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

libtorch学习笔记(11)保存和加载训练结果

保存和加载训练结果libtorch/pytorch提供了很好的

保存和加载训练结果

libtorch/pytorch提供了很好的serialize操作,可以很容把训练结果保存起来,最初我认为训练结果包括网络拓补结构,权重和偏置量等,后来发现应该只包含权重和偏置量,这样一来就需要在这个训练结果中存储一些扩展值,用于下一次重构网络。

保存网络权重和偏置量

torch::nn:Module提供了一个方法save方法,我们构建的网络一般会继承这个类,所以可以调用此方法存储网络权重和偏置量。

torch::serialize::OutputArchive archive;
save(archive);
archive.save_to(szTrainSetStateFilePath);

保存其他网络信息

只保存网络权重和偏置量是不够的,下次加载训练结果之前,还是需要先将Module注册好,然后再将网络权重和偏置量加载到当前网络的各个module中。比如对于VGG网络,需要保存哪种类型的VGG网络,需不需要在各个卷积层后面添加batchnorm层,是否需要使用32x32的小图片输入,还是224x224的大图片输入,以及最后输出多少个classes,当然也需要保存当前网络的所支持图片分类的标签。
从下面的代码可以看到如何将这些信息保存到训练结果中:

存储的关键字存储内容
VGG_labelsVGG支持的网络标签,比如0: 猫;1:狗
VGG_num_of_class最后输出的分类数,缺省是1000
VGG_configVGG网络类型,每个类型包含两个子类,带batchnorm和不带batchnorm
VGG_use_32x32_input使用32x32的小图片输入,还是224x224的大图片输入

int VGGNet::savenet(const char* szTrainSetStateFilePath)
{
// Save the net state to xxxx.pt and save the labels to xxxx.pt.label
char szLabel[MAX_LABEL_NAME] = { 0 };
try
{
torch::serialize::OutputArchive archive;
// Add nested archive here
c10::List<std::string> label_list;
for (size_t i = 0; i < m_image_labels.size(); i++)
{
memset(szLabel, 0, sizeof(szLabel));
WideCharToMultiByte(CP_UTF8, 0,
m_image_labels[i].c_str(), -1, szLabel, MAX_LABEL_NAME, NULL, NULL);
label_list.emplace_back((const char*)szLabel);
}
torch::IValue value(label_list);
archive.write("VGG_labels", label_list);
// also save the current network configuration
torch::IValue valNumClass(m_num_classes);
archive.write("VGG_num_of_class", valNumClass);
torch::IValue valNetConfig((int64_t)m_VGG_config);
archive.write("VGG_config", valNetConfig);
torch::IValue valUseSmallSize(m_use_32x32_input);
archive.write("VGG_use_32x32_input", valUseSmallSize);
save(archive);
archive.save_to(szTrainSetStateFilePath);
}
catch (...)
{
printf("Failed to save the trained VGG net state.\n");
return -1;
}
printf("Save the training result to %s.\n", szTrainSetStateFilePath);
return 0;
}

加载训练结果

从指定的预训练结果文档中,首先把分类标签载入,这就是当前训练好的网络所支持的多少种图像的分类,然后加载网络类型,这个主要用来构建网络拓补图,注册网络层模块,然后就是一些小的配置参数, 比如小图片还是大图片输入,网路最后输入的类数目等等。等到这些信息读取完毕后,就开始加载网络了,当网络的拓补结构,权重和偏置量张量都构建完毕后,再通过torch:nn::Module::load方法加载网络权重和偏置张量到网络各权重层中,这样一来网络就能恢复中训练后的状态,可以做分类、测试,甚至能基于之前训练结果再接着训练。
下面这段代码就是加载和还原上面保存下来的网络:

int VGGNet::loadnet(const char* szPreTrainSetStateFilePath)
{
wchar_t szLabel[MAX_LABEL_NAME] = { 0 };
try
{
torch::serialize::InputArchive archive;
archive.load_from(szPreTrainSetStateFilePath);
torch::IValue value;
if (archive.try_read("VGG_labels", value) && value.isList())
{
auto& label_list = value.toListRef();
for (size_t i = 0; i < label_list.size(); i++)
{
#ifdef _UNICODE
if (MultiByteToWideChar(CP_UTF8, 0,
label_list[i].toStringRef().c_str(), -1, szLabel, MAX_LABEL_NAME) <= 0)
m_image_labels.push_back(_T("Unknown"));
else
m_image_labels.push_back(szLabel);
#else
m_image_labels.push_back(label_list.get(i).toStringRef());
#endif
}
}
archive.read("VGG_num_of_class", value);
m_num_classes = (int)value.toInt();
archive.read("VGG_config", value);
m_VGG_config = (VGG_CONFIG)value.toInt();
m_bEnableBatchNorm = IS_BATCHNORM_ENABLED(m_VGG_config);
archive.read("VGG_use_32x32_input", value);
m_use_32x32_input = value.toBool();
m_imageprocessor.Init(m_use_32x32_input ? 32 : VGG_INPUT_IMG_WIDTH,
m_use_32x32_input ? 32 : VGG_INPUT_IMG_HEIGHT);
// Construct network layout,weight layers and so on
if (_Init() < 0)
{
printf("Failed to initialize VGG network {num_of_classes: %d, VGG config: %d, use_32x32_input: %s}.\n",
m_num_classes, m_VGG_config, m_use_32x32_input?"yes":"no");
return -1;
}
// Load the network state into the constructed neutral network
load(archive);
}
catch (...)
{
printf("Failed to load the pre-trained VGG net state.\n");
return -1;
}
return 0;
}

其他需要注意点

如果要接着之前的网络继续训练,这个时候需要检查之前网络训练结果和当前网络配置是否一致,如果不一致的话,需要停止训练,或者删除之前的训练结果重新训练网络。

代码和测试

对应的测试代码已经放到GitHub,这是一些基本用法:

  • 查看训练结构网络状态

VGGNet.exe state I:\catdog.pt

  • 训练网络,训练结构存放到I:\catdog.pt,如果已经存在在其基础上继续训练

VGGNet.exe train I:\CatDog I:\catdog.pt --bn -b 64 -l 0.0001 --showloss 10

在这里插入图片描述

  • 加载网络训练结构,验证测试集,得到准确率

VGGNet.exe verify I:\CatDog I:\catdog.pt

在这里插入图片描述

  • 网上下来一些图片,加载之前训练的网络,随机测试这个图片类型

VGGNet.exe classify I:\catdog.pt I:\test.png

在这里插入图片描述


推荐阅读
  • 前景:当UI一个查询条件为多项选择,或录入多个条件的时候,比如查询所有名称里面包含以下动态条件,需要模糊查询里面每一项时比如是这样一个数组条件:newstring[]{兴业银行, ... [详细]
  • vue使用
    关键词: ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • IhaveconfiguredanactionforaremotenotificationwhenitarrivestomyiOsapp.Iwanttwodiff ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • 本文介绍了UVALive6575题目Odd and Even Zeroes的解法,使用了数位dp和找规律的方法。阶乘的定义和性质被介绍,并给出了一些例子。其中,部分阶乘的尾零个数为奇数,部分为偶数。 ... [详细]
  • CF:3D City Model(小思维)问题解析和代码实现
    本文通过解析CF:3D City Model问题,介绍了问题的背景和要求,并给出了相应的代码实现。该问题涉及到在一个矩形的网格上建造城市的情景,每个网格单元可以作为建筑的基础,建筑由多个立方体叠加而成。文章详细讲解了问题的解决思路,并给出了相应的代码实现供读者参考。 ... [详细]
  • Redis底层数据结构之压缩列表的介绍及实现原理
    本文介绍了Redis底层数据结构之压缩列表的概念、实现原理以及使用场景。压缩列表是Redis为了节约内存而开发的一种顺序数据结构,由特殊编码的连续内存块组成。文章详细解释了压缩列表的构成和各个属性的含义,以及如何通过指针来计算表尾节点的地址。压缩列表适用于列表键和哈希键中只包含少量小整数值和短字符串的情况。通过使用压缩列表,可以有效减少内存占用,提升Redis的性能。 ... [详细]
  • 纠正网上的错误:自定义一个类叫java.lang.System/String的方法
    本文纠正了网上关于自定义一个类叫java.lang.System/String的错误答案,并详细解释了为什么这种方法是错误的。作者指出,虽然双亲委托机制确实可以阻止自定义的System类被加载,但通过自定义一个特殊的类加载器,可以绕过双亲委托机制,达到自定义System类的目的。作者呼吁读者对网上的内容持怀疑态度,并带着问题来阅读文章。 ... [详细]
  • 本文讨论了使用差分约束系统求解House Man跳跃问题的思路与方法。给定一组不同高度,要求从最低点跳跃到最高点,每次跳跃的距离不超过D,并且不能改变给定的顺序。通过建立差分约束系统,将问题转化为图的建立和查询距离的问题。文章详细介绍了建立约束条件的方法,并使用SPFA算法判环并输出结果。同时还讨论了建边方向和跳跃顺序的关系。 ... [详细]
  • Java学习笔记之面向对象编程(OOP)
    本文介绍了Java学习笔记中的面向对象编程(OOP)内容,包括OOP的三大特性(封装、继承、多态)和五大原则(单一职责原则、开放封闭原则、里式替换原则、依赖倒置原则)。通过学习OOP,可以提高代码复用性、拓展性和安全性。 ... [详细]
  • 本文讨论了clone的fork与pthread_create创建线程的不同之处。进程是一个指令执行流及其执行环境,其执行环境是一个系统资源的集合。在调用系统调用fork创建一个进程时,子进程只是完全复制父进程的资源,这样得到的子进程独立于父进程,具有良好的并发性。但是二者之间的通讯需要通过专门的通讯机制,另外通过fork创建子进程系统开销很大。因此,在某些情况下,使用clone或pthread_create创建线程可能更加高效。 ... [详细]
  • 本文介绍了如何使用Express App提供静态文件,同时提到了一些不需要使用的文件,如package.json和/.ssh/known_hosts,并解释了为什么app.get('*')无法捕获所有请求以及为什么app.use(express.static(__dirname))可能会提供不需要的文件。 ... [详细]
  • JDK源码学习之HashTable(附带面试题)的学习笔记
    本文介绍了JDK源码学习之HashTable(附带面试题)的学习笔记,包括HashTable的定义、数据类型、与HashMap的关系和区别。文章提供了干货,并附带了其他相关主题的学习笔记。 ... [详细]
author-avatar
手机用户2602919763
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有