热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

聊聊Pytorch的dataloader

点击上方“机器学习与生成对抗网络”,关注星标获取有趣、好玩的前沿干货!来源:知乎—Mario编辑人工智能前沿讲习地址:htt

点击上方“机器学习与生成对抗网络”,关注星标

获取有趣、好玩的前沿干货!

来源:知乎—Mario 编辑 人工智能前沿讲习

地址:https://zhuanlan.zhihu.com/p/117270644

为啥突然要写一下pytorch的dataloader呢,首先来说说事情的来龙去脉。

起初,我最开始单独训练一个网络来完成landmark点回归任务和分类任务,训练的数据是txt格式,在训练之前对数据进行分析,发现分类任务中存在严重的数据样本不均衡的问题,那么我事先针对性的进行数据采样均衡操作,重新得到训练和测试的txt数据和标签,保证了整个训练和测试数据的样本均衡性。由于我的整个项目是检测+点回归+分类,起初检测和点回归+分类是分两步实现的,检测是通过读取XML格式来进行训练,现在要统一整个项目的训练和测试过程,要将点回归+分类的训练测试过程也按照读取XML格式来进行,那么就遇到一个问题,如何针对性的去给样本偏少的样本进行均衡,由于在dataset类中,返回的图像和标签都是针对每个index返回一个结果,在dataset类中进行操作似乎不太可行,那么就想到在dataloader中进行操作,通过dataloader中的参数sample来完成针对性采样。

还有一个问题是关于num_workers的设置,因为我有对比过,在我的单机RTX 2080Ti上和八卡服务器TITAN RTX上(仅使用单卡,其它卡有在跑其它任务),使用相同的num_workers,在单机上的训练速度反而更快,于是猜想可能和CPU或者内存有关系,下面会具体分析。

首先来看下下dataloader中的各个参数的含义。

类的定义为:torch.utils.data.DataLoader ,其中包含的参数有:

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, \batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, \drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

dataset:定义的dataset类返回的结果。

batchsize:每个bacth要加载的样本数,默认为1。

shuffle:在每个epoch中对整个数据集data进行shuffle重排,默认为False。

sample:定义从数据集中加载数据所采用的策略,如果指定的话,shuffle必须为False;batch_sample类似,表示一次返回一个batch的index。

num_workers:表示开启多少个线程数去加载你的数据,默认为0,代表只使用主进程。

collate_fn:表示合并样本列表以形成小批量的Tensor对象。

pin_memory:表示要将load进来的数据是否要拷贝到pin_memory区中,其表示生成的Tensor数据是属于内存中的锁页内存区,这样将Tensor数据转义到GPU中速度就会快一些,默认为False。

drop_last:当你的整个数据长度不能够整除你的batchsize,选择是否要丢弃最后一个不完整的batch,默认为False。

注:这里简单科普下pin_memory,通常情况下,数据在内存中要么以锁页的方式存在,要么保存在虚拟内存(磁盘)中,设置为True后,数据直接保存在锁页内存中,后续直接传入cuda;否则需要先从虚拟内存中传入锁页内存中,再传入cuda,这样就比较耗时了,但是对于内存的大小要求比较高。

下面针对num_workers,sample和collate_fn分别进行说明:

01

设置num_workers

pytorch中dataloader一次性创建num_workers个子线程,然后用batch_sampler将指定batch分配给指定worker,worker将它负责的batch加载进RAM,dataloader就可以直接从RAM中找本轮迭代要用的batch。如果num_worker设置得大,好处是寻batch速度快,因为下一轮迭代的batch很可能在上一轮/上上一轮...迭代时已经加载好了。坏处是内存开销大,也加重了CPU负担(worker加载数据到RAM的进程是进行CPU复制)。如果num_worker设为0,意味着每一轮迭代时,dataloader不再有自主加载数据到RAM这一步骤,只有当你需要的时候再加载相应的batch,当然速度就更慢。num_workers的经验设置值是自己电脑/服务器的CPU核心数,如果CPU很强、RAM也很充足,就可以设置得更大些,对于单机来说,单跑一个任务的话,直接设置为CPU的核心数最好。

02

定义sample:(假设dataset类返回的是:data, label)

from torch.utils.data.sampler import WeightedRandomSampler
## 如果label为1,那么对应的该类别被取出来的概率是另外一个类别的2倍
weights = [2 if label == 1 else 1 for data, label in dataset]
sampler = WeightedRandomSampler(weights,num_samples=10, replacement=True)
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)

PyTorch中提供的这个sampler模块,用来对数据进行采样。默认采用SequentialSampler,它会按顺序一个一个进行采样。常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。这里使用另外一个很有用的采样方法:WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。replacement用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。

03

定义collate_fn

def detection_collate(batch):"""Custom collate fn for dealing with batches of images that have a differentnumber of associated object annotations (bounding boxes).Arguments:batch: (tuple) A tuple of tensor images and lists of annotationsReturn:A tuple containing:1) (tensor) batch of images stacked on their 0 dim2) (list of tensors) annotations for a given image are stacked on0 dim"""targets = []imgs = []for sample in batch:imgs.append(sample[0])targets.append(torch.FloatTensor(sample[1]))return torch.stack(imgs, 0), targets

使用dataloader时加入collate_fn参数,即可合并样本列表以形成小批量的Tensor对象,如果你的标签不止一个的话,还可以支持自定义,在上述方法中再额外添加对应的label即可。

