热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pytorch训练数据以及测试全部代码(2)

p{‘trainBatch’:6,nAveGrad:1,lr:1e-07,wd:0.0005,momentum:0.9,epoch_size:10,optimizer:SGD()}

p={‘trainBatch’:6, 'nAveGrad':1, 'lr':1e-07, 'wd':0.0005, 'momentum':0.9,'epoch_size':10, 'optimizer':'SGD()'}最后一个optimizer的值是很长的字符串就不全部写出来了。这个字典长度是7。

其中的net 和criterion在稍后来进行讲解

if resume_epoch==0,那么从头开始训练 training from scratch;否则权重的初始化时一个已经训练好的模型,使用net.load_state_dict函数,这个函数是在torch.nn.Module类里面定义的一个函数。

def load_state_dict(self, state_dict, strict=True):r"""Copies parameters and buffers from :attr:`state_dict` intothis module and its descendants. If :attr:`strict` is ``True``, thenthe keys of :attr:`state_dict` must exactly match the keys returnedby this module's :meth:`~torch.nn.Module.state_dict` function.Arguments:state_dict (dict): a dict containing parameters andpersistent buffers.strict (bool, optional): whether to strictly enforce that the keysin :attr:`state_dict` match the keys returned by this module's:meth:`~torch.nn.Module.state_dict` function. Default: ``True``"""missing_keys = []unexpected_keys = []error_msgs = []# copy state_dict so _load_from_state_dict can modify itmetadata = getattr(state_dict, '_metadata', None)state_dict = state_dict.copy()if metadata is not None:state_dict._metadata = metadatadef load(module, prefix=''):module._load_from_state_dict(state_dict, prefix, strict, missing_keys, unexpected_keys, error_msgs)for name, child in module._modules.items():if child is not None:load(child, prefix + name + '.')load(self)

而里面的torch.load函数定义如下.map_location参数有三种形式:函数,字符串,字典

def load(f, map_location=None, pickle_module=pickle):"""Loads an object saved with :func:`torch.save` from a file.:meth:`torch.load` uses Python's unpickling facilities but treats storages,which underlie tensors, specially. They are first deserialized on theCPU and are then moved to the device they were saved from. If this fails(e.g. because the run time system doesn't have certain devices), an exceptionis raised. However, storages can be dynamically remapped to an alternativeset of devices using the `map_location` argument.If `map_location` is a callable, it will be called once for each serializedstorage with two arguments: storage and location. The storage argumentwill be the initial deserialization of the storage, residing on the CPU.Each serialized storage has a location tag associated with it whichidentifies the device it was saved from, and this tag is the secondargument passed to map_location. The builtin location tags are `'cpu'` forCPU tensors and `'cuda:device_id'` (e.g. `'cuda:2'`) for CUDA tensors.`map_location` should return either None or a storage. If `map_location` returnsa storage, it will be used as the final deserialized object, already moved tothe right device. Otherwise, :math:`torch.load` will fall back to the defaultbehavior, as if `map_location` wasn't specified.If `map_location` is a string, it should be a device tag, where all tensorsshould be loaded.Otherwise, if `map_location` is a dict, it will be used to remap location tagsappearing in the file (keys), to ones that specify where to put thestorages (values).User extensions can register their own location tags and tagging anddeserialization methods using `register_package`.Args:f: a file-like object (has to implement read, readline, tell, and seek),or a string containing a file namemap_location: a function, string or a dict specifying how to remap storagelocationspickle_module: module used for unpickling metadata and objects (has tomatch the pickle_module used to serialize file)Example:>>> torch.load('tensors.pt')# Load all tensors onto the CPU>>> torch.load('tensors.pt', map_location='cpu')# Load all tensors onto the CPU, using a function>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)# Load all tensors onto GPU 1>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))# Map tensors from GPU 1 to GPU 0>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})# Load tensor from io.BytesIO object>>> with open('tensor.pt') as f:buffer = io.BytesIO(f.read())>>> torch.load(buffer)"""

设置使用GPU,这里是

torch.cuda.set_device(device=0)  告诉编码器cuda使用gpu0号

net.cuda() 将模型放在gpu0号上面

关于writer = SummaryWriter(log_dir=log_dir)这个函数在后面会讲解

num_img_tr = len(trainloader)# 1764
num_img_ts = len(testloader)# 242 这是batch数目


