热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

【PyTorch】使用DataLoader自定义数据集读取

【PyTorch】使用DataLoader自定义数据集读取为了方便之后使用PyTorch的distributed部署,加速训练,将数据读取的方式改为适
【PyTorch】使用DataLoader自定义数据集读取

为了方便之后使用PyTorch的distributed部署,加速训练,将数据读取的方式改为适配pytorch提供的Dataset和DataLoader的方式。这里记录一下修改的要点:

1. 涉及的import库:

import torch
from torch.utils.data import Dataset, DataLoader

2. 自定义一个Dataset类:


  • 该类继承Dataset;

  • 可以定义若干个数据预处理的函数,关键的两个函数是:__len__()__getitem__();

  • __getitem__()实际是python支持的一个迭代器函数,编写时每次返回一个sample,不需要定义batch size,之后的DataLoader会自动帮忙读取数据组成batch的;

  • 举个栗子:

    class MyDataset(Dataset):def __init__(self,data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self):return self.datadef output(self):print('output')


3. 初始化Dataset和DataLoader类:


  • DataLoader的参数可参考:https://blog.csdn.net/zyq12345678/article/details/90268668

  • 注意,如果在Dataset中每次返回的是自己定义的数据类型,或者是字典类型,有时要自己编写collate_fn()函数,告诉系统如何返回一个batch。

  • 举个栗子:

    dataset = MyDataset(data)
    dataloader = DataLoader(dataset,batch_size = 2,num_workers = 8,collate_fn = collate_fn,pin_memory = True
    )
    # 返回数据结构较复杂,包括自定义数据类型或字典时
    def collate_fn(batch):data = list(batch)return (data)

  • 如果遇到类似报错:

    TypeError: can't pickle _thread._local objects

    请将DataLoader中的num_workers参数设置为0,关闭多线程。原因可能是无法自动多线程处理复杂的数据类型。


4. 访问Dataloader内的Dataset类函数


  • 举个栗子:

for step, batch in enumerate(dataloader):dataloader.dataset.output()


推荐阅读
author-avatar
只爱裙装
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有