热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch之Dataset和TensorDataset

DeepLearning系列@cxxDatasetv.s.

Deep Learning系列 @cxx

Dataset v.s. TensorDataset

使用PyTorch搭建过Neural Network的小伙伴们都知道,在数据准备步骤里,我们需要把训练集的x和y分装在dataset里,然后将dataset分装到DataLoader中去,便于之后在搭建好的模型中训练。
简言之,dataset是用来做打包和预处理(比如输入资料路径自动读取);DataLoader则是将整个资料集(dataset)按照batch进行迭代分装或者shuffle(可以得到一个iterator以利于for循环读取)。

Dataset

如果使用继承Dataset的方式,那么在自定义的dataset类中必须给予__len__和__getitem__的定义。
进行图片处理的时候,可以定义一个transforms来随机旋转训练图片,将图片格式变成tensor等
(这里有一个坑)

假设我们读取了一个有如下格式的图片
在这里插入图片描述
将图片分装到dataset里,再放到dataloader里

from torch.utils.data import TensorDataset
batch_size = 128
train_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),]
)
test_transform = transforms.Compose(
[transforms.ToPILImage(),
transforms.ToTensor(),]
) #测试集不需要翻转或旋转图片
#继承Dataset
class ImgDataset(Dataset):
def __init__(self, x, y=None, transform=None):
self.x = x
self.y = y
# label is required to be a LongTensor
if y is not None:
self.y = torch.LongTensor(y)
self.transform = transform
def __len__(self):
return len(self.x)
def __getitem__(self, index):
X = self.x[index]
if self.transform is not None:
X = self.transform(X)
if self.y is not None:
Y = self.y[index]
return X, Y
else:
return X


#将dataset分装到dataloader里
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True
)
test_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False
)

接下来我们可以输出一个batch看看图片的格式
在这里插入图片描述
我们发现一个batch的x[0]的shape由原先的(128, 128, 3)变成了(3, 128, 128)。
原因在于transformers.toTensor()方法有自动转换维度的功能,它会将channel变成第一维(夺么坑爹的功能,导致我排查了好久不知道是哪里出了问题==)
具体可以参照这篇博客transforms.ToTensor()本身有维度转换功能

TensorDataset

张量资料集tensrdataset是最常见的形式,因为PyTorch本身有提供方便的TensorDataset给我们使用

torch.utils.data.TensorDataset(data_tensor, target_tensor)

用TensorDataset写会少写很多东西

#将资料转换成tensor
tsr_x_train, tsr_y_train = torch.tensor(x_train), torch.tensor(y_train)
tsr_x_val, tsr_y_val = torch.tensor(x_val), torch.tensor(y_val)
tsr_x_testing = torch.tensor(x_test)
#然后只需要一行就可以啦
train_dataset = TensorDataset(tsr_x_train, tsr_y_train)
val_dataset = TensorDataset(tsr_x_val, tsr_y_val)
#装入dataloader的步骤同上
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True
)
test_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False
)

我们跑一个loop看看这次维度是否被转换了
在这里插入图片描述
答案是:这次没有!
这次的x[0]的shape同我们一开始设置的shape,TensorDataset并没有帮我们把channel数调成第一维
这里真的要注意呀。


