作者:书友73892718 | 来源:互联网 | 2023-08-08 21:20
LeNet5网络和CIFAR10数据集main函数--dataloader--train--test1importtorch2fromtorch.utils.dataimportD
LeNet5网络和CIFAR10数据集
main函数--dataloader--train--test
1 import torch
2 from torch.utils.data import DataLoader
3 from torchvision import datasets
4 from torchvision import transforms
5 from torch import nn,optim
6 from lenet5 import LeNet5
7
8 def main():
9 batch_size = 32
10 cifar_train = datasets.CIFAR10(‘cifar‘,train = True,transform = transforms.Compose([
11 transforms.Resize((32,32)),
12 transforms.ToTensor()
13 ]),download = True)
14
15 # 可以同时加载多张图片
16 cifar_train = DataLoader(cifar_train,batch_size = batch_size,shuffle = True)
17
18 cifar_test = datasets.CIFAR10(‘cifar‘,train = False,transform = transforms.Compose([
19 transforms.Resize((32,32)),
20 transforms.ToTensor()
21 ]),download = True)
22
23 # 可以同时加载多张图片
24 cifar_test = DataLoader(cifar_test,batch_size = batch_size,shuffle = True)
25
26 # 数据加载成功后可以检验shape
27 x,label = iter(cifar_train).next()
28 print(‘x:‘,x.shape,‘label:‘,label.shape)
29
30 device = torch.device(‘cuda‘)
31 model = LeNet5().to(device)
32 criteon = nn.CrossEntropyLoss().to(device)
33 optimizer = optim.Adam(model.parameters(),lr=1e-3)
34
35 print(model)
36
37 for epoch in range(1000):
38
39 model.train()
40 for batchidx,(x,label) in enumerate(cifar_train):
41 # x: [b,3,32,32], label: [b]
42 x,label = x.to(device),label.to(device)
43
44 logits = model(x)
45 # logits:[b,10]
46 # label:[b]
47 loss = criteon(logits,label)
48
49 # backprop
50 optimizer.zero_grad()
51 loss.backwark()
52 optimizer.step()
53
54 #
55 print(epoch,loss.item())
56
57 model.eval()
58 # 不需要做梯度相关计算
59 with torch.nn_grad():
60 # test
61 total_correct = 0
62 total_num = 0
63 for x,label in cifar_test:
64 x,label = x.to(device),label.to(device)
65 # logits:[b,10]
66 logits = model(x)
67 pred = logits.argmax(dim=1)
68 # 获取一个batch的在累加
69 total_correct = += torch.eq(pred,label).float().sum().item()
70 # x.size(0)就是batch_size
71 total_num += x.size(0)
72
73 acc = total_correct / total_num
74 print(epoch,acc)
75
76 if __name__ == ‘__main__‘
77 main()
LeNet网络--tmp测试
1 import torch
2 from torch import nn
3 from torch.nn import functional as F
4
5 class LeNet5(nn.Module):
6 """
7 for cifar10 dataset.
8 """
9 def __init__(self):
10 super(LeNet5,self).__init__()
11
12 self.conv_unit = nn.Sequential(
13 # x:[b,3,32,32] --> [b,6,]
14 # input_channel,output_channel,kernel_size,stride,padding
15 nn.Conv2d(3,6,kernel_size = 5,stride = 1,padding = 0),
16 nn.AvgPool2d(kernel_size = 2,stride = 2,padding = 0),
17 #
18 nn.Conv2d(6,16,kernel_size = 5,stride = 1,padding = 0),
19 nn.AvgPool2d(kernel_size = 2,stride = 2,padding = 0),
20 )
21 # Flatten
22 # fc_unit
23 self.fc_unit = nn.Sequential(
24 # 由下面的测试得出来的
25 nn.Linear(16*5*5,120),
26 # 全连接层会出现梯度离散现象,加一个relu
27 nn.ReLU(),
28 nn.Linear(120,84),
29 nn.ReLU(),
30 nn.Linear(84,10),
31 )
32 ‘‘‘
33 tmp = torch.randn(2,3,32,32)
34 out = self.conv_unit(tmp)
35 # 测试一下输出的维度,用于全连接层
36 # [2,16,5,5]
37 print(‘conv_out:‘,out.shape)
38 ‘‘‘
39
40 # use Cross Entropy Loss
41 # 放到类外,不用引入y参数
42 # self.criteon = nn.CrossEntropyLoss()
43
44 # 从左往右走的,backward会自动根据这个走
45 def forward(self,x):
46 # 取得x的shape,然后0号为batch_size
47 batch_size = x.size(0)
48 # [b,3,32,32] --> [b,16,5,5]
49 x = self.conv_unit(x)
50 # [b,16,5,5] --> [b,16*5*5]
51 x = x.view(batch_size,16*5*5)
52 # [b,16*5*5] --> [b,10]
53 logits = self.fc_unit(x)
54 return logits
55 # [b,10] crossEntropy会包含,不用写
56 # pred = F.softmax(logits,dim = 1)
57 # loss = self.criteon(logits,y)
58
59
60 def main():
61
62 net = LeNet5()
63 tmp = torch.randn(2,3,32,32)
64 out = net(tmp)
65 print(‘lenet_out:‘,out.shape)
66
67
68 if __name__ == ‘__main__‘
69 main()