作者:可心 | 来源:互联网 | 2023-09-12 11:31
先放上完整的训练测试代码:#-*-coding:utf-8-*-CreatedonFriAug715:10:162020@author:wjimporttorchfro
先放上完整的训练测试代码:
"""
Created on Fri Aug 7 15:10:16 2020
@author: wj
"""
import torch
from torch import nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
num_epoches = 10
BATCH_SIZE = 128
TIME_STEP = 28
INPUT_SIZE = 28
LR = 0.01
train_dataset = datasets.MNIST(
root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(
root='./data', train=False, transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.LSTM(input_size=INPUT_SIZE,
hidden_size=64,
num_layers=1,
batch_first=True,
)
self.out = nn.Linear(64, 10)
def forward(self, x):
r_out, (h_n, h_c) = self.rnn(x, None)
out = self.out(r_out[:, -1, :])
return out
rnn = RNN()
rnn=rnn.cuda()
print(rnn)
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
for epoch in range(num_epoches):
print('epoch {}'.format(epoch + 1))
print('*' * 10)
running_loss = 0.0
running_acc = 0.0
rnn.train()
for imgs, labels in train_loader:
imgs = imgs.squeeze(1)
imgs = Variable(imgs.cuda())
labels = Variable(labels.cuda())
out = rnn(imgs)
loss = criterion(out, labels)
running_loss += loss.item() * labels.size(0)
_, pred = torch.max(out, 1)
num_correct = (pred == labels).sum()
running_acc += num_correct.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Finish {} epoch, Loss: {:.6f}, Acc: {:.6f}'.format(
epoch + 1, running_loss / (len(train_dataset)), running_acc / (len(train_dataset))))
rnn.eval()
eval_loss = 0.0
eval_acc = 0.0
for imgs, labels in test_loader:
imgs = imgs.squeeze(1)
imgs=imgs.cuda()
labels=labels.cuda()
out = rnn(imgs)
loss = criterion(out, labels)
eval_loss += loss.item() * labels.size(0)
_, pred = torch.max(out, 1)
num_correct = (pred == labels).sum()
eval_acc += num_correct.item()
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
test_dataset)), eval_acc / (len(test_dataset))))
最后放上训练结果:
本文地址:https://blog.csdn.net/weixin_45738220/article/details/107881269