Keras是一个由Python编写的开源人工神经网络库,Keras包含一个简洁的API接口来呈现出你的模型的样子,这在debug过程中是非常有用的。这里有一段模仿pytorch的代码,It Is summary(), 目标就是提供完备的信息以补充 print(your_model) 的不足。
1 import torch
2 import torch.nn as nn
3 import torch.nn.functional as F
4 from torchsummary import summary
5
6 class Net(nn.Module):
7 def __init__(self):
8 super(Net, self).__init__()
9 self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
10 self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
11 self.conv2_drop = nn.Dropout2d()
12 self.fc1 = nn.Linear(320, 50)
13 self.fc2 = nn.Linear(50, 10)
14
15 def forward(self, x):
16 x = F.relu(F.max_pool2d(self.conv1(x), 2))
17 x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
18 x = x.view(-1, 320)
19 x = F.relu(self.fc1(x))
20 x = F.dropout(x, training=self.training)
21 x = self.fc2(x)
22 return F.log_softmax(x, dim=1)
23
24 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
25 model = Net().to(device)
26
27 summary(model, (1, 28, 28))
28
29 >>>>>:
30 ----------------------------------------------------------------
31 Layer (type) Output Shape Param #
32 ================================================================
33 Conv2d-1 [-1, 10, 24, 24] 260
34 Conv2d-2 [-1, 20, 8, 8] 5,020
35 Dropout2d-3 [-1, 20, 8, 8] 0
36 Linear-4 [-1, 50] 16,050
37 Linear-5 [-1, 10] 510
38 ================================================================
39 Total params: 21,840
40 Trainable params: 21,840
41 Non-trainable params: 0
42 ----------------------------------------------------------------
43 Input size (MB): 0.00
44 Forward/backward pass size (MB): 0.06
45 Params size (MB): 0.08
46 Estimated Total Size (MB): 0.15
47 ----------------------------------------------------------------
1 import torch
2 from torchvision import models
3 from torchsummary import summary
4
5 device = torch.device(‘cuda‘ if torch.cuda.is_available() else ‘cpu‘)
6 vgg = models.vgg16().to(device)
7
8 summary(vgg, (3, 224, 224))
9
10 >>>>>:
11 ----------------------------------------------------------------
12 Layer (type) Output Shape Param #
13 ================================================================
14 Conv2d-1 [-1, 64, 224, 224] 1,792
15 ReLU-2 [-1, 64, 224, 224] 0
16 Conv2d-3 [-1, 64, 224, 224] 36,928
17 ReLU-4 [-1, 64, 224, 224] 0
18 MaxPool2d-5 [-1, 64, 112, 112] 0
19 Conv2d-6 [-1, 128, 112, 112] 73,856
20 ReLU-7 [-1, 128, 112, 112] 0
21 Conv2d-8 [-1, 128, 112, 112] 147,584
22 ReLU-9 [-1, 128, 112, 112] 0
23 MaxPool2d-10 [-1, 128, 56, 56] 0
24 Conv2d-11 [-1, 256, 56, 56] 295,168
25 ReLU-12 [-1, 256, 56, 56] 0
26 Conv2d-13 [-1, 256, 56, 56] 590,080
27 ReLU-14 [-1, 256, 56, 56] 0
28 Conv2d-15 [-1, 256, 56, 56] 590,080
29 ReLU-16 [-1, 256, 56, 56] 0
30 MaxPool2d-17 [-1, 256, 28, 28] 0
31 Conv2d-18 [-1, 512, 28, 28] 1,180,160
32 ReLU-19 [-1, 512, 28, 28] 0
33 Conv2d-20 [-1, 512, 28, 28] 2,359,808
34 ReLU-21 [-1, 512, 28, 28] 0
35 Conv2d-22 [-1, 512, 28, 28] 2,359,808
36 ReLU-23 [-1, 512, 28, 28] 0
37 MaxPool2d-24 [-1, 512, 14, 14] 0
38 Conv2d-25 [-1, 512, 14, 14] 2,359,808
39 ReLU-26 [-1, 512, 14, 14] 0
40 Conv2d-27 [-1, 512, 14, 14] 2,359,808
41 ReLU-28 [-1, 512, 14, 14] 0
42 Conv2d-29 [-1, 512, 14, 14] 2,359,808
43 ReLU-30 [-1, 512, 14, 14] 0
44 MaxPool2d-31 [-1, 512, 7, 7] 0
45 Linear-32 [-1, 4096] 102,764,544
46 ReLU-33 [-1, 4096] 0
47 Dropout-34 [-1, 4096] 0
48 Linear-35 [-1, 4096] 16,781,312
49 ReLU-36 [-1, 4096] 0
50 Dropout-37 [-1, 4096] 0
51 Linear-38 [-1, 1000] 4,097,000
52 ================================================================
53 Total params: 138,357,544
54 Trainable params: 138,357,544
55 Non-trainable params: 0
56 ----------------------------------------------------------------
57 Input size (MB): 0.57
58 Forward/backward pass size (MB): 218.59
59 Params size (MB): 527.79
60 Estimated Total Size (MB): 746.96
61 ----------------------------------------------------------------