热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Pytorch学习自动求梯度

2.3自动求梯度对函数计算梯度(gradient),Pytoch提供的autograd包能够根据输入和前向传播过程自动构建计算图,并执行反向传播。2.3.1概念Tensor是aut

2.3 自动求梯度

对函数计算梯度(gradient),Pytoch提供的autograd包能够根据输入和前向传播过程自动构建计算图,并执行反向传播。


2.3.1 概念

Tensor是autograd包的核心类

如果将其属性.requires_grad设置为True,它将开始追踪在其上的所有操作(这样就可以利用链式法则进行梯度传播)。

完成计算后,可以调用.backward()来完成所有梯度计算。此Tensor的梯度将累积到.grad属性中。

注:在y.backward()时,如果y是标量,则不需要为backward()传入任何参数;否则,需要传入一个与y同形的Tensor

如果不想继续追踪,调用.detach()将追踪记录分离出来(防止将来的计算被追踪,这样梯度就传不过去)。也可以使用with torch.no_grad()将不想被追踪的操作代码块包裹起来。(在评估模型时经常使用,因为在评估模型时,我们不需要计算可训练参数(requires_grad=True)的梯度)

Function是另外一个很重要的类。TensorFunction互相结合可以构建一个记录有整个计算过程的有向无环图(DAG)。每个Tensor都有一个.grad_fn属性,该属性即创建该TensorFunction。就是说该Tensor若是通过运算得到的,则grad_fn返回一个与这些运算相关的对象,否则是None。


2.3.2 Tensor

创建一个Tensor并设置requires_grad = True:

import torch

x = torch.ones(2,2,requires_grad = True)
print(x)
print(x.grad_fn)

tensor([[1., 1.],
[1., 1.]], requires_grad=True)
None

y = x + 2
print(y)
print(y.grad_fn)

tensor([[3., 3.],
[3., 3.]], grad_fn=)

注意x是直接创建的,所以返回的是None,而y是x进行了加法操作创建的,所以它有一个为的运算对象

像x这种直接创建的称为叶子节点,叶子节点对应的grad_fnNone

print(x.is_leaf,y.is_leaf)

True False

运算操作复杂化

z = y*y*3
out = z.mean()
print(z,out)

tensor([[27., 27.],
[27., 27.]], grad_fn=) tensor(27., grad_fn=)

通过.requires_grad_()来用in-place的方式改变requires_grad属性:

a = torch.randn(2,2) #缺失情况下默认 requires_grad = False
a = ((a*3)/(a-1))
print(a.requires_grad)
a.requires_grad_(True)
print(a.requires_grad)
b = (a*a).sum()
print(b.grad_fn)

False
True


2.3.2 梯度

因为out是一个标量,所以调用backward()时不需要指定求导变量:

out.backward() #等价于 out.backward(torch.tensor(1.))

out关于x的梯度 $\frac{d(out)}{dx}$

print(x.grad)

tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])

#再来反向传播一次,注意grad是累加的
out2 = x.sum()
out2.backward()
print(x.grad)
out3 = x.sum()
x.grad.data.zero_()
out3.backward()
print(x.grad)

tensor([[5.5000, 5.5000],
[5.5000, 5.5000]])
tensor([[1., 1.],
[1., 1.]])

举例:

x = torch.tensor([1.0,2.0,3.0,4.0],requires_grad=True)
y = 2*x
z = y.view(2,2)
print(z)

tensor([[2., 4.],
[6., 8.]], grad_fn=)

#现在 y 不是一个标量,所以在调用 backward 时需要传入一个和 y 同形的权重向量进行加权求和得到一个标量
v = torch.tensor([[1.0,0.1],[0.01,0.001]],dtype = torch.float)
z.backward(v)
print(x.grad)

tensor([2.0000, 0.2000, 0.0200, 0.0020])

注:x.grad是和x同形的张量

中断梯度追踪的例子:

x = torch.tensor(1.0,requires_grad = True)
y1 = x ** 2
with torch.no_grad():
y2 = x ** 3
y3 = y1+y2
print(x.requires_grad)
print(y1,y1.requires_grad)
print(y2,y2.requires_grad)
print(y3,y3.requires_grad)

True
tensor(1., grad_fn=

) True
tensor(1.) False
tensor(2., grad_fn=) True

y2是没有grad_fn而且y2.requires_grad = False,而y3有,接下来对y3求梯度

