热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch之Dataset和TensorDataset

DeepLearning系列@cxxDatasetv.s.

Deep Learning系列 @cxx

Dataset v.s. TensorDataset

使用PyTorch搭建过Neural Network的小伙伴们都知道,在数据准备步骤里,我们需要把训练集的x和y分装在dataset里,然后将dataset分装到DataLoader中去,便于之后在搭建好的模型中训练。
简言之,dataset是用来做打包和预处理(比如输入资料路径自动读取);DataLoader则是将整个资料集(dataset)按照batch进行迭代分装或者shuffle(可以得到一个iterator以利于for循环读取)。

Dataset

如果使用继承Dataset的方式,那么在自定义的dataset类中必须给予__len__和__getitem__的定义。
进行图片处理的时候,可以定义一个transforms来随机旋转训练图片,将图片格式变成tensor等
(这里有一个坑)

假设我们读取了一个有如下格式的图片
在这里插入图片描述
将图片分装到dataset里,再放到dataloader里

from torch.utils.data import TensorDataset
batch_size = 128
train_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),]
)
test_transform = transforms.Compose(
[transforms.ToPILImage(),
transforms.ToTensor(),]
) #测试集不需要翻转或旋转图片
#继承Dataset
class ImgDataset(Dataset):
def __init__(self, x, y=None, transform=None):
self.x = x
self.y = y
# label is required to be a LongTensor
if y is not None:
self.y = torch.LongTensor(y)
self.transform = transform
def __len__(self):
return len(self.x)
def __getitem__(self, index):
X = self.x[index]
if self.transform is not None:
X = self.transform(X)
if self.y is not None:
Y = self.y[index]
return X, Y
else:
return X


#将dataset分装到dataloader里
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True
)
test_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False
)

接下来我们可以输出一个batch看看图片的格式
在这里插入图片描述
我们发现一个batch的x[0]的shape由原先的(128, 128, 3)变成了(3, 128, 128)。
原因在于transformers.toTensor()方法有自动转换维度的功能,它会将channel变成第一维(夺么坑爹的功能,导致我排查了好久不知道是哪里出了问题==)
具体可以参照这篇博客transforms.ToTensor()本身有维度转换功能

TensorDataset

张量资料集tensrdataset是最常见的形式,因为PyTorch本身有提供方便的TensorDataset给我们使用

torch.utils.data.TensorDataset(data_tensor, target_tensor)

用TensorDataset写会少写很多东西

#将资料转换成tensor
tsr_x_train, tsr_y_train = torch.tensor(x_train), torch.tensor(y_train)
tsr_x_val, tsr_y_val = torch.tensor(x_val), torch.tensor(y_val)
tsr_x_testing = torch.tensor(x_test)
#然后只需要一行就可以啦
train_dataset = TensorDataset(tsr_x_train, tsr_y_train)
val_dataset = TensorDataset(tsr_x_val, tsr_y_val)
#装入dataloader的步骤同上
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True
)
test_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False
)

我们跑一个loop看看这次维度是否被转换了
在这里插入图片描述
答案是:这次没有!
这次的x[0]的shape同我们一开始设置的shape,TensorDataset并没有帮我们把channel数调成第一维
这里真的要注意呀。


版权声明:本文为qq_43611080原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。
原文链接:https://blog.csdn.net/qq_43611080/article/details/113575167
推荐阅读
  • 技术分享:从动态网站提取站点密钥的解决方案
    本文探讨了如何从动态网站中提取站点密钥,特别是针对验证码(reCAPTCHA)的处理方法。通过结合Selenium和requests库,提供了详细的代码示例和优化建议。 ... [详细]
  • 毕业设计:基于机器学习与深度学习的垃圾邮件(短信)分类算法实现
    本文详细介绍了如何使用机器学习和深度学习技术对垃圾邮件和短信进行分类。内容涵盖从数据集介绍、预处理、特征提取到模型训练与评估的完整流程,并提供了具体的代码示例和实验结果。 ... [详细]
  • 利用决策树预测NBA比赛胜负的Python数据挖掘实践
    本文通过使用2013-14赛季NBA赛程与结果数据集以及2013年NBA排名数据,结合《Python数据挖掘入门与实践》一书中的方法,展示如何应用决策树算法进行比赛胜负预测。我们将详细讲解数据预处理、特征工程及模型评估等关键步骤。 ... [详细]
  • Python 异步编程:深入理解 asyncio 库(上)
    本文介绍了 Python 3.4 版本引入的标准库 asyncio,该库为异步 IO 提供了强大的支持。我们将探讨为什么需要 asyncio,以及它如何简化并发编程的复杂性,并详细介绍其核心概念和使用方法。 ... [详细]
  • 本文详细介绍了Java中org.neo4j.helpers.collection.Iterators.single()方法的功能、使用场景及代码示例,帮助开发者更好地理解和应用该方法。 ... [详细]
  • 优化ListView性能
    本文深入探讨了如何通过多种技术手段优化ListView的性能,包括视图复用、ViewHolder模式、分批加载数据、图片优化及内存管理等。这些方法能够显著提升应用的响应速度和用户体验。 ... [详细]
  • Explore how Matterverse is redefining the metaverse experience, creating immersive and meaningful virtual environments that foster genuine connections and economic opportunities. ... [详细]
  • 深入理解Tornado模板系统
    本文详细介绍了Tornado框架中模板系统的使用方法。Tornado自带的轻量级、高效且灵活的模板语言位于tornado.template模块,支持嵌入Python代码片段,帮助开发者快速构建动态网页。 ... [详细]
  • 1.如何在运行状态查看源代码?查看函数的源代码,我们通常会使用IDE来完成。比如在PyCharm中,你可以Ctrl+鼠标点击进入函数的源代码。那如果没有IDE呢?当我们想使用一个函 ... [详细]
  • 本文详细介绍了 Dockerfile 的编写方法及其在网络配置中的应用,涵盖基础指令、镜像构建与发布流程,并深入探讨了 Docker 的默认网络、容器互联及自定义网络的实现。 ... [详细]
  • 数据库内核开发入门 | 搭建研发环境的初步指南
    本课程将带你从零开始,逐步掌握数据库内核开发的基础知识和实践技能,重点介绍如何搭建OceanBase的开发环境。 ... [详细]
  • Python自动化处理:从Word文档提取内容并生成带水印的PDF
    本文介绍如何利用Python实现从特定网站下载Word文档,去除水印并添加自定义水印,最终将文档转换为PDF格式。该方法适用于批量处理和自动化需求。 ... [详细]
  • 本文介绍如何使用Python进行文本处理,包括分词和生成词云图。通过整合多个文本文件、去除停用词并生成词云图,展示文本数据的可视化分析方法。 ... [详细]
  • 本文介绍如何使用 Python 获取文件和图片的创建、修改及拍摄日期。通过多种方法,如 PIL 库的 _getexif() 函数和 os 模块的 getmtime() 和 stat() 方法,详细讲解了这些技术的应用场景和注意事项。 ... [详细]
  • 自己用过的一些比较有用的css3新属性【HTML】
    web前端|html教程自己用过的一些比较用的css3新属性web前端-html教程css3刚推出不久,虽然大多数的css3属性在很多流行的浏览器中不支持,但我个人觉得还是要尽量开 ... [详细]
author-avatar
Rain雨露Dew
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有