热门标签 | HotTags
当前位置:  开发笔记 > 前端 > 正文

pytorch获取非叶子节点的grad

参考url: https:mathpretty.com12509.html在调试过程中,有时候我们需要对中间变量梯度进行监控,以确保网络的有效性,这个时候我们需要打印出非叶节点的梯

参考url: https://mathpretty.com/12509.html

在调试过程中, 有时候我们需要对中间变量梯度进行监控, 以确保网络的有效性, 这个时候我们需要打印出非叶节点的梯度, 为了实现这个目的, 我们可以通过两种手段进行, 分别是:



  • retain_grad()

  • hook


retain_grad()

retain_grad()显式地保存非叶节点的梯度, 当然代价就是会增加显存的消耗(对比hook函数的方法则是在反向计算时直接打印, 因此不会增加显存消耗.)

def forwrad(x, y, w1, w2):
# 其中 x,y 为输入数据,w为该函数所需要的参数
z_1 = torch.mm(w1, x)
z_1.retain_grad()
y_1
= torch.sigmoid(z_1)
y_1.retain_grad()
z_2
= torch.mm(w2, y_1)
z_2.retain_grad()
y_2
= torch.sigmoid(z_2)
y_2.retain_grad()
loss
= 1/2*(((y_2 - y)**2).sum())
loss.retain_grad()
return loss, z_1, y_1, z_2, y_2
# 测试代码
x = torch.tensor([[1.0]])
y
= torch.tensor([[1.0], [0.0]])
w1
= torch.tensor([[1.0], [2.0]], requires_grad=True)
w2
= torch.tensor([[3.0, 4.0], [5.0, 6.0]], requires_grad=True)
# w2 = torch.tensor([[3.0, 1.0], [1.0, 6.0]], requires_grad=True)
#
正向
loss, z_1, y_1, z_2, y_2 = forwrad(x, y, w1, w2)
# 反向
loss.backward() # 反向传播,计算梯度
print(loss.grad)
print(y_2.grad)
print(z_2.grad)

hook的使用

使用retain_grad会消耗额外的显存, 我们可以使用hook在反向计算的时候进行保存. 还是上面的例子, 我们使用hook来完成.

# 我们可以定义一个hook来保存中间的变量
grads = {} # 存储节点名称与节点的grad
def save_grad(name):
def hook(grad):
grads[name]
= grad
return hook
def forwrad(x, y, w1, w2):
# 其中 x,y 为输入数据,w为该函数所需要的参数
z_1 = torch.mm(w1, x)
y_1
= torch.sigmoid(z_1)
z_2
= torch.mm(w2, y_1)
y_2
= torch.sigmoid(z_2)
loss
= 1/2*(((y_2 - y)**2).sum())
return loss, z_1, y_1, z_2, y_2
# 测试代码
x = torch.tensor([[1.0]])
y
= torch.tensor([[1.0], [0.0]])
w1
= torch.tensor([[1.0], [2.0]], requires_grad=True)
w2
= torch.tensor([[3.0, 4.0], [5.0, 6.0]], requires_grad=True)
# 正向传播
loss, z_1, y_1, z_2, y_2 = forwrad(x, y, w1, w2)
# hook中间节点
z_1.register_hook(save_grad('z_1'))
y_1.register_hook(save_grad(
'y_1'))
z_2.register_hook(save_grad(
'z_2'))
y_2.register_hook(save_grad(
'y_2'))
loss.register_hook(save_grad(
'loss'))
# 反向传播
loss.backward()
print(grads['z_1'])
print(grads['y_1'])
print(grads['z_2'])
print(grads['y_2'])
print(grads['loss'])

 



推荐阅读
  • 实验六团队作业2:团队项目选题实验时间2019-4-18(19)Deadline:2019-4-2410:00,以团队随笔博文提交至班级博 ... [详细]
  • 腾讯投资十勇士,后者致力于精品游戏研发
    企查查APP显示,近日,浙江十勇士网络科技有限公司发生工商变更,新增股东广西腾讯创业投资有限公司,同时公司注册资本由1319.44万元人民币增加至1552.28万元人民币。企查查信 ... [详细]
  • 普通调用https:www.cnblogs.comYogurshinep3913073.htmlhttps:zhidao.baidu.comquestion531286375.h ... [详细]
  • 请写出一下程序的输出内容***2018032122:02:03**Brief:**Author:ZhangJianWei**Email:Dream_Dog163.com* ... [详细]
  • 原文链接https://www.cnblogs.com/zhouzhendong/p/9161514.html ... [详细]
  • WebBrowser控件(1)
    WindowsPhone7内置了一个强大的网络浏览器,该浏览器的内核是基于桌面版的InternetExplorer7(Mango版基于InternetE ... [详细]
  • 1.安装brewinstallnginx(需要安装homebrew)2.执行nginx直接启动nginx服务3.nginx-sreloadstop4.配 ... [详细]
  • 存一下吧……以后有用……http:blog.csdn.netclove_uniquearticledetails50630280转载于:https:www.cnblogs.comk ... [详细]
  • 深度学习中的Vision Transformer (ViT)详解
    本文详细介绍了深度学习中的Vision Transformer (ViT)方法。首先介绍了相关工作和ViT的基本原理,包括图像块嵌入、可学习的嵌入、位置嵌入和Transformer编码器等。接着讨论了ViT的张量维度变化、归纳偏置与混合架构、微调及更高分辨率等方面。最后给出了实验结果和相关代码的链接。本文的研究表明,对于CV任务,直接应用纯Transformer架构于图像块序列是可行的,无需依赖于卷积网络。 ... [详细]
  • 本文介绍了腾讯最近开源的BERT推理模型TurboTransformers,该模型在推理速度上比PyTorch快1~4倍。TurboTransformers采用了分层设计的思想,通过简化问题和加速开发,实现了快速推理能力。同时,文章还探讨了PyTorch在中间层延迟和深度神经网络中存在的问题,并提出了合并计算的解决方案。 ... [详细]
  • mapreduce数据去重的实现方法
    本文介绍了利用mapreduce实现数据去重的方法,同时还介绍了人工智能AI领域中常用的框架和工具,包括Keras、PyTorch、MXNet、TensorFlow和PaddlePaddle,并提供了深度学习实战的代码下载链接。 ... [详细]
  • 【论文】ICLR 2020 九篇满分论文!!!
    点击上方,选择星标或置顶,每天给你送干货!阅读大概需要11分钟跟随小博主,每天进步一丢丢来自:深度学习技术前沿 ... [详细]
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • S3D算法详解
    S3D论文详解论文地址:RethinkingSpatiotemporalFeatureLearning:Speed-AccuracyTrade-offsinVide ... [详细]
  • navicat生成er图_实践案例丨ACL2020 KBQA 基于查询图生成回答多跳复杂问题
    摘要:目前复杂问题包括两种:含约束的问题和多跳关系问题。本文对ACL2020KBQA基于查询图生成的方法来回答多跳复杂问题这一论文工作进行了解读 ... [详细]
author-avatar
mobiledu2502903113
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有