热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

mnisttorch加载fashion_简单的使用PyTorch和FashionMNIST数据集进行深度学习图像分类...

概述本文的目的是为那些想要使用PyTorch和FashionMNIST进行简单深度学习图像分类网络的人提供示例参考代码。在本文中,我们将演示深度学习图像分类网络的所有

概述

本文的目的是为那些想要使用PyTorch和Fashion MNIST进行简单深度学习图像分类网络的人提供示例参考代码。

在本文中,我们将演示深度学习图像分类网络的所有工作部分,包括加载数据,定义网络,优化GPU上的权重以及评估性能。

整理Fashion MNIST数据集

Fashion MNIST是一个包含70,000个灰度图像和10个类的数据集。

1.检查GPU是否可用

import torch

print(torch.cuda.is_available())

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)

2.下载并加载Fashion MNIST数据集

import torch

from torchvision import datasets, transforms

import helper

# Instructions from here:

# https://www.kaggle.com/ishvindersethi22/fashion-mnist-using-pytorch/data

# Define a transform to normalize the data

transform = transforms.Compose([transforms.ToTensor(),

transforms.Normalize([0.], [0.5])])

# Download and load the training data

trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Download and load the test data

testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

3.显示样本图像

import matplotlib.pyplot as plt

import torchvision

import numpy as np

def imshow(img):

img = img / 2 + 0. # unnormalize

img = img.squeeze()

plt.imshow(img, cmap='gray')

images, label = next(iter(trainloader))

imshow(images[0, :])

4.定义和训练网络

from collections import OrderedDict

from torch import optim, nn

hidden_units = [4, 8, 16]

output_units = 10

class Flatten(nn.Module):

def forward(self, input):

return input.view(input.size(0), -1)

model_d = nn.Sequential(OrderedDict([

('conv1', nn.Conv2d(1, hidden_units[0], 3, stride=2, padding=1)),

('Relu1', nn.ReLU()),

('conv2', nn.Conv2d(hidden_units[0], hidden_units[1], 3, stride=2, padding=1)),

('Relu2', nn.ReLU()),

('conv3', nn.Conv2d(hidden_units[1], hidden_units[2], 3, stride=2, padding=1)),

('Relu3', nn.ReLU()),

('conv4', nn.Conv2d(hidden_units[2], output_units, 4, stride=4, padding=0)),

('log_softmax', nn.LogSoftmax(dim = 1))

]))

model_d.to(device)

optimizer_d = optim.Adam(model_d.parameters(), lr = 0.01)

criterion = nn.NLLLoss()

epochs = 10

for i in range(epochs):

running_classification_loss = 0

running_cycle_consistent_loss = 0

running_loss = 0

for images, labels in trainloader:

images, labels = images.to(device), labels.to(device)

optimizer_d.zero_grad()

# Run classification model

predicted_labels = model_d(images)

classification_loss = criterion(Flatten()(predicted_labels), labels)

# Optimize classification weights

classification_loss.backward()

optimizer_d.step()

running_classification_loss += classification_loss.item()

running_loss = running_classification_loss

else:

print(f"{i} Training loss: {running_loss/len(trainloader)}")

5.评估网络

total_correct = 0

total_num = 0

for images, labels in testloader:

images, labels = images.to(device), labels.to(device)

ps = Flatten()(torch.exp(model_d(images)))

predictions = ps.topk(1, 1, True, True)[1].t()

correct = predictions.eq(labels.view(1, -1))

total_correct += correct.sum().cpu().numpy()

total_num += images.shape[0]

print('Accuracy:', total_correct / float(total_num))

print('Correct Label:', labels[0].item())

print('Predicted Label:', predictions[0, 0].item())

index = 0

imshow(images[index, :].cpu())

小贴士

本文演示了深度学习图像分类网络的所有工作部分,我们可以使用最基本的人工智能应用来进行简单的学习。

我们加载了Fashion MNIST数据集

定义一个简单的深度卷积网络

我们使用GPU上的Adam优化器优化网络权重

我们评估网络并达到约85%的准确度



推荐阅读
author-avatar
无味18_380
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有