热门标签 | HotTags
当前位置:  开发笔记 > 运维 > 正文

pytorch查看模型weight与grad方式

这篇文章主要介绍了pytorch查看模型weight与grad方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

在用pdb debug的时候,有时候需要看一下特定layer的权重以及相应的梯度信息,如何查看呢?

1. 首先把你的模型打印出来,像这样

2. 然后观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了,在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad就可以查看梯度信息

补充知识:查看Pytorch网络的各层输出(feature map)、权重(weight)、偏置(bias)

BatchNorm2d参数量

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# 卷积层中卷积核的数量C 
num_features – C from an expected input of size (N, C, H, W)
>>> import torch
>>> m = torch.nn.BatchNorm2d(100)
>>> m.weight.shape
torch.Size([100])
>>> m.numel()
AttributeError: 'BatchNorm2d' object has no attribute 'numel'
>>> m.weight.numel()
100
>>> m.parameters().numel()
Traceback (most recent call last):
 File "", line 1, in 
AttributeError: 'generator' object has no attribute 'numel'
>>> [p.numel() for p in m.parameters()]
[100, 100]

linear层

>>> import torch
>>> m1 = torch.nn.Linear(100,10)
# 参数数量= (输入神经元+1)*输出神经元
>>> m1.weight.shape
torch.Size([10, 100])
>>> m1.bias.shape
torch.Size([10])
>>> m1.bias.numel()
10
>>> m1.weight.numel()
1000
>>> m11 = list(m1.parameters())
>>> m11[0].shape
# weight
torch.Size([10, 100])
>>> m11[1].shape
# bias
torch.Size([10])

weight and bias

# Method 1 查看Parameters的方式多样化,直接访问即可
model = alexnet(pretrained=True).to(device)
conv1_weight = model.features[0].weight# Method 2 
# 这种方式还适合你想自己参考一个预训练模型写一个网络,各层的参数不变,但网络结构上表述有所不同
# 这样你就可以把param迭代出来,赋给你的网络对应层,避免直接load不能匹配的问题!
for layer,param in model.state_dict().items(): # param is weight or bias(Tensor) 
 print layer,param

feature map

由于pytorch是动态网络,不存储计算数据,查看各层输出的特征图并不是很方便!分下面两种情况讨论:

1、你想查看的层是独立的,那么你在forward时用变量接收并返回即可!!

class Net(nn.Module):
  def __init__(self):
    self.conv1 = nn.Conv2d(1, 1, 3)
    self.conv2 = nn.Conv2d(1, 1, 3)
    self.conv3 = nn.Conv2d(1, 1, 3)  def forward(self, x):
    out1 = F.relu(self.conv1(x))
    out2 = F.relu(self.conv2(out1))
    out3 = F.relu(self.conv3(out2))
    return out1, out2, out3

2、你的想看的层在nn.Sequential()顺序容器中,这个麻烦些,主要有以下几种思路:

# Method 1 巧用nn.Module.children()
# 在模型实例化之后,利用nn.Module.children()删除你查看的那层的后面层
import torch
import torch.nn as nn
from torchvision import modelsmodel = models.alexnet(pretrained=True)# remove last fully-connected layer
new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
model.classifier = new_classifier
# Third convolutional layer
new_features = nn.Sequential(*list(model.features.children())[:5])
model.features = new_features
# Method 2 巧用hook,推荐使用这种方式,不用改变原有模型
# torch.nn.Module.register_forward_hook(hook)
# hook(module, input, output) -> NOnemodel= models.alexnet(pretrained=True)
# 定义
def hook (module,input,output):
  print output.size()
# 注册
handle = model.features[0].register_forward_hook(hook)
# 删除句柄
handle.remove()# torch.nn.Module.register_backward_hook(hook)
# hook(module, grad_input, grad_output) -> Tensor or None
model = alexnet(pretrained=True).to(device)
outputs = []
def hook (module,input,output):
  outputs.append(output)
  print len(outputs)handle = model.features[0].register_backward_hook(hook)

