热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Pytorch学习自动求梯度

2.3自动求梯度对函数计算梯度(gradient),Pytoch提供的autograd包能够根据输入和前向传播过程自动构建计算图,并执行反向传播。2.3.1概念Tensor是aut

2.3 自动求梯度

对函数计算梯度(gradient),Pytoch提供的autograd包能够根据输入和前向传播过程自动构建计算图,并执行反向传播。


2.3.1 概念

Tensor是autograd包的核心类

如果将其属性.requires_grad设置为True,它将开始追踪在其上的所有操作(这样就可以利用链式法则进行梯度传播)。

完成计算后,可以调用.backward()来完成所有梯度计算。此Tensor的梯度将累积到.grad属性中。

注:在y.backward()时,如果y是标量,则不需要为backward()传入任何参数;否则,需要传入一个与y同形的Tensor

如果不想继续追踪,调用.detach()将追踪记录分离出来(防止将来的计算被追踪,这样梯度就传不过去)。也可以使用with torch.no_grad()将不想被追踪的操作代码块包裹起来。(在评估模型时经常使用,因为在评估模型时,我们不需要计算可训练参数(requires_grad=True)的梯度)

Function是另外一个很重要的类。TensorFunction互相结合可以构建一个记录有整个计算过程的有向无环图(DAG)。每个Tensor都有一个.grad_fn属性,该属性即创建该TensorFunction。就是说该Tensor若是通过运算得到的,则grad_fn返回一个与这些运算相关的对象,否则是None。


2.3.2 Tensor

创建一个Tensor并设置requires_grad = True:

import torch

x = torch.ones(2,2,requires_grad = True)
print(x)
print(x.grad_fn)

tensor([[1., 1.],
[1., 1.]], requires_grad=True)
None

y = x + 2
print(y)
print(y.grad_fn)

tensor([[3., 3.],
[3., 3.]], grad_fn=)

注意x是直接创建的,所以返回的是None,而y是x进行了加法操作创建的,所以它有一个为的运算对象

像x这种直接创建的称为叶子节点,叶子节点对应的grad_fnNone

print(x.is_leaf,y.is_leaf)

True False

运算操作复杂化

z = y*y*3
out = z.mean()
print(z,out)

tensor([[27., 27.],
[27., 27.]], grad_fn=) tensor(27., grad_fn=)

通过.requires_grad_()来用in-place的方式改变requires_grad属性:

a = torch.randn(2,2) #缺失情况下默认 requires_grad = False
a = ((a*3)/(a-1))
print(a.requires_grad)
a.requires_grad_(True)
print(a.requires_grad)
b = (a*a).sum()
print(b.grad_fn)

False
True


2.3.2 梯度

因为out是一个标量,所以调用backward()时不需要指定求导变量:

out.backward() #等价于 out.backward(torch.tensor(1.))

out关于x的梯度 $\frac{d(out)}{dx}$

print(x.grad)

tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])

#再来反向传播一次,注意grad是累加的
out2 = x.sum()
out2.backward()
print(x.grad)
out3 = x.sum()
x.grad.data.zero_()
out3.backward()
print(x.grad)

tensor([[5.5000, 5.5000],
[5.5000, 5.5000]])
tensor([[1., 1.],
[1., 1.]])

举例:

x = torch.tensor([1.0,2.0,3.0,4.0],requires_grad=True)
y = 2*x
z = y.view(2,2)
print(z)

tensor([[2., 4.],
[6., 8.]], grad_fn=)

#现在 y 不是一个标量,所以在调用 backward 时需要传入一个和 y 同形的权重向量进行加权求和得到一个标量
v = torch.tensor([[1.0,0.1],[0.01,0.001]],dtype = torch.float)
z.backward(v)
print(x.grad)

tensor([2.0000, 0.2000, 0.0200, 0.0020])

注:x.grad是和x同形的张量

中断梯度追踪的例子:

x = torch.tensor(1.0,requires_grad = True)
y1 = x ** 2
with torch.no_grad():
y2 = x ** 3
y3 = y1+y2
print(x.requires_grad)
print(y1,y1.requires_grad)
print(y2,y2.requires_grad)
print(y3,y3.requires_grad)

True
tensor(1., grad_fn=

) True
tensor(1.) False
tensor(2., grad_fn=) True

