热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

十分钟搞懂Pytorch如何读取MNIST数据集

前言本文用于记录使用pytorch读取minist数据集的过程,以及一些思考和疑惑吧…正文在阅读教程书籍《深度学习入门之Pytorch》时,文中是如此加载MNIST手写数字训练集的

前言

本文用于记录使用pytorch读取minist数据集的过程,以及一些思考和疑惑吧…

正文

在阅读教程书籍《深度学习入门之Pytorch》时,文中是如此加载MNIST手写数字训练集的:

train_dataset = datasets.MNIST(root='./MNIST',train=True,transform=data_tf,download=True)

解释一下参数

datasets.MNIST是Pytorch的内置函数torchvision.datasets.MNIST,通过这个可以导入数据集。

train=True 代表我们读入的数据作为训练集(如果为true则从training.pt创建数据集,否则从test.pt创建数据集)

transform则是读入我们自己定义的数据预处理操作

download=True则是当我们的根目录(root)下没有数据集时,便自动下载。

如果这时候我们通过联网自动下载方式download我们的数据后,它的文件路径是以下形式:

在这里插入图片描述

其中我们所需要的文件主要在raw文件夹下

train-images-idx3-ubyte.gz: training set images (9912422 bytes)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)

接下来,书中是如此加载数据集的

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=5,
shuffle=True)

由于DataLoader为Pytorch内部封装好的函数,所以对于它的调用方法需要自行去查阅。

我在最开始疑惑的点:传入的根目录在下载好数据集后,为MNIST下两个文件夹,而processed和raw文件夹下还有诸多文件,所以到底是如何读入数据的呢?所以我决定将数据集下载后,通过读取本地的MINIST数据集并进行装载。

首先,自定义数据类来继承和重写Dataset抽象类

class DealDataset(Dataset):
""" 读取数据、初始化数据 """
def __init__(self, folder, data_name, label_name,transform=None):
(train_set, train_labels) = self.load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式
self.train_set = train_set
self.train_labels = train_labels
self.transform = transform
def __getitem__(self, index):
img, target = self.train_set[index], int(self.train_labels[index])
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.train_set)
''' load_data也是我们自定义的函数,用途:读取数据集中的数据 ( 图片数据+标签label '''
def load_data(self,data_folder, data_name, label_name):
with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath:
x_train = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
return (x_train, y_train)

接下来,调用我们自定义的数据类来加载数据集

trainDataset = DealDataset('./MNIST/MNIST/raw', "train-images-idx3-ubyte.gz","train-labels-idx1-ubyte.gz",transform=transforms.ToTensor())
# 训练数据和测试数据的装载
train_loader = torch.utils.data.DataLoader(
dataset=trainDataset,
batch_size=10, # 一个批次可以认为是一个包,每个包中含有10张图片
shuffle=False,
)

通过这种方式便可以大概了解了读取数据集的过程。

接下来,我们来验证以下我们数据是否正确加载

# 实现单张图片可视化
images, labels = next(iter(train_loader))
img = torchvision.utils.make_grid(images)
img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
plt.imshow(img)
plt.show()

p.s.:其实这里是用cv2.imshow来展示图片,但是我的代码是在jupyter notebook上写的,所以只能通过plt来代替加载。
在这里插入图片描述

数据加载成功~

深入探索

可以看到,在load_data函数中

y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

这个offset=8又是为啥呢?
我们进入MNIST数据集的官方页面进行查看
在这里插入图片描述

通过文档介绍,可以看到
offset的0000-0003是 magic number,所以跳过不读,
offset的0004-0007是items数目
接下来这些代表的就是标签

同理对于

x_train = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_train)

在这里插入图片描述

根据刚才的分析方法,也可以明白为什么offset=16了


完整代码


1.直接使用pytorch自带的mnist数据集加载

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2
import matplotlib.pyplot as plt
data_tf = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])]
)
train_dataset = datasets.MNIST(root='./coding/learning/lrdata/MNIST',train=True,transform=data_tf,download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=5,
shuffle=True)
# 实现单张图片可视化
images, labels = next(iter(train_loader))
img = torchvision.utils.make_grid(images)
img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
plt.imshow(img)
plt.show()

p.s.:记得自己修改root根目录。

2.使用自定义的数据类加载本地MNIST数据集