推荐阅读
  • 本文讨论了在openwrt-17.01版本中,mt7628设备上初始化启动时eth0的mac地址总是随机生成的问题。每次随机生成的eth0的mac地址都会写到/sys/class/net/eth0/address目录下,而openwrt-17.01原版的SDK会根据随机生成的eth0的mac地址再生成eth0.1、eth0.2等,生成后的mac地址会保存在/etc/config/network下。 ... [详细]
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • Java序列化对象传给PHP的方法及原理解析
    本文介绍了Java序列化对象传给PHP的方法及原理,包括Java对象传递的方式、序列化的方式、PHP中的序列化用法介绍、Java是否能反序列化PHP的数据、Java序列化的原理以及解决Java序列化中的问题。同时还解释了序列化的概念和作用,以及代码执行序列化所需要的权限。最后指出,序列化会将对象实例的所有字段都进行序列化,使得数据能够被表示为实例的序列化数据,但只有能够解释该格式的代码才能够确定数据的内容。 ... [详细]
  • 如何使用Java获取服务器硬件信息和磁盘负载率
    本文介绍了使用Java编程语言获取服务器硬件信息和磁盘负载率的方法。首先在远程服务器上搭建一个支持服务端语言的HTTP服务,并获取服务器的磁盘信息,并将结果输出。然后在本地使用JS编写一个AJAX脚本,远程请求服务端的程序,得到结果并展示给用户。其中还介绍了如何提取硬盘序列号的方法。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 本文详细介绍了Spring的JdbcTemplate的使用方法,包括执行存储过程、存储函数的call()方法,执行任何SQL语句的execute()方法,单个更新和批量更新的update()和batchUpdate()方法,以及单查和列表查询的query()和queryForXXX()方法。提供了经过测试的API供使用。 ... [详细]
  • 本文讨论了在iOS平台中的Metal框架中,对于if语句中的判断条件的限制和处理方式。作者提到了在Metal shader中,判断条件不能写得太长太复杂,否则可能导致程序停留或没有响应。作者还分享了自己的经验,建议在CPU端进行处理,以避免出现问题。 ... [详细]
  • 本文详细介绍了如何使用MySQL来显示SQL语句的执行时间,并通过MySQL Query Profiler获取CPU和内存使用量以及系统锁和表锁的时间。同时介绍了效能分析的三种方法:瓶颈分析、工作负载分析和基于比率的分析。 ... [详细]
  • 第四章高阶函数(参数传递、高阶函数、lambda表达式)(python进阶)的讲解和应用
    本文主要讲解了第四章高阶函数(参数传递、高阶函数、lambda表达式)的相关知识,包括函数参数传递机制和赋值机制、引用传递的概念和应用、默认参数的定义和使用等内容。同时介绍了高阶函数和lambda表达式的概念,并给出了一些实例代码进行演示。对于想要进一步提升python编程能力的读者来说,本文将是一个不错的学习资料。 ... [详细]
  • 本文讨论了一个数列求和问题,该数列按照一定规律生成。通过观察数列的规律,我们可以得出求解该问题的算法。具体算法为计算前n项i*f[i]的和,其中f[i]表示数列中有i个数字。根据参考的思路,我们可以将算法的时间复杂度控制在O(n),即计算到5e5即可满足1e9的要求。 ... [详细]
  • 感谢大家对IT十八掌大数据的支持,今天的作业如下:1.实践PreparedStament的CRUD操作。2.对比Statement和PreparedStatement的大批量操作耗时?(1 ... [详细]
  • 语义分割系列3SegNet(pytorch实现)
    SegNet手稿最早是在2015年12月投出,和FCN属于同时期作品。稍晚于FCN,既然属于后来者,又是与FCN同属于语义分割网络 ... [详细]
  • 上一章讲了如何制作数据集,接下来我们使用mmcls来实现多标签分类。 ... [详细]
  • pytorch Dropout过拟合的操作
    这篇文章主要介绍了pytorchDropout过拟合的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完 ... [详细]
  • 都会|可能会_###haohaohao###图神经网络之神器——PyTorch Geometric 上手 & 实战
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了###haohaohao###图神经网络之神器——PyTorchGeometric上手&实战相关的知识,希望对你有一定的参考价值。 ... [详细]
author-avatar
手机用户2502904457
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有