热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch之Dataset和TensorDataset

DeepLearning系列@cxxDatasetv.s.

Deep Learning系列 @cxx

Dataset v.s. TensorDataset

使用PyTorch搭建过Neural Network的小伙伴们都知道,在数据准备步骤里,我们需要把训练集的x和y分装在dataset里,然后将dataset分装到DataLoader中去,便于之后在搭建好的模型中训练。
简言之,dataset是用来做打包和预处理(比如输入资料路径自动读取);DataLoader则是将整个资料集(dataset)按照batch进行迭代分装或者shuffle(可以得到一个iterator以利于for循环读取)。

Dataset

如果使用继承Dataset的方式,那么在自定义的dataset类中必须给予__len__和__getitem__的定义。
进行图片处理的时候,可以定义一个transforms来随机旋转训练图片,将图片格式变成tensor等
(这里有一个坑)

假设我们读取了一个有如下格式的图片
在这里插入图片描述
将图片分装到dataset里,再放到dataloader里

from torch.utils.data import TensorDataset
batch_size = 128
train_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),]
)
test_transform = transforms.Compose(
[transforms.ToPILImage(),
transforms.ToTensor(),]
) #测试集不需要翻转或旋转图片
#继承Dataset
class ImgDataset(Dataset):
def __init__(self, x, y=None, transform=None):
self.x = x
self.y = y
# label is required to be a LongTensor
if y is not None:
self.y = torch.LongTensor(y)
self.transform = transform
def __len__(self):
return len(self.x)
def __getitem__(self, index):
X = self.x[index]
if self.transform is not None:
X = self.transform(X)
if self.y is not None:
Y = self.y[index]
return X, Y
else:
return X


#将dataset分装到dataloader里
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True
)
test_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False
)

接下来我们可以输出一个batch看看图片的格式
在这里插入图片描述
我们发现一个batch的x[0]的shape由原先的(128, 128, 3)变成了(3, 128, 128)。
原因在于transformers.toTensor()方法有自动转换维度的功能,它会将channel变成第一维(夺么坑爹的功能,导致我排查了好久不知道是哪里出了问题==)
具体可以参照这篇博客transforms.ToTensor()本身有维度转换功能

TensorDataset

张量资料集tensrdataset是最常见的形式,因为PyTorch本身有提供方便的TensorDataset给我们使用

torch.utils.data.TensorDataset(data_tensor, target_tensor)

用TensorDataset写会少写很多东西

#将资料转换成tensor
tsr_x_train, tsr_y_train = torch.tensor(x_train), torch.tensor(y_train)
tsr_x_val, tsr_y_val = torch.tensor(x_val), torch.tensor(y_val)
tsr_x_testing = torch.tensor(x_test)
#然后只需要一行就可以啦
train_dataset = TensorDataset(tsr_x_train, tsr_y_train)
val_dataset = TensorDataset(tsr_x_val, tsr_y_val)
#装入dataloader的步骤同上
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True
)
test_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False
)

我们跑一个loop看看这次维度是否被转换了
在这里插入图片描述
答案是:这次没有!
这次的x[0]的shape同我们一开始设置的shape,TensorDataset并没有帮我们把channel数调成第一维
这里真的要注意呀。


版权声明:本文为qq_43611080原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。
原文链接:https://blog.csdn.net/qq_43611080/article/details/113575167
推荐阅读
  • 语义分割系列3SegNet(pytorch实现)
    SegNet手稿最早是在2015年12月投出,和FCN属于同时期作品。稍晚于FCN,既然属于后来者,又是与FCN同属于语义分割网络 ... [详细]
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • 开发笔记:加密&json&StringIO模块&BytesIO模块
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了加密&json&StringIO模块&BytesIO模块相关的知识,希望对你有一定的参考价值。一、加密加密 ... [详细]
  • 使用Ubuntu中的Python获取浏览器历史记录原文: ... [详细]
  • 本文介绍了计算机网络的定义和通信流程,包括客户端编译文件、二进制转换、三层路由设备等。同时,还介绍了计算机网络中常用的关键词,如MAC地址和IP地址。 ... [详细]
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • 本文介绍了使用readlink命令获取文件的完整路径的简单方法,并提供了一个示例命令来打印文件的完整路径。共有28种解决方案可供选择。 ... [详细]
  • 本文介绍了C#中生成随机数的三种方法,并分析了其中存在的问题。首先介绍了使用Random类生成随机数的默认方法,但在高并发情况下可能会出现重复的情况。接着通过循环生成了一系列随机数,进一步突显了这个问题。文章指出,随机数生成在任何编程语言中都是必备的功能,但Random类生成的随机数并不可靠。最后,提出了需要寻找其他可靠的随机数生成方法的建议。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 在重复造轮子的情况下用ProxyServlet反向代理来减少工作量
    像不少公司内部不同团队都会自己研发自己工具产品,当各个产品逐渐成熟,到达了一定的发展瓶颈,同时每个产品都有着自己的入口,用户 ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • ASP.NET2.0数据教程之十四:使用FormView的模板
    本文介绍了在ASP.NET 2.0中使用FormView控件来实现自定义的显示外观,与GridView和DetailsView不同,FormView使用模板来呈现,可以实现不规则的外观呈现。同时还介绍了TemplateField的用法和FormView与DetailsView的区别。 ... [详细]
  • web.py开发web 第八章 Formalchemy 服务端验证方法
    本文介绍了在web.py开发中使用Formalchemy进行服务端表单数据验证的方法。以User表单为例,详细说明了对各字段的验证要求,包括必填、长度限制、唯一性等。同时介绍了如何自定义验证方法来实现验证唯一性和两个密码是否相等的功能。该文提供了相关代码示例。 ... [详细]
  • 合并列值-合并为一列问题需求:createtabletab(Aint,Bint,Cint)inserttabselect1,2,3unionallsel ... [详细]
author-avatar
Rain雨露Dew
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有