data_loader = torch.utils.data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers, sampler=sampler, shuffle=False,     collate_fn=detection_collate, pin_memory=True, drop_last=True)
参考链接:
https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader
https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader
猜您喜欢:
超100篇!CVPR 2020最全GAN论文梳理汇总!拆解组新的GAN:解耦表征MixNMatchStarGAN第2版:多域多样性图像生成
附下载 | 《可解释的机器学习》中文版附下载 |《TensorFlow 2.0 深度学习算法实战》附下载 |《计算机视觉中的数学方法》分享《基于深度学习的表面缺陷检测方法综述》《零样本图像分类综述: 十年进展》《基于深度神经网络的少样本学习综述》


推荐阅读
  • 在Conda环境中高效配置并安装PyTorch和TensorFlow GPU版的方法如下:首先,创建一个新的Conda环境以避免与基础环境发生冲突,例如使用 `conda create -n pytorch_gpu python=3.7` 命令。接着,激活该环境,确保所有依赖项都正确安装。此外,建议在安装过程中指定CUDA版本,以确保与GPU兼容性。通过这些步骤,可以确保PyTorch和TensorFlow GPU版的顺利安装和运行。 ... [详细]
  • 【图像分类实战】利用DenseNet在PyTorch中实现秃头识别
    本文详细介绍了如何使用DenseNet模型在PyTorch框架下实现秃头识别。首先,文章概述了项目所需的库和全局参数设置。接着,对图像进行预处理并读取数据集。随后,构建并配置DenseNet模型,设置训练和验证流程。最后,通过测试阶段验证模型性能,并提供了完整的代码实现。本文不仅涵盖了技术细节,还提供了实用的操作指南,适合初学者和有经验的研究人员参考。 ... [详细]
  • RTThread线程间通信
    线程中通信在裸机编程中,经常会使用全局变量进行功能间的通信,如某些功能可能由于一些操作而改变全局变量的值,另一个功能对此全局变量进行读取& ... [详细]
  • 1、形成邻居条件:1)区域ID相同;2)hello,dead时间一致;3)认证&# ... [详细]
  • 目录预备知识导包构建数据集神经网络结构训练测试精度可视化计算模型精度损失可视化输出网络结构信息训练神经网络定义参数载入数据载入神经网络结构、损失及优化训练及测试损失、精度可视化qu ... [详细]
  • JUC(三):深入解析AQS
    本文详细介绍了Java并发工具包中的核心类AQS(AbstractQueuedSynchronizer),包括其基本概念、数据结构、源码分析及核心方法的实现。 ... [详细]
  • 用阿里云的免费 SSL 证书让网站从 HTTP 换成 HTTPS
    HTTP协议是不加密传输数据的,也就是用户跟你的网站之间传递数据有可能在途中被截获,破解传递的真实内容,所以使用不加密的HTTP的网站是不 ... [详细]
  • 最详尽的4K技术科普
    什么是4K?4K是一个分辨率的范畴,即40962160的像素分辨率,一般用于专业设备居多,目前家庭用的设备,如 ... [详细]
  • 本文详细介绍了 InfluxDB、collectd 和 Grafana 的安装与配置流程。首先,按照启动顺序依次安装并配置 InfluxDB、collectd 和 Grafana。InfluxDB 作为时序数据库,用于存储时间序列数据;collectd 负责数据的采集与传输;Grafana 则用于数据的可视化展示。文中提供了 collectd 的官方文档链接,便于用户参考和进一步了解其配置选项。通过本指南,读者可以轻松搭建一个高效的数据监控系统。 ... [详细]
  • 利用REM实现移动端布局的高效适配技巧
    在移动设备上实现高效布局适配时,使用rem单位已成为一种流行且有效的技术。本文将分享过去一年中使用rem进行布局适配的经验和心得。rem作为一种相对单位,能够根据根元素的字体大小动态调整,从而确保不同屏幕尺寸下的布局一致性。通过合理设置根元素的字体大小,开发者可以轻松实现响应式设计,提高用户体验。此外,文章还将探讨一些常见的问题和解决方案,帮助开发者更好地掌握这一技术。 ... [详细]
  • 在Windows系统中安装TensorFlow GPU版的详细指南与常见问题解决
    在Windows系统中安装TensorFlow GPU版是许多深度学习初学者面临的挑战。本文详细介绍了安装过程中的每一个步骤,并针对常见的问题提供了有效的解决方案。通过本文的指导,读者可以顺利地完成安装并避免常见的陷阱。 ... [详细]
  • 能够感知你情绪状态的智能机器人即将问世 | 科技前沿观察
    本周科技前沿报道了多项重要进展,包括美国多所高校在机器人技术和自动驾驶领域的最新研究成果,以及硅谷大型企业在智能硬件和深度学习技术上的突破性进展。特别值得一提的是,一款能够感知用户情绪状态的智能机器人即将问世,为未来的人机交互带来了全新的可能性。 ... [详细]
  • 在运行于MS SQL Server 2005的.NET 2.0 Web应用中,我偶尔会遇到令人头疼的SQL死锁问题。过去,我们主要通过调整查询来解决这些问题,但这既耗时又不可靠。我希望能找到一种确定性的查询模式,确保从设计上彻底避免SQL死锁。 ... [详细]
  • 深入理解Java多线程与并发机制
    本文探讨了Java多线程和并发机制的核心概念,包括多线程类的分类、执行器框架、并发容器及控制工具。通过详细解析这些组件,帮助开发者更好地理解和应用多线程技术。 ... [详细]
  • 为什么多数程序员难以成为架构师?
    探讨80%的程序员为何难以晋升为架构师,涉及技术深度、经验积累和综合能力等方面。本文将详细解析Tomcat的配置和服务组件,帮助读者理解其内部机制。 ... [详细]
author-avatar
思南sn99
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有