热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Pytorch笔记05自定义数据读取方式orch.utils.data.Dataset与Dataloader

0.本章内容在pytorch中,提供了一种十分方便的数据读取机制,即使用torch.utils.data.Dataset与Dataloader组合得到数据迭代器。在每次训练时,利用

0. 本章内容

在pytorch中,提供了一种十分方便的数据读取机制,即使用torch.utils.data.Dataset与Dataloader组合得到数据迭代器。在每次训练时,利用这个迭代器输出每一个batch数据,并能在输出时对数据进行相应的预处理或数据增广操作。

同时,pytorch可视觉包torchvision中,继承torch.utils.data.Dataset,预定义了许多常用的数据集,并提供了许多常用的数据增广函数。

本章主要进行下列介绍:

  • torch.utils.data.Dataset与Dataloader的理解
  • torchvision中的datasets
  • torchvision ImageFolder
  • torchvision transforms

具体代码可以在 XavierLinNow/pytorch_note_CN得到

1. torch.utils.data.Dataset与torch.utils.data.DataLoader的理解

  1. pytorch提供了一个数据读取的方法,其由两个类构成:torch.utils.data.Dataset和DataLoader
  2. 我们要自定义自己数据读取的方法,就需要继承torch.utils.data.Dataset,并将其封装到DataLoader中
  3. torch.utils.data.Dataset表示该数据集,继承该类可以重载其中的方法,实现多种数据读取及数据预处理方式
  4. torch.utils.data.DataLoader 封装了Data对象,实现单(多)进程迭代器输出数据集

下面我们分别介绍下torch.utils.data.Dataset以及DataLoader

1.1 torch.utils.data.Dataset

  1. 要自定义自己的Dataset类,至少要重载两个方法,__len__, __getitem__
  2. __len__返回的是数据集的大小
  3. __getitem__实现索引数据集中的某一个数据
  4. 除了这两个基本功能,还可以在__getitem__时对数据进行预处理,或者是直接在硬盘中读取数据,对于超大的数据集还可以使用lmdb来读取

下面将简单实现一个返回torch.Tensor类型的数据集

from torch.utils.data import DataLoader, Dataset
import torch
class TensorDataset(Dataset):
# TensorDataset继承Dataset, 重载了__init__, __getitem__, __len__
# 实现将一组Tensor数据对封装成Tensor数据集
# 能够通过index得到数据集的数据,能够通过len,得到数据集大小
def __init__(self, data_tensor, target_tensor):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]
def __len__(self):
return self.data_tensor.size(0)
# 生成数据
data_tensor = torch.randn(4, 3)
target_tensor = torch.rand(4)
# 将数据封装成Dataset
tensor_dataset = TensorDataset(data_tensor, target_tensor)
# 可使用索引调用数据
print 'tensor_data[0]: ', tensor_dataset[0]
''' 输出 tensor_data[0]: ( 0.6804 -1.2515 1.6084 [torch.FloatTensor of size 3] , 0.2058754563331604) '''
# 可返回数据len
print 'len os tensor_dataset: ', len(tensor_dataset)
''' 输出: len os tensor_dataset: 4 '''

1.2 torch.utils.data.Dataloader

  1. Dataloader将Dataset或其子类封装成一个迭代器
  2. 这个迭代器可以迭代输出Dataset的内容
  3. 同时可以实现多进程、shuffle、不同采样策略,数据校对等等处理过程

tensor_dataloader = DataLoader(tensor_dataset, # 封装的对象
batch_size=2, # 输出的batchsize
shuffle=True, # 随机输出
num_workers=0) # 只有1个进程
# 以for循环形式输出
for data, target in tensor_dataloader:
print(data, target)
# 输出一个batch
print 'one batch tensor data: ', iter(tensor_dataloader).next()
# 输出batch数量
print 'len of batchtensor: ', len(list(iter(tensor_dataloader)))
''' 输出: one batch tensor data: [ 0.6804 -1.2515 1.6084 -0.1156 -1.1552 0.1866 [torch.FloatTensor of size 2x3] , 0.2059 0.6452 [torch.DoubleTensor of size 2] ] len of batchtensor: 2 '''

2. torchvision.datasets

  1. pytorch专门针对视觉实现了一个torchvision包,里面包括了许多常用的CNN模型以及一些数据集
  2. torchvision.datasets包含了MNIST,cifar10等数据集,他们都是通过继承上述Dataset类实现的

2.1 调用torchvision自带的cifar10数据集

import torchvision.datasets as dset
import torchvision
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import numpy as np
%matplotlib inline
def imshow(img, is_unnormlize=False):
if is_unnormlize:
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# 载入cifar数据集
trainset = dset.CIFAR10(root='../data', # 数据集路径
train=True, # 载入train set
download=True, # 如果未下载数据集,则自动下载。
# 建议直接下载后压缩到root的路径
transform=transforms.ToTensor() # 转换成Tensor才能被封装为DataLoader
)
# 封装成loader
trainloader = DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
# 显示图片
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))

