概述
本文的目的是为那些想要使用PyTorch和Fashion MNIST进行简单深度学习图像分类网络的人提供示例参考代码。
在本文中,我们将演示深度学习图像分类网络的所有工作部分,包括加载数据,定义网络,优化GPU上的权重以及评估性能。
整理Fashion MNIST数据集
Fashion MNIST是一个包含70,000个灰度图像和10个类的数据集。
1.检查GPU是否可用
import torch
print(torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
2.下载并加载Fashion MNIST数据集
import torch
from torchvision import datasets, transforms
import helper
# Instructions from here:
# https://www.kaggle.com/ishvindersethi22/fashion-mnist-using-pytorch/data
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize([0.], [0.5])])
# Download and load the training data
trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# Download and load the test data
testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
3.显示样本图像
import matplotlib.pyplot as plt
import torchvision
import numpy as np
def imshow(img):
img = img / 2 + 0. # unnormalize
img = img.squeeze()
plt.imshow(img, cmap='gray')
images, label = next(iter(trainloader))
imshow(images[0, :])
4.定义和训练网络
from collections import OrderedDict
from torch import optim, nn
hidden_units = [4, 8, 16]
output_units = 10
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
model_d = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1, hidden_units[0], 3, stride=2, padding=1)),
('Relu1', nn.ReLU()),
('conv2', nn.Conv2d(hidden_units[0], hidden_units[1], 3, stride=2, padding=1)),
('Relu2', nn.ReLU()),
('conv3', nn.Conv2d(hidden_units[1], hidden_units[2], 3, stride=2, padding=1)),
('Relu3', nn.ReLU()),
('conv4', nn.Conv2d(hidden_units[2], output_units, 4, stride=4, padding=0)),
('log_softmax', nn.LogSoftmax(dim = 1))
]))
model_d.to(device)
optimizer_d = optim.Adam(model_d.parameters(), lr = 0.01)
criterion = nn.NLLLoss()
epochs = 10
for i in range(epochs):
running_classification_loss = 0
running_cycle_consistent_loss = 0
running_loss = 0
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)
optimizer_d.zero_grad()
# Run classification model
predicted_labels = model_d(images)
classification_loss = criterion(Flatten()(predicted_labels), labels)
# Optimize classification weights
classification_loss.backward()
optimizer_d.step()
running_classification_loss += classification_loss.item()
running_loss = running_classification_loss
else:
print(f"{i} Training loss: {running_loss/len(trainloader)}")
5.评估网络
total_correct = 0
total_num = 0
for images, labels in testloader:
images, labels = images.to(device), labels.to(device)
ps = Flatten()(torch.exp(model_d(images)))
predictions = ps.topk(1, 1, True, True)[1].t()
correct = predictions.eq(labels.view(1, -1))
total_correct += correct.sum().cpu().numpy()
total_num += images.shape[0]
print('Accuracy:', total_correct / float(total_num))
print('Correct Label:', labels[0].item())
print('Predicted Label:', predictions[0, 0].item())
index = 0
imshow(images[index, :].cpu())
小贴士
本文演示了深度学习图像分类网络的所有工作部分,我们可以使用最基本的人工智能应用来进行简单的学习。
我们加载了Fashion MNIST数据集
定义一个简单的深度卷积网络
我们使用GPU上的Adam优化器优化网络权重
我们评估网络并达到约85%的准确度