y2是没有grad_fn而且y2.requires_grad = False,而y3有,接下来对y3求梯度

y3.backward()
print(x.grad)

tensor(2.)

想要修改tensor的数值又不希望被autograd记录(不影响反向传播),可以对tensor.data操作

x = torch.ones(1,requires_grad = True)
print(x.data) #还是一个tensor
print(x.data.requires_grad) #独立于计算图之外
y = 2*x
x.data *= 100 #只改变了值,不会记录在计算图,不影响梯度传播
y.backward()
print(x) #修改data会影响tensor值
print(x.grad)

tensor([1.])
False
tensor([100.], requires_grad=True)
tensor([2.])


推荐阅读
  • 目录预备知识导包构建数据集神经网络结构训练测试精度可视化计算模型精度损失可视化输出网络结构信息训练神经网络定义参数载入数据载入神经网络结构、损失及优化训练及测试损失、精度可视化qu ... [详细]
  • Exploring issues and solutions when defining multiple Faust agents programmatically. ... [详细]
  • 在AngularJS中,有时需要在表单内包含某些控件,但又不希望这些控件导致表单变为脏状态。例如,当用户对表单进行修改后,表单的$dirty属性将变为true,触发保存对话框。然而,对于一些导航或辅助功能控件,我们可能并不希望它们触发这种行为。 ... [详细]
  • iOS如何实现手势
    这篇文章主要为大家展示了“iOS如何实现手势”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“iOS ... [详细]
  • 本文旨在探讨Swift中的Closure与Objective-C中的Block之间的区别与联系,通过定义、使用方式以及外部变量捕获等方面的比较,帮助开发者更好地理解这两种机制的特点及应用场景。 ... [详细]
  • 长期从事ABAP开发工作的专业人士,在面对行业新趋势时,往往需要重新审视自己的发展方向。本文探讨了几位资深专家对ABAP未来走向的看法,以及开发者应如何调整技能以适应新的技术环境。 ... [详细]
  • 小编给大家分享一下Vue3中如何提高开发效率,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获, ... [详细]
  • 使用 Babylon.js 实现地球模型与切片地图交互(第三部分)
    本文继续探讨在上一章节中构建的地球模型基础上,如何通过自定义的 `CameraEarthWheelControl` 类来实现更精细的地图缩放控制。我们将深入解析该类的实现细节,并展示其在实际项目中的应用。 ... [详细]
  • 探索CNN的可视化技术
    神经网络的可视化在理论学习与实践应用中扮演着至关重要的角色。本文深入探讨了三种有效的CNN(卷积神经网络)可视化方法,旨在帮助读者更好地理解和优化模型。 ... [详细]
  • egg实现登录鉴权(七):权限管理
    权限管理包含三部分:访问页面的权限,操作功能的权限和获取数据权限。页面权限:登录用户所属角色的可访问页面的权限功能权限:登录用户所属角色的可访问页面的操作权限数据权限:登录用户所属 ... [详细]
  • 本文探讨了异步编程的发展历程,从最初的AJAX异步回调到现代的Promise、Generator+Co以及Async/Await等技术。文章详细分析了Promise的工作原理及其源码实现,帮助开发者更好地理解和使用这一重要工具。 ... [详细]
  • 本文详细介绍如何在 Apache 中设置虚拟主机,包括基本配置和高级设置,帮助用户更好地理解和使用虚拟主机功能。 ... [详细]
  • 函子(Functor)是函数式编程中的一个重要概念,它不仅是一个特殊的容器,还提供了一种优雅的方式来处理值和函数。本文将详细介绍函子的基本概念及其在函数式编程中的应用,包括如何通过函子控制副作用、处理异常以及进行异步操作。 ... [详细]
  • Logging all MySQL queries into the Slow Log
    MySQLoptionallylogsslowqueriesintotheSlowQueryLog–orjustSlowLog,asfriendscallit.However,Thereareseveralreasonstologallqueries.Thislistisnotexhaustive:Belowyoucanfindthevariablestochange,astheyshouldbewritteninth ... [详细]
  • 在Qt框架中,信号与槽机制是一种独特的组件间通信方式。本文探讨了这一机制相较于传统的C风格回调函数所具有的优势,并分析了其潜在的不足之处。 ... [详细]
author-avatar
周周微商互联
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有