热门标签 | HotTags
当前位置:  开发笔记 > 人工智能 > 正文

关于pytorch中网络loss传播和参数更新的理解

今天小编就为大家分享一篇关于pytorch中网络loss传播和参数更新的理解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56,但是pytorch的数量从87篇提升到了252篇。

TensorFlow: 228--->266

Keras: 42--->56

Pytorch: 87--->252

在使用pytorch中,自己有一些思考,如下:

1. loss计算和反向传播

import torch.nn as nn
 
criterion = nn.MSELoss().cuda()
 
output = model(input)
 
loss = criterion(output, target)
loss.backward()

通过定义损失函数:criterion,然后通过计算网络真实输出和真实标签之间的误差,得到网络的损失值:loss;

最后通过loss.backward()完成误差的反向传播,通过pytorch的内在机制完成自动求导得到每个参数的梯度。

需要注意,在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化或最大化,一般是通过梯度进行网络模型的参数更新,通过loss的计算和误差反向传播,我们得到网络中,每个参数的梯度值,后面我们再通过优化算法进行网络参数优化更新。

2. 网络参数更新

在更新网络参数时,我们需要选择一种调整模型参数更新的策略,即优化算法。

优化算法中,简单的有一阶优化算法:

其中 就是通常说的学习率, 是函数的梯度;

自己的理解是,对于复杂的优化算法,基本原理也是这样的,不过计算更加复杂。

在pytorch中,torch.optim是一个实现各种优化算法的包,可以直接通过这个包进行调用。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

注意:

1)在前面部分1中,已经通过loss的反向传播得到了每个参数的梯度,然后再本部分通过定义优化器(优化算法),确定了网络更新的方式,在上述代码中,我们将模型的需要更新的参数传入优化器。

2)注意优化器,即optimizer中,传入的模型更新的参数,对于网络中有多个模型的网络,我们可以选择需要更新的网络参数进行输入即可,上述代码,只会更新model中的模型参数。对于需要更新多个模型的参数的情况,可以参考以下代码:

optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': gru.parameters()}], lr=0.01) 3) 在优化前需要先将梯度归零,即optimizer.zeros()。

3. loss计算和参数更新

import torch.nn as nn
import torch
 
criterion = nn.MSELoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
output = model(input)
 
loss = criterion(output, target)
 
​optimizer.zero_grad() # 将所有参数的梯度都置零
loss.backward()    # 误差反向传播计算参数梯度
optimizer.step()    # 通过梯度做一步参数更新

以上这篇关于pytorch中网络loss传播和参数更新的理解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。


推荐阅读
  • 本文探讨了图像标签的多种分类场景及其在以图搜图技术中的应用,涵盖了从基础理论到实际项目实施的全面解析。 ... [详细]
  • 吴恩达推出TensorFlow实践课程,Python基础即可入门,四个月掌握核心技能
    量子位报道,deeplearning.ai最新发布了TensorFlow实践课程,适合希望使用TensorFlow开发AI应用的学习者。该课程涵盖机器学习模型构建、图像识别、自然语言处理及时间序列预测等多个方面。 ... [详细]
  • 本文详细介绍了 TensorFlow 的入门实践,特别是使用 MNIST 数据集进行数字识别的项目。文章首先解析了项目文件结构,并解释了各部分的作用,随后逐步讲解了如何通过 TensorFlow 实现基本的神经网络模型。 ... [详细]
  • 计算机视觉初学者指南:如何顺利入门
    本文旨在为计算机视觉领域的初学者提供一套全面的入门指南,涵盖基础知识、技术工具、学习资源等方面,帮助读者快速掌握计算机视觉的核心概念和技术。 ... [详细]
  • 机器学习中的相似度度量与模型优化
    本文探讨了机器学习中常见的相似度度量方法,包括余弦相似度、欧氏距离和马氏距离,并详细介绍了如何通过选择合适的模型复杂度和正则化来提高模型的泛化能力。此外,文章还涵盖了模型评估的各种方法和指标,以及不同分类器的工作原理和应用场景。 ... [详细]
  • 毕业设计:基于机器学习与深度学习的垃圾邮件(短信)分类算法实现
    本文详细介绍了如何使用机器学习和深度学习技术对垃圾邮件和短信进行分类。内容涵盖从数据集介绍、预处理、特征提取到模型训练与评估的完整流程,并提供了具体的代码示例和实验结果。 ... [详细]
  • 尽管深度学习带来了广泛的应用前景,其训练通常需要强大的计算资源。然而,并非所有开发者都能负担得起高性能服务器或专用硬件。本文探讨了如何在有限的硬件条件下(如ARM CPU)高效运行深度神经网络,特别是通过选择合适的工具和框架来加速模型推理。 ... [详细]
  • Coursera ML 机器学习
    2019独角兽企业重金招聘Python工程师标准线性回归算法计算过程CostFunction梯度下降算法多变量回归![选择特征](https:static.oschina.n ... [详细]
  • 机器学习核心概念与技术
    本文系统梳理了机器学习的关键知识点,涵盖模型评估、正则化、线性模型、支持向量机、决策树及集成学习等内容,并深入探讨了各算法的原理和应用场景。 ... [详细]
  • 解决PyCharm中安装PyTorch深度学习d2l包的问题
    本文详细介绍了如何在PyCharm中成功安装用于PyTorch深度学习的d2l包,包括环境配置、安装步骤及常见问题的解决方案。 ... [详细]
  • 2017年人工智能领域的十大里程碑事件回顾
    随着2018年的临近,我们一同回顾过去一年中人工智能领域的重要进展。这一年,无论是政策层面的支持,还是技术上的突破,都显示了人工智能发展的迅猛势头。以下是精选的2017年人工智能领域最具影响力的事件。 ... [详细]
  • 在Ubuntu 16.04中使用Anaconda安装TensorFlow
    本文详细介绍了如何在Ubuntu 16.04系统上通过Anaconda环境管理工具安装TensorFlow。首先,需要下载并安装Anaconda,然后配置环境变量以确保系统能够识别Anaconda命令。接着,创建一个特定的Python环境用于安装TensorFlow,并通过指定的镜像源加速安装过程。最后,通过一个简单的线性回归示例验证TensorFlow的安装是否成功。 ... [详细]
  • 强人工智能时代,区块链的角色与前景
    随着强人工智能的崛起,区块链技术在新的技术生态中扮演着怎样的角色?本文探讨了区块链与强人工智能之间的互补关系及其在未来技术发展中的重要性。 ... [详细]
  • 京东AI创新之路:周伯文解析京东AI战略的独特之处
    2018年4月15日,京东在北京举办了人工智能创新峰会,会上首次公开了京东AI的整体布局和发展方向。此次峰会不仅展示了京东在AI领域的最新成果,还标志着京东AI团队的首次集体亮相。本文将深入探讨京东AI的发展策略及其与BAT等公司的不同之处。 ... [详细]
  • 在中国医疗行业面临高度监管和市场垄断的背景下,医疗领域的创新面临诸多挑战。本文通过探讨技术变革与商业模式的结合,为医疗AI的未来发展提供了新的视角。 ... [详细]
author-avatar
訫鈊衅
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有