显示图片

《Pytorch笔记05-自定义数据读取方式orch.utils.data.Dataset与Dataloader》

2.2 直接从硬盘中载入自己的图像

torch.datasets包中的ImageFolder支持我们直接从硬盘中按照固定路径格式载入每张数据,其格式如下:

  • 根目录/类别/图像

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

3. torchvision.transforms

  1. 在刚才,我们见到生成cifar数据集时有一个参数transform,这个参数就是实现各种预处理
  2. 在torchvision.transforms中,有多种预测方式,如scale,centercrop
  3. 我们可以使用Compose将这些预处理方式组成transforms list,对图像进行多种处理
  4. 需要注意,由于这些transform是基于PIL的,因此Compose中,Scale等预处理需要先调用,ToTensor需要后与他们
  5. 如果觉得torchvision自带的预处理不够多,可以使用https://github.com/ncullen93/torchsample 中的transforms

# 定义transform
transform = torchvision.transforms.Compose(
[transforms.RandomCrop(20),
transforms.ToTensor(), # ToTensor需要在预处理之后进行
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] )
# 载入cifar数据集
trainset = dset.CIFAR10(root='../data', # 数据集路径
train=True, # 载入train set
download=True, # 如果未下载数据集,则自动下载。
# 建议直接下载后压缩到root的路径
transform=transform # 进行预处理
)
# 封装成loader
trainloader = DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
# 显示图片
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images), True)

经过数据增广的数据:

《Pytorch笔记05-自定义数据读取方式orch.utils.data.Dataset与Dataloader》


推荐阅读
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 本文介绍了使用Spark实现低配版高斯朴素贝叶斯模型的原因和原理。随着数据量的增大,单机上运行高斯朴素贝叶斯模型会变得很慢,因此考虑使用Spark来加速运行。然而,Spark的MLlib并没有实现高斯朴素贝叶斯模型,因此需要自己动手实现。文章还介绍了朴素贝叶斯的原理和公式,并对具有多个特征和类别的模型进行了讨论。最后,作者总结了实现低配版高斯朴素贝叶斯模型的步骤。 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • 合并列值-合并为一列问题需求:createtabletab(Aint,Bint,Cint)inserttabselect1,2,3unionallsel ... [详细]
  • 本文介绍了利用ARMA模型对平稳非白噪声序列进行建模的步骤及代码实现。首先对观察值序列进行样本自相关系数和样本偏自相关系数的计算,然后根据这些系数的性质选择适当的ARMA模型进行拟合,并估计模型中的位置参数。接着进行模型的有效性检验,如果不通过则重新选择模型再拟合,如果通过则进行模型优化。最后利用拟合模型预测序列的未来走势。文章还介绍了绘制时序图、平稳性检验、白噪声检验、确定ARMA阶数和预测未来走势的代码实现。 ... [详细]
  • 我用Tkinter制作了一个图形用户界面,有两个主按钮:“开始”和“停止”。请您就如何使用“停止”按钮终止“开始”按钮为以下代码调用的已运行功能提供建议 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • 本文介绍了设计师伊振华受邀参与沈阳市智慧城市运行管理中心项目的整体设计,并以数字赋能和创新驱动高质量发展的理念,建设了集成、智慧、高效的一体化城市综合管理平台,促进了城市的数字化转型。该中心被称为当代城市的智能心脏,为沈阳市的智慧城市建设做出了重要贡献。 ... [详细]
  • 展开全部下面的代码是创建一个立方体Thisexamplescreatesanddisplaysasimplebox.#Thefirstlineloadstheinit_disp ... [详细]
  • 本文介绍了如何使用n3-charts绘制以日期为x轴的数据,并提供了相应的代码示例。通过设置x轴的类型为日期,可以实现对日期数据的正确显示和处理。同时,还介绍了如何设置y轴的类型和其他相关参数。通过本文的学习,读者可以掌握使用n3-charts绘制日期数据的方法。 ... [详细]
  • OpenMap教程4 – 图层概述
    本文介绍了OpenMap教程4中关于地图图层的内容,包括将ShapeLayer添加到MapBean中的方法,OpenMap支持的图层类型以及使用BufferedLayer创建图像的MapBean。此外,还介绍了Layer背景标志的作用和OMGraphicHandlerLayer的基础层类。 ... [详细]
  • 本文介绍了使用readlink命令获取文件的完整路径的简单方法,并提供了一个示例命令来打印文件的完整路径。共有28种解决方案可供选择。 ... [详细]
author-avatar
卫宇欢试试
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有