热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

神经网络打印模型参数及参数名字和数量

神经网络打印模型参数及参数名字和数量在设计和优化神经网络模型性

神经网络打印模型参数及参数名字和数量

在设计和优化神经网络模型性能时,很多时候需要考虑模型的参数量和计算复杂度,下面一个栗子可以帮助我们快速查看模型的参数。
** 举个栗子,如有错误,欢迎大家批评指正 **
本文链接:神经网络打印模型参数及参数名字和数量
https://blog.csdn.net/leiduifan6944/article/details/103690228

exp:

import torch
from torch import nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3*4*4, 3*5*5)
self.conv1 = nn.Sequential(
nn.Conv2d(3, 4, 1, 1), # conv1.0
nn.BatchNorm2d(4), # conv1.1
nn.LeakyReLU(), # conv1.2
nn.Conv2d(4, 4, 3, 1), # conv1.3
nn.BatchNorm2d(4), # conv1.4
nn.LeakyReLU(), # conv1.5
)
self.fc2 = nn.Linear(4*3*3, 10)
def forward(self, entry):
entry = entry.reshape(-1, 3*4*4)
fc1_out = self.fc1(entry)
fc1_out = fc1_out.reshape(-1, 3, 5, 5)
conv1_out = self.conv1(fc1_out)
conv1_out = conv1_out.reshape(-1, 4*3*3)
fc2_out = self.fc2(conv1_out)
return fc2_out
if __name__ == '__main__':
x = torch.Tensor(2, 3, 4, 4)
net = Net()
out = net(x)
print('%14s : %s' % ('out.shape', out.shape))
print('---------------华丽丽的分隔线---------------')
# -------------方法1--------------
sum_ = 0
for name, param in net.named_parameters():
mul = 1
for size_ in param.shape:
mul *= size_ # 统计每层参数个数
sum_ += mul # 累加每层参数个数
print('%14s : %s' % (name, param.shape)) # 打印参数名和参数数量
# print('%s' % param) # 这样可以打印出参数,由于过多,我就不打印了
print('参数个数:', sum_) # 打印参数量

# -------------方法2--------------
for param in net.parameters():
print(param.shape)
# print(param)
# -------------方法3--------------
params = list(net.parameters())
for param in params:
print(param.shape)
# print(param)

以下是方法1的输出效果:

(方法2和方法3没贴出效果,个人比较喜欢用方法1,因为可以看到当前打印的是哪一层网络的参数)

out.shape : torch.Size([2, 10])
---------------华丽丽的分隔线---------------
fc1.weight : torch.Size([75, 48])
fc1.bias : torch.Size([75])
conv1.0.weight : torch.Size([4, 3, 1, 1])
conv1.0.bias : torch.Size([4])
conv1.1.weight : torch.Size([4])
conv1.1.bias : torch.Size([4])
conv1.3.weight : torch.Size([4, 4, 3, 3])
conv1.3.bias : torch.Size([4])
conv1.4.weight : torch.Size([4])
conv1.4.bias : torch.Size([4])
fc2.weight : torch.Size([10, 36])
fc2.bias : torch.Size([10])
参数个数: 4225

推荐阅读
  • 深入浅出TensorFlow数据读写机制
    本文详细介绍TensorFlow中的数据读写操作,包括TFRecord文件的创建与读取,以及数据集(dataset)的相关概念和使用方法。 ... [详细]
  • golang常用库:配置文件解析库/管理工具viper使用
    golang常用库:配置文件解析库管理工具-viper使用-一、viper简介viper配置管理解析库,是由大神SteveFrancia开发,他在google领导着golang的 ... [详细]
  • Explore how Matterverse is redefining the metaverse experience, creating immersive and meaningful virtual environments that foster genuine connections and economic opportunities. ... [详细]
  • Java 中 Writer flush()方法,示例 ... [详细]
  • Java 中的 BigDecimal pow()方法,示例 ... [详细]
  • 1:有如下一段程序:packagea.b.c;publicclassTest{privatestaticinti0;publicintgetNext(){return ... [详细]
  • 本文详细介绍了如何在Linux系统上安装和配置Smokeping,以实现对网络链路质量的实时监控。通过详细的步骤和必要的依赖包安装,确保用户能够顺利完成部署并优化其网络性能监控。 ... [详细]
  • 本文介绍了Java并发库中的阻塞队列(BlockingQueue)及其典型应用场景。通过具体实例,展示了如何利用LinkedBlockingQueue实现线程间高效、安全的数据传递,并结合线程池和原子类优化性能。 ... [详细]
  • 1.如何在运行状态查看源代码?查看函数的源代码,我们通常会使用IDE来完成。比如在PyCharm中,你可以Ctrl+鼠标点击进入函数的源代码。那如果没有IDE呢?当我们想使用一个函 ... [详细]
  • 主要用了2个类来实现的,话不多说,直接看运行结果,然后在奉上源代码1.Index.javaimportjava.awt.Color;im ... [详细]
  • 本文详细介绍了 Dockerfile 的编写方法及其在网络配置中的应用,涵盖基础指令、镜像构建与发布流程,并深入探讨了 Docker 的默认网络、容器互联及自定义网络的实现。 ... [详细]
  • 本文详细介绍了Akka中的BackoffSupervisor机制,探讨其在处理持久化失败和Actor重启时的应用。通过具体示例,展示了如何配置和使用BackoffSupervisor以实现更细粒度的异常处理。 ... [详细]
  • 将Web服务部署到Tomcat
    本文介绍了如何在JDeveloper 12c中创建一个Java项目,并将其打包为Web服务,然后部署到Tomcat服务器。内容涵盖从项目创建、编写Web服务代码、配置相关XML文件到最终的本地部署和验证。 ... [详细]
  • 本文介绍了如何在C#中启动一个应用程序,并通过枚举窗口来获取其主窗口句柄。当使用Process类启动程序时,我们通常只能获得进程的句柄,而主窗口句柄可能为0。因此,我们需要使用API函数和回调机制来准确获取主窗口句柄。 ... [详细]
  • 本文详细解析了Python中的os和sys模块,介绍了它们的功能、常用方法及其在实际编程中的应用。 ... [详细]
author-avatar
nzl
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有