热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch之Dataset和TensorDataset

DeepLearning系列@cxxDatasetv.s.

Deep Learning系列 @cxx

Dataset v.s. TensorDataset

使用PyTorch搭建过Neural Network的小伙伴们都知道,在数据准备步骤里,我们需要把训练集的x和y分装在dataset里,然后将dataset分装到DataLoader中去,便于之后在搭建好的模型中训练。
简言之,dataset是用来做打包和预处理(比如输入资料路径自动读取);DataLoader则是将整个资料集(dataset)按照batch进行迭代分装或者shuffle(可以得到一个iterator以利于for循环读取)。

Dataset

如果使用继承Dataset的方式,那么在自定义的dataset类中必须给予__len__和__getitem__的定义。
进行图片处理的时候,可以定义一个transforms来随机旋转训练图片,将图片格式变成tensor等
(这里有一个坑)

假设我们读取了一个有如下格式的图片
在这里插入图片描述
将图片分装到dataset里,再放到dataloader里

from torch.utils.data import TensorDataset
batch_size = 128
train_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),]
)
test_transform = transforms.Compose(
[transforms.ToPILImage(),
transforms.ToTensor(),]
) #测试集不需要翻转或旋转图片
#继承Dataset
class ImgDataset(Dataset):
def __init__(self, x, y=None, transform=None):
self.x = x
self.y = y
# label is required to be a LongTensor
if y is not None:
self.y = torch.LongTensor(y)
self.transform = transform
def __len__(self):
return len(self.x)
def __getitem__(self, index):
X = self.x[index]
if self.transform is not None:
X = self.transform(X)
if self.y is not None:
Y = self.y[index]
return X, Y
else:
return X


#将dataset分装到dataloader里
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True
)
test_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False
)

接下来我们可以输出一个batch看看图片的格式
在这里插入图片描述
我们发现一个batch的x[0]的shape由原先的(128, 128, 3)变成了(3, 128, 128)。
原因在于transformers.toTensor()方法有自动转换维度的功能,它会将channel变成第一维(夺么坑爹的功能,导致我排查了好久不知道是哪里出了问题==)
具体可以参照这篇博客transforms.ToTensor()本身有维度转换功能

TensorDataset

张量资料集tensrdataset是最常见的形式,因为PyTorch本身有提供方便的TensorDataset给我们使用

torch.utils.data.TensorDataset(data_tensor, target_tensor)

用TensorDataset写会少写很多东西

#将资料转换成tensor
tsr_x_train, tsr_y_train = torch.tensor(x_train), torch.tensor(y_train)
tsr_x_val, tsr_y_val = torch.tensor(x_val), torch.tensor(y_val)
tsr_x_testing = torch.tensor(x_test)
#然后只需要一行就可以啦
train_dataset = TensorDataset(tsr_x_train, tsr_y_train)
val_dataset = TensorDataset(tsr_x_val, tsr_y_val)
#装入dataloader的步骤同上
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True
)
test_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False
)

我们跑一个loop看看这次维度是否被转换了
在这里插入图片描述
答案是:这次没有!
这次的x[0]的shape同我们一开始设置的shape,TensorDataset并没有帮我们把channel数调成第一维
这里真的要注意呀。


版权声明:本文为qq_43611080原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。
原文链接:https://blog.csdn.net/qq_43611080/article/details/113575167
推荐阅读
  • 利用决策树预测NBA比赛胜负的Python数据挖掘实践
    本文通过使用2013-14赛季NBA赛程与结果数据集以及2013年NBA排名数据,结合《Python数据挖掘入门与实践》一书中的方法,展示如何应用决策树算法进行比赛胜负预测。我们将详细讲解数据预处理、特征工程及模型评估等关键步骤。 ... [详细]
  • 探讨ChatGPT在法律和版权方面的潜在风险及影响,分析其作为内容创造工具的合法性和合规性。 ... [详细]
  • 本文详细介绍了如何在Linux系统上安装和配置Smokeping,以实现对网络链路质量的实时监控。通过详细的步骤和必要的依赖包安装,确保用户能够顺利完成部署并优化其网络性能监控。 ... [详细]
  • 本文详细介绍了 Dockerfile 的编写方法及其在网络配置中的应用,涵盖基础指令、镜像构建与发布流程,并深入探讨了 Docker 的默认网络、容器互联及自定义网络的实现。 ... [详细]
  • 本文介绍了如何使用JQuery实现省市二级联动和表单验证。首先,通过change事件监听用户选择的省份,并动态加载对应的城市列表。其次,详细讲解了使用Validation插件进行表单验证的方法,包括内置规则、自定义规则及实时验证功能。 ... [详细]
  • 自己用过的一些比较有用的css3新属性【HTML】
    web前端|html教程自己用过的一些比较用的css3新属性web前端-html教程css3刚推出不久,虽然大多数的css3属性在很多流行的浏览器中不支持,但我个人觉得还是要尽量开 ... [详细]
  • Unity编辑器插件:NGUI资源引用检测工具
    本文介绍了一款基于NGUI的资源引用检测工具,该工具能够帮助开发者快速查找和管理项目中的资源引用。其功能涵盖Atlas/Sprite、字库、UITexture及组件的引用检测,并提供了替换和修复功能。文末提供源码下载链接。 ... [详细]
  • 社交网络中的级联行为 ... [详细]
  • 本文探讨了如何在 F# Interactive (FSI) 中通过 AddPrinter 和 AddPrintTransformer 方法自定义类型(尤其是集合类型)的输出格式,提供了详细的指南和示例代码。 ... [详细]
  • 历经三十年的开发,Mathematica 已成为技术计算领域的标杆,为全球的技术创新者、教育工作者、学生及其他用户提供了一个领先的计算平台。最新版本 Mathematica 12.3.1 增加了多项核心语言、数学计算、可视化和图形处理的新功能。 ... [详细]
  • 本文将详细探讨 Java 中提供的不可变集合(如 `Collections.unmodifiableXXX`)和同步集合(如 `Collections.synchronizedXXX`)的实现原理及使用方法,帮助开发者更好地理解和应用这些工具。 ... [详细]
  • 理解与应用:独热编码(One-Hot Encoding)
    本文详细介绍了独热编码(One-Hot Encoding)与哑变量编码(Dummy Encoding)两种方法,用于将分类变量转换为数值形式,以便于机器学习算法处理。文章不仅解释了这两种编码方式的基本原理,还探讨了它们在实际应用中的差异及选择依据。 ... [详细]
  • 优化ListView性能
    本文深入探讨了如何通过多种技术手段优化ListView的性能,包括视图复用、ViewHolder模式、分批加载数据、图片优化及内存管理等。这些方法能够显著提升应用的响应速度和用户体验。 ... [详细]
  • 导航栏样式练习:项目实例解析
    本文详细介绍了如何创建一个具有动态效果的导航栏,包括HTML、CSS和JavaScript代码的实现,并附有详细的说明和效果图。 ... [详细]
  • 一个登陆界面
    预览截图html部分123456789101112用户登入1314邮箱名称邮箱为空15密码密码为空16登 ... [详细]
author-avatar
Rain雨露Dew
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有