热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

使用torchsummary打印torch每层形状

Keras是一个由Python编写的开源人工神经网络库,Keras包含一个简洁的API接口来呈现出你的模型的样子,这在debug过程中是非常有用的。这里有一段模仿pytorch的代

Keras是一个由Python编写的开源人工神经网络库,Keras包含一个简洁的API接口来呈现出你的模型的样子,这在debug过程中是非常有用的。这里有一段模仿pytorch的代码,It Is summary(), 目标就是提供完备的信息以补充 print(your_model) 的不足。

作者:sksq96

git地址:https://github.com/sksq96/pytorch-summary

 

安装:

pip install torchsummary

或者

git clone https://github.com/sksq96/pytorch-summary

使用范例:

from torchsummary import summary
summary(your_model, input_size
=(channels, H, W))

注意,input_size是建立一个前向传播的网络

CNN for MNIST 

1 import torch
2 import torch.nn as nn
3 import torch.nn.functional as F
4 from torchsummary import summary
5
6 class Net(nn.Module):
7 def __init__(self):
8 super(Net, self).__init__()
9 self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
10 self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
11 self.conv2_drop = nn.Dropout2d()
12 self.fc1 = nn.Linear(320, 50)
13 self.fc2 = nn.Linear(50, 10)
14
15 def forward(self, x):
16 x = F.relu(F.max_pool2d(self.conv1(x), 2))
17 x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
18 x = x.view(-1, 320)
19 x = F.relu(self.fc1(x))
20 x = F.dropout(x, training=self.training)
21 x = self.fc2(x)
22 return F.log_softmax(x, dim=1)
23
24 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
25 model = Net().to(device)
26
27 summary(model, (1, 28, 28))
28
29 >>>>>:
30 ----------------------------------------------------------------
31 Layer (type) Output Shape Param #
32 ================================================================
33 Conv2d-1 [-1, 10, 24, 24] 260
34 Conv2d-2 [-1, 20, 8, 8] 5,020
35 Dropout2d-3 [-1, 20, 8, 8] 0
36 Linear-4 [-1, 50] 16,050
37 Linear-5 [-1, 10] 510
38 ================================================================
39 Total params: 21,840
40 Trainable params: 21,840
41 Non-trainable params: 0
42 ----------------------------------------------------------------
43 Input size (MB): 0.00
44 Forward/backward pass size (MB): 0.06
45 Params size (MB): 0.08
46 Estimated Total Size (MB): 0.15
47 ----------------------------------------------------------------

VGG16

1 import torch
2 from torchvision import models
3 from torchsummary import summary
4
5 device = torch.device(cuda if torch.cuda.is_available() else cpu)
6 vgg = models.vgg16().to(device)
7
8 summary(vgg, (3, 224, 224))
9
10 >>>>>:
11 ----------------------------------------------------------------
12 Layer (type) Output Shape Param #
13 ================================================================
14 Conv2d-1 [-1, 64, 224, 224] 1,792
15 ReLU-2 [-1, 64, 224, 224] 0
16 Conv2d-3 [-1, 64, 224, 224] 36,928
17 ReLU-4 [-1, 64, 224, 224] 0
18 MaxPool2d-5 [-1, 64, 112, 112] 0
19 Conv2d-6 [-1, 128, 112, 112] 73,856
20 ReLU-7 [-1, 128, 112, 112] 0
21 Conv2d-8 [-1, 128, 112, 112] 147,584
22 ReLU-9 [-1, 128, 112, 112] 0
23 MaxPool2d-10 [-1, 128, 56, 56] 0
24 Conv2d-11 [-1, 256, 56, 56] 295,168
25 ReLU-12 [-1, 256, 56, 56] 0
26 Conv2d-13 [-1, 256, 56, 56] 590,080
27 ReLU-14 [-1, 256, 56, 56] 0
28 Conv2d-15 [-1, 256, 56, 56] 590,080
29 ReLU-16 [-1, 256, 56, 56] 0
30 MaxPool2d-17 [-1, 256, 28, 28] 0
31 Conv2d-18 [-1, 512, 28, 28] 1,180,160
32 ReLU-19 [-1, 512, 28, 28] 0
33 Conv2d-20 [-1, 512, 28, 28] 2,359,808
34 ReLU-21 [-1, 512, 28, 28] 0
35 Conv2d-22 [-1, 512, 28, 28] 2,359,808
36 ReLU-23 [-1, 512, 28, 28] 0
37 MaxPool2d-24 [-1, 512, 14, 14] 0
38 Conv2d-25 [-1, 512, 14, 14] 2,359,808
39 ReLU-26 [-1, 512, 14, 14] 0
40 Conv2d-27 [-1, 512, 14, 14] 2,359,808
41 ReLU-28 [-1, 512, 14, 14] 0
42 Conv2d-29 [-1, 512, 14, 14] 2,359,808
43 ReLU-30 [-1, 512, 14, 14] 0
44 MaxPool2d-31 [-1, 512, 7, 7] 0
45 Linear-32 [-1, 4096] 102,764,544
46 ReLU-33 [-1, 4096] 0
47 Dropout-34 [-1, 4096] 0
48 Linear-35 [-1, 4096] 16,781,312
49 ReLU-36 [-1, 4096] 0
50 Dropout-37 [-1, 4096] 0
51 Linear-38 [-1, 1000] 4,097,000
52 ================================================================
53 Total params: 138,357,544
54 Trainable params: 138,357,544
55 Non-trainable params: 0
56 ----------------------------------------------------------------
57 Input size (MB): 0.57
58 Forward/backward pass size (MB): 218.59
59 Params size (MB): 527.79
60 Estimated Total Size (MB): 746.96
61 ----------------------------------------------------------------

 