版权声明:本文为qq_43611080原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。
原文链接:https://blog.csdn.net/qq_43611080/article/details/113575167
推荐阅读
  • 深入解析HTML5字符集属性:charset与defaultCharset
    本文将详细介绍HTML5中新增的字符集属性charset和defaultCharset,帮助开发者更好地理解和应用这些属性,以确保网页在不同环境下的正确显示。 ... [详细]
  • 解决Only fullscreen opaque activities can request orientation错误的方法
    本文介绍了在使用PictureSelectorLight第三方框架时遇到的Only fullscreen opaque activities can request orientation错误,并提供了一种有效的解决方案。 ... [详细]
  • 结城浩(1963年7月出生),日本资深程序员和技术作家,居住在东京武藏野市。他开发了著名的YukiWiki软件,并在杂志上发表了大量程序入门文章和技术翻译作品。结城浩著有30多本关于编程和数学的书籍,其中许多被翻译成英文和韩文。 ... [详细]
  • 探索Web 2.0新概念:Widget
    尽管你可能尚未注意到Widget,但正如几年前对RSS的陌生一样,这一概念正逐渐走入大众视野。据美国某权威杂志预测,2007年将是Widget年。本文将详细介绍Widget的定义、功能及其未来发展趋势。 ... [详细]
  • 网站访问全流程解析
    本文详细介绍了从用户在浏览器中输入一个域名(如www.yy.com)到页面完全展示的整个过程,包括DNS解析、TCP连接、请求响应等多个步骤。 ... [详细]
  • 微信公众号推送模板40036问题
    返回码错误码描述说明40001invalidcredential不合法的调用凭证40002invalidgrant_type不合法的grant_type40003invalidop ... [详细]
  • [转]doc,ppt,xls文件格式转PDF格式http:blog.csdn.netlee353086articledetails7920355确实好用。需要注意的是#import ... [详细]
  • MySQL 5.7 学习指南:SQLyog 中的主键、列属性和数据类型
    本文介绍了 MySQL 5.7 中主键(Primary Key)和自增(Auto-Increment)的概念,以及如何在 SQLyog 中设置这些属性。同时,还探讨了数据类型的分类和选择,以及列属性的设置方法。 ... [详细]
  • javascript分页类支持页码格式
    前端时间因为项目需要,要对一个产品下所有的附属图片进行分页显示,没考虑ajax一张张请求,所以干脆一次性全部把图片out,然 ... [详细]
  • Java swing 连连看小游戏  开发小系统 项目源代码 实训实验毕设
    Javaswing连连看小游戏开发小系统项目源代码实训实验能满足学习和二次开发可以作为初学者熟悉Java的学习,作为老师阶段性学习的一个成功检验不再是单调的理解老师空泛的知识,导入 ... [详细]
  • 利用 Python Socket 实现 ICMP 协议下的网络通信
    在计算机网络课程的2.1实验中,学生需要通过Python Socket编程实现一种基于ICMP协议的网络通信功能。与操作系统自带的Ping命令类似,该实验要求学生开发一个简化的、非标准的ICMP通信程序,以加深对ICMP协议及其在网络通信中的应用的理解。通过这一实验,学生将掌握如何使用Python Socket库来构建和解析ICMP数据包,并实现基本的网络探测功能。 ... [详细]
  • 在深入研究 UniApp 封装请求时,发现其请求 API 方法中使用了 `then` 和 `catch` 函数。通过详细分析,了解到这些函数是 Promise 对象的核心组成部分。Promise 是一种用于处理异步操作的结果的标准化方式,它提供了一种更清晰、更可控的方法来管理复杂的异步流程。本文将详细介绍 Promise 的基本概念、结构和常见应用场景,帮助开发者更好地理解和使用这一强大的工具。 ... [详细]
  • HTML 页面中调用 JavaScript 函数生成随机数值并自动展示
    在HTML页面中,通过调用JavaScript函数生成随机数值,并将其自动展示在页面上。具体实现包括构建HTML页面结构,定义JavaScript函数以生成随机数,以及在页面加载时自动调用该函数并将结果呈现给用户。 ... [详细]
  • 通过使用CIFAR-10数据集,本文详细介绍了如何快速掌握Mixup数据增强技术,并展示了该方法在图像分类任务中的显著效果。实验结果表明,Mixup能够有效提高模型的泛化能力和分类精度,为图像识别领域的研究提供了有价值的参考。 ... [详细]
  • 本文介绍了如何利用 `matplotlib` 库中的 `FuncAnimation` 类将 Python 中的动态图像保存为视频文件。通过详细解释 `FuncAnimation` 类的参数和方法,文章提供了多种实用技巧,帮助用户高效地生成高质量的动态图像视频。此外,还探讨了不同视频编码器的选择及其对输出文件质量的影响,为读者提供了全面的技术指导。 ... [详细]
author-avatar
Rain雨露Dew
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有