注:还可以通过定义一个提取特征的类,甚至是重构成各层独立相同模型将问题转化成第一种

计算模型参数数量

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

以上这篇pytorch查看模型weight与grad方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。


推荐阅读
  • 本文介绍如何在阿里云环境中利用 Docker 容器化技术部署一个简单的 Flask Web 应用,并确保其可通过互联网访问。内容涵盖 Python 代码编写、Dockerfile 配置、镜像构建及容器运行等步骤。 ... [详细]
  • 本文介绍了一种方法,通过使用Python的ctypes库来调用C++代码。具体实例为实现一个简单的加法器,并详细说明了从编写C++代码到编译及最终在Python中调用的全过程。 ... [详细]
  • 2017年软件开发领域的七大变革
    随着技术的不断进步,2017年对软件开发人员而言将充满挑战与机遇。本文探讨了开发人员需要适应的七个关键变化,包括人工智能、聊天机器人、容器技术、应用程序版本控制、云测试环境、大众开发者崛起以及系统管理的云迁移。 ... [详细]
  • 本文详细介绍如何在华为鲲鹏平台上构建和使用适配ARM架构的Redis Docker镜像,解决常见错误并提供优化建议。 ... [详细]
  • Flutter 核心技术与混合开发模式深入解析
    本文深入探讨了 Flutter 的核心技术,特别是其混合开发模式,包括统一管理模式和三端分离模式,以及混合栈原理。通过对比不同模式的优缺点,帮助开发者选择最适合项目的混合开发策略。 ... [详细]
  • 使用Echarts for Weixin 小程序实现中国地图及区域点击事件
    本文介绍了如何使用Echarts for Weixin在微信小程序中构建中国地图,并实现区域点击事件。包括效果展示、条件准备和逻辑实现的具体步骤。 ... [详细]
  • 将字符串中的嵌套列表转换回嵌套列表 ... [详细]
  • 本文将探讨如何在 Struts2 中使用 ActionContext 和 ServletActionContext 来获取请求参数和会话信息,同时解释它们的内部机制和最佳实践。 ... [详细]
  • RTThread线程间通信
    线程中通信在裸机编程中,经常会使用全局变量进行功能间的通信,如某些功能可能由于一些操作而改变全局变量的值,另一个功能对此全局变量进行读取& ... [详细]
  • 本文介绍了存储器的基本原理及其分类,包括不同类型的存储介质和存储方式,并详细解释了各种存储器的特点和应用场景。 ... [详细]
  • 本文总结了近年来在实际项目中使用消息中间件的经验和常见问题,旨在为Java初学者和中级开发者提供实用的参考。文章详细介绍了消息中间件在分布式系统中的作用,以及如何通过消息中间件实现高可用性和可扩展性。 ... [详细]
  • DirectShow Filter 开发指南
    本文总结了 DirectShow Filter 的开发经验,重点介绍了 Source Filter、In-Place Transform Filter 和 Render Filter 的实现方法。通过使用 DirectShow 提供的类,可以简化 Filter 的开发过程。 ... [详细]
  • 如何在DedeCMS专题页节点文档中调用自定义模型字段?
    在完成DedeCMS专题页节点文章列表样式的修改后,如果需要在列表中显示自定义模型的字段,由于DedeCMS默认不支持这一功能,因此需要进行一些二次开发。本文将详细介绍如何通过修改模板文件和核心文件来实现这一需求。 ... [详细]
  • 本文详细介绍了CSS中元素的显示模式,包括块元素、行内元素和行内块元素的特性和应用场景。 ... [详细]
  • 本文介绍了如何将Spring属性占位符与Jersey的@Path和@ApplicationPath注解结合使用,以便在资源路径中动态解析属性值。 ... [详细]
author-avatar
天涯使者2602921991
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有