推荐阅读
  • 本文介绍了lua语言中闭包的特性及其在模式匹配、日期处理、编译和模块化等方面的应用。lua中的闭包是严格遵循词法定界的第一类值,函数可以作为变量自由传递,也可以作为参数传递给其他函数。这些特性使得lua语言具有极大的灵活性,为程序开发带来了便利。 ... [详细]
  • HDU 2372 El Dorado(DP)的最长上升子序列长度求解方法
    本文介绍了解决HDU 2372 El Dorado问题的一种动态规划方法,通过循环k的方式求解最长上升子序列的长度。具体实现过程包括初始化dp数组、读取数列、计算最长上升子序列长度等步骤。 ... [详细]
  • 本文讨论了Alink回归预测的不完善问题,指出目前主要针对Python做案例,对其他语言支持不足。同时介绍了pom.xml文件的基本结构和使用方法,以及Maven的相关知识。最后,对Alink回归预测的未来发展提出了期待。 ... [详细]
  • 本文讨论了如何优化解决hdu 1003 java题目的动态规划方法,通过分析加法规则和最大和的性质,提出了一种优化的思路。具体方法是,当从1加到n为负时,即sum(1,n)sum(n,s),可以继续加法计算。同时,还考虑了两种特殊情况:都是负数的情况和有0的情况。最后,通过使用Scanner类来获取输入数据。 ... [详细]
  • 本文介绍了OC学习笔记中的@property和@synthesize,包括属性的定义和合成的使用方法。通过示例代码详细讲解了@property和@synthesize的作用和用法。 ... [详细]
  • 1,关于死锁的理解死锁,我们可以简单的理解为是两个线程同时使用同一资源,两个线程又得不到相应的资源而造成永无相互等待的情况。 2,模拟死锁背景介绍:我们创建一个朋友 ... [详细]
  • 《数据结构》学习笔记3——串匹配算法性能评估
    本文主要讨论串匹配算法的性能评估,包括模式匹配、字符种类数量、算法复杂度等内容。通过借助C++中的头文件和库,可以实现对串的匹配操作。其中蛮力算法的复杂度为O(m*n),通过随机取出长度为m的子串作为模式P,在文本T中进行匹配,统计平均复杂度。对于成功和失败的匹配分别进行测试,分析其平均复杂度。详情请参考相关学习资源。 ... [详细]
  • 动态规划算法的基本步骤及最长递增子序列问题详解
    本文详细介绍了动态规划算法的基本步骤,包括划分阶段、选择状态、决策和状态转移方程,并以最长递增子序列问题为例进行了详细解析。动态规划算法的有效性依赖于问题本身所具有的最优子结构性质和子问题重叠性质。通过将子问题的解保存在一个表中,在以后尽可能多地利用这些子问题的解,从而提高算法的效率。 ... [详细]
  • 高质量SQL书写的30条建议
    本文提供了30条关于优化SQL的建议,包括避免使用select *,使用具体字段,以及使用limit 1等。这些建议是基于实际开发经验总结出来的,旨在帮助读者优化SQL查询。 ... [详细]
  • 本文内容为asp.net微信公众平台开发的目录汇总,包括数据库设计、多层架构框架搭建和入口实现、微信消息封装及反射赋值、关注事件、用户记录、回复文本消息、图文消息、服务搭建(接入)、自定义菜单等。同时提供了示例代码和相关的后台管理功能。内容涵盖了多个方面,适合综合运用。 ... [详细]
  • 本文介绍了使用Java实现大数乘法的分治算法,包括输入数据的处理、普通大数乘法的结果和Karatsuba大数乘法的结果。通过改变long类型可以适应不同范围的大数乘法计算。 ... [详细]
  • Mac OS 升级到11.2.2 Eclipse打不开了,报错Failed to create the Java Virtual Machine
    本文介绍了在Mac OS升级到11.2.2版本后,使用Eclipse打开时出现报错Failed to create the Java Virtual Machine的问题,并提供了解决方法。 ... [详细]
  • 本文讲述了作者通过点火测试男友的性格和承受能力,以考验婚姻问题。作者故意不安慰男友并再次点火,观察他的反应。这个行为是善意的玩人,旨在了解男友的性格和避免婚姻问题。 ... [详细]
  • 本文详细介绍了Linux中进程控制块PCBtask_struct结构体的结构和作用,包括进程状态、进程号、待处理信号、进程地址空间、调度标志、锁深度、基本时间片、调度策略以及内存管理信息等方面的内容。阅读本文可以更加深入地了解Linux进程管理的原理和机制。 ... [详细]
  • Java验证码——kaptcha的使用配置及样式
    本文介绍了如何使用kaptcha库来实现Java验证码的配置和样式设置,包括pom.xml的依赖配置和web.xml中servlet的配置。 ... [详细]
author-avatar
b01453901
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有