热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch学习笔记之神经网络包nn和优化器optim

PyTorch学习笔记之神经网络包 nn 和优化器 optim


torch.nn 是专门为神经网络设计的模块化接口。构建于 Autograd 之上,可用来定义和运行神经网络。下面介绍几个常用的类:

注: torch.nn 为了方便使用,将它设置成 nn 的别名。

除了 nn 别名以外,我们还引用了 nn.functional,这个包中包含了神经网络中使用的一些常用的函数,这些函数的特点是,不具有可学习的参数(ReLu, pool, DropOut 等),这些函数可以放在构造函数中,也可以不放,但是这里建议不放。

一般情况下,我们将 nn.functional 设置为大写的 F, 这样缩写方便调用

PyTorch学习笔记之神经网络包 nn 和优化器 optim

 

  • 定义一个网络

PyTorch 中已经为我们准备好了现成的网络模型,只要继承 nn.module,并实现它的 forward 方法, PyTorch 会根据 autograd,自动实现 backward 函数,在 forward 函数中可使用任何 tensor 支持的函数,还可以使用 if、for 循环、print、 log 等 Python 语法,写法和标准的 Python 写法一致。

PyTorch学习笔记之神经网络包 nn 和优化器 optim

 

网络的可学习参数通过 net.parameters() 返回。

PyTorch学习笔记之神经网络包 nn 和优化器 optim

 

net.named_parameters 可以同时返回可学习的参数及名称。

PyTorch学习笔记之神经网络包 nn 和优化器 optim

 

forward 函数的输入和输出都是 Tensor。

PyTorch学习笔记之神经网络包 nn 和优化器 optim

 

在反向传播前,先要将所有参数的梯度清零。

PyTorch学习笔记之神经网络包 nn 和优化器 optim

 

注意 :torch. nn 只支持 mini-batches,不支持一次只输入一个样本,即一次必须是一个 batch。

也就是说,就算我们输入一个样本,也会对样本进行分批,所以,所有的输入都会增加一个维度,我们对比下刚才的nn中定义为3维,但是我们人工创建时多增加了一个维度,变为了4 维,最前面的 1 即为 batch-size。

 

  • 损失函数

在 nn 中 PyTorch 还预制了常用的损失函数,下面我们用 MSELOSS 用来计算均方误差

PyTorch学习笔记之神经网络包 nn 和优化器 optim

 

  • 优化器

在反向传播中计算完所有参数的梯度后,还需要使用优化方法来更新网络的权重和参数,例如随机梯度下降法(SGD)的更新策略如下:

 torch.optim中实现大多数的优化方法,例如:RMSProp、Adam、SGD 等,下面我们使用SGD 做个简单的样例:

PyTorch学习笔记之神经网络包 nn 和优化器 optim

 

 

 

 

 


推荐阅读
  • 本文介绍了lua语言中闭包的特性及其在模式匹配、日期处理、编译和模块化等方面的应用。lua中的闭包是严格遵循词法定界的第一类值,函数可以作为变量自由传递,也可以作为参数传递给其他函数。这些特性使得lua语言具有极大的灵活性,为程序开发带来了便利。 ... [详细]
  • Python正则表达式学习记录及常用方法
    本文记录了学习Python正则表达式的过程,介绍了re模块的常用方法re.search,并解释了rawstring的作用。正则表达式是一种方便检查字符串匹配模式的工具,通过本文的学习可以掌握Python中使用正则表达式的基本方法。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 闭包一直是Java社区中争论不断的话题,很多语言都支持闭包这个语言特性,闭包定义了一个依赖于外部环境的自由变量的函数,这个函数能够访问外部环境的变量。本文以JavaScript的一个闭包为例,介绍了闭包的定义和特性。 ... [详细]
  • 浏览器中的异常检测算法及其在深度学习中的应用
    本文介绍了在浏览器中进行异常检测的算法,包括统计学方法和机器学习方法,并探讨了异常检测在深度学习中的应用。异常检测在金融领域的信用卡欺诈、企业安全领域的非法入侵、IT运维中的设备维护时间点预测等方面具有广泛的应用。通过使用TensorFlow.js进行异常检测,可以实现对单变量和多变量异常的检测。统计学方法通过估计数据的分布概率来计算数据点的异常概率,而机器学习方法则通过训练数据来建立异常检测模型。 ... [详细]
  • 本文介绍了在Python张量流中使用make_merged_spec()方法合并设备规格对象的方法和语法,以及参数和返回值的说明,并提供了一个示例代码。 ... [详细]
  • Learning to Paint with Model-based Deep Reinforcement Learning
    本文介绍了一种基于模型的深度强化学习方法,通过结合神经渲染器,教机器像人类画家一样进行绘画。该方法能够生成笔画的坐标点、半径、透明度、颜色值等,以生成类似于给定目标图像的绘画。文章还讨论了该方法面临的挑战,包括绘制纹理丰富的图像等。通过对比实验的结果,作者证明了基于模型的深度强化学习方法相对于基于模型的DDPG和模型无关的DDPG方法的优势。该研究对于深度强化学习在绘画领域的应用具有重要意义。 ... [详细]
  • 本博文基于《Amalgamationofproteinsequence,structureandtextualinformationforimprovingprote ... [详细]
  • 【论文】ICLR 2020 九篇满分论文!!!
    点击上方,选择星标或置顶,每天给你送干货!阅读大概需要11分钟跟随小博主,每天进步一丢丢来自:深度学习技术前沿 ... [详细]
  • Opencv提供了几种分类器,例程里通过字符识别来进行说明的1、支持向量机(SVM):给定训练样本,支持向量机建立一个超平面作为决策平面,使得正例和反例之间的隔离边缘被最大化。函数原型:训练原型cv ... [详细]
  • 本人学习笔记,知识点均摘自于网络,用于学习和交流(如未注明出处,请提醒,将及时更正,谢谢)OS:我学习是为了上 ... [详细]
  • PHP图片截取方法及应用实例
    本文介绍了使用PHP动态切割JPEG图片的方法,并提供了应用实例,包括截取视频图、提取文章内容中的图片地址、裁切图片等问题。详细介绍了相关的PHP函数和参数的使用,以及图片切割的具体步骤。同时,还提供了一些注意事项和优化建议。通过本文的学习,读者可以掌握PHP图片截取的技巧,实现自己的需求。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • 本文介绍了多因子选股模型在实际中的构建步骤,包括风险源分析、因子筛选和体系构建,并进行了模拟实证回测。在风险源分析中,从宏观、行业、公司和特殊因素四个角度分析了影响资产价格的因素。具体包括宏观经济运行和宏经济政策对证券市场的影响,以及行业类型、行业生命周期和行业政策对股票价格的影响。 ... [详细]
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
author-avatar
x修者x
無限者:www.wuxianzhe.cn
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有