y3.backward()
print(x.grad)

tensor(2.)

想要修改tensor的数值又不希望被autograd记录(不影响反向传播),可以对tensor.data操作

x = torch.ones(1,requires_grad = True)
print(x.data) #还是一个tensor
print(x.data.requires_grad) #独立于计算图之外
y = 2*x
x.data *= 100 #只改变了值,不会记录在计算图,不影响梯度传播
y.backward()
print(x) #修改data会影响tensor值
print(x.grad)

tensor([1.])
False
tensor([100.], requires_grad=True)
tensor([2.])


推荐阅读
  • 尽管使用TensorFlow和PyTorch等成熟框架可以显著降低实现递归神经网络(RNN)的门槛,但对于初学者来说,理解其底层原理至关重要。本文将引导您使用NumPy从头构建一个用于自然语言处理(NLP)的RNN模型。 ... [详细]
  • 本文详细介绍了Akka中的BackoffSupervisor机制,探讨其在处理持久化失败和Actor重启时的应用。通过具体示例,展示了如何配置和使用BackoffSupervisor以实现更细粒度的异常处理。 ... [详细]
  • 本文探讨了如何在给定整数N的情况下,找到两个不同的整数a和b,使得它们的和最大,并且满足特定的数学条件。 ... [详细]
  • Redux入门指南
    本文介绍Redux的基本概念和工作原理,帮助初学者理解如何使用Redux管理应用程序的状态。Redux是一个用于JavaScript应用的状态管理库,特别适用于React项目。 ... [详细]
  • 本文详细介绍了如何使用 Yii2 的 GridView 组件在列表页面实现数据的直接编辑功能。通过具体的代码示例和步骤,帮助开发者快速掌握这一实用技巧。 ... [详细]
  • 前言--页数多了以后需要指定到某一页(只做了功能,样式没有细调)html ... [详细]
  • 本章将深入探讨移动 UI 设计的核心原则,帮助开发者构建简洁、高效且用户友好的界面。通过学习设计规则和用户体验优化技巧,您将能够创建出既美观又实用的移动应用。 ... [详细]
  • 本文详细解析了Python中的os和sys模块,介绍了它们的功能、常用方法及其在实际编程中的应用。 ... [详细]
  • 从 .NET 转 Java 的自学之路:IO 流基础篇
    本文详细介绍了 Java 中的 IO 流,包括字节流和字符流的基本概念及其操作方式。探讨了如何处理不同类型的文件数据,并结合编码机制确保字符数据的正确读写。同时,文中还涵盖了装饰设计模式的应用,以及多种常见的 IO 操作实例。 ... [详细]
  • 基因组浏览器中的Wig格式解析
    本文详细介绍了Wiggle(Wig)格式及其在基因组浏览器中的应用,涵盖variableStep和fixedStep两种主要格式的特点、适用场景及具体使用方法。同时,还提供了关于数据值和自定义参数的补充信息。 ... [详细]
  • 基于KVM的SRIOV直通配置及性能测试
    SRIOV介绍、VF直通配置,以及包转发率性能测试小慢哥的原创文章,欢迎转载目录?1.SRIOV介绍?2.环境说明?3.开启SRIOV?4.生成VF?5.VF ... [详细]
  • 在项目部署后,Node.js 进程可能会遇到不可预见的错误并崩溃。为了及时通知开发人员进行问题排查,我们可以利用 nodemailer 插件来发送邮件提醒。本文将详细介绍如何配置和使用 nodemailer 实现这一功能。 ... [详细]
  • 中科院学位论文排版指南
    随着毕业季的到来,许多即将毕业的学生开始撰写学位论文。本文介绍了使用LaTeX排版学位论文的方法,特别是针对中国科学院大学研究生学位论文撰写规范指导意见的最新要求。LaTeX以其精确的控制和美观的排版效果成为许多学者的首选。 ... [详细]
  • 本文介绍了如何在多线程环境中实现异步任务的事务控制,确保任务执行的一致性和可靠性。通过使用计数器和异常标记字段,系统能够准确判断所有异步线程的执行结果,并根据结果决定是否回滚或提交事务。 ... [详细]
  • 本文介绍了如何使用JavaScript的Fetch API与Express服务器进行交互,涵盖了GET、POST、PUT和DELETE请求的实现,并展示了如何处理JSON响应。 ... [详细]
author-avatar
周周微商互联
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有