import numpy as np
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
import gzip
import os
import torchvision
import cv2
import matplotlib.pyplot as plt
class DealDataset(Dataset):
""" 读取数据、初始化数据 """
def __init__(self, folder, data_name, label_name,transform=None):
(train_set, train_labels) = load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式
self.train_set = train_set
self.train_labels = train_labels
self.transform = transform
def __getitem__(self, index):
img, target = self.train_set[index], int(self.train_labels[index])
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.train_set)
def load_data(data_folder, data_name, label_name):
with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath:
x_train = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
return (x_train, y_train)
trainDataset = DealDataset('./coding/learning/lrdata/MNIST/MNIST/raw', "train-images-idx3-ubyte.gz","train-labels-idx1-ubyte.gz",transform=transforms.ToTensor())
# 训练数据和测试数据的装载
train_loader = torch.utils.data.DataLoader(
dataset=trainDataset,
batch_size=10, # 一个批次可以认为是一个包,每个包中含有10张图片
shuffle=False,
)
# 实现单张图片可视化
images, labels = next(iter(train_loader))
img = torchvision.utils.make_grid(images)
img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
plt.imshow(img)
plt.show()

参考

1.《深度学习入门之Pytorch》- 廖星宇
2.使用Pytorch进行读取本地的MINIST数据集并进行装载
3.顺藤摸瓜-mnist数据集的补充

在这里插入图片描述


推荐阅读
  • 本文详细介绍了Python的multiprocessing模块,该模块不仅支持本地并发操作,还支持远程操作。通过使用multiprocessing模块,开发者可以利用多核处理器的优势,提高程序的执行效率。 ... [详细]
  • django项目中使用手机号登录
    本文使用聚合数据的短信接口,需要先获取到申请接口的appkey和模板id项目目录下创建ubtils文件夹,定义返回随机验证码和调取短信接口的函数function.py文件se ... [详细]
  • Lua基本语法lua与C#的交互(相当简单详细的例子)
    lua脚本与C#的交互本文提供全流程,中文翻译。Chinar坚持将简单的生活方式,带给世人!(拥有更好的阅读体验——高分辨率用户请根据需求调整网页缩放比例)1LuaAndC#——L ... [详细]
  • Java 中静态和非静态嵌套类的区别 ... [详细]
  • 本文介绍了JSP的基本概念、常用标签及其功能,并通过示例详细说明了如何在JSP页面中使用Java代码。 ... [详细]
  • 解决 Pytest 运行时出现 FileNotFoundError 的方法
    在使用 Pytest 进行测试时,可能会遇到 FileNotFoundError 错误,提示无法找到指定的文件或目录。本文将探讨该错误的原因及解决方案。 ... [详细]
  • 本文将指导你通过 Gulp 和 Webpack 构建一个简单的用户登录界面,包括目录结构设置和关键文件的配置。 ... [详细]
  • 这个报错出现在userDao里面,sessionfactory没有注入。解决办法:spring整合Hibernate使用test测试时要把spring.xml和spring-hib ... [详细]
  • 深入解析Keras中的ImageDataGenerator参数
    本文详细探讨了Keras库中ImageDataGenerator类的各项参数及其功能,旨在帮助读者更好地理解和应用数据增强技术,提高模型训练效果。 ... [详细]
  • 本文旨在探讨计算机机房的有效管理与维护方法,包括合理的机房布局设计、高效的操作系统安装与恢复技术以及数据保护措施。随着信息技术教育的发展,计算机机房作为教学的重要组成部分,其稳定性和安全性直接影响到教学质量。文章分析了当前机房管理中存在的问题,并提出了针对性的解决方案。 ... [详细]
  • 利用Selenium框架解决SSO单点登录接口无法返回Token的问题
    针对接口自动化测试中遇到的SSO单点登录系统不支持通过API接口返回Token的问题,本文提供了一种解决方案,即通过UI自动化工具Selenium模拟用户登录过程,从浏览器的localStorage或sessionStorage中提取Token。 ... [详细]
  • 本文探讨了在使用Apache Flink向Kafka发送数据过程中遇到的事务频繁失败问题,并提供了详细的解决方案,包括必要的配置调整和最佳实践。 ... [详细]
  • 本文提供了解决在尝试重置MySQL root用户密码时遇到连接失败问题的方法,包括停止MySQL服务、以安全模式启动MySQL、手动更新用户表中的密码等步骤。 ... [详细]
  • 本文探讨了在Node.js环境中如何有效地捕获标准输出(stdout)的内容,并将其存储到变量中。通过具体的示例和解决方案,帮助开发者解决常见的输出捕获问题。 ... [详细]
  • 本文档提供了详细的MySQL安装步骤,包括解压安装文件、选择安装类型、配置MySQL服务以及设置管理员密码等关键环节,帮助用户顺利完成MySQL的安装。 ... [详细]
author-avatar
奕殫的泪
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有