【PyTorch】使用DataLoader自定义数据集读取
为了方便之后使用PyTorch的distributed部署,加速训练,将数据读取的方式改为适配pytorch提供的Dataset和DataLoader的方式。这里记录一下修改的要点:
1. 涉及的import库:
import torch
from torch.utils.data import Dataset, DataLoader
2. 自定义一个Dataset类:
-
该类继承Dataset;
-
可以定义若干个数据预处理的函数,关键的两个函数是:__len__()
和__getitem__()
;
-
__getitem__()
实际是python支持的一个迭代器函数,编写时每次返回一个sample,不需要定义batch size,之后的DataLoader会自动帮忙读取数据组成batch的;
-
举个栗子:
class MyDataset(Dataset):def __init__(self,data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self):return self.datadef output(self):print('output')
3. 初始化Dataset和DataLoader类:
-
DataLoader的参数可参考:https://blog.csdn.net/zyq12345678/article/details/90268668
-
注意,如果在Dataset中每次返回的是自己定义的数据类型,或者是字典类型,有时要自己编写collate_fn()
函数,告诉系统如何返回一个batch。
-
举个栗子:
dataset = MyDataset(data)
dataloader = DataLoader(dataset,batch_size = 2,num_workers = 8,collate_fn = collate_fn,pin_memory = True
)
def collate_fn(batch):data = list(batch)return (data)
-
如果遇到类似报错:
TypeError: can't pickle _thread._local objects
请将DataLoader中的num_workers
参数设置为0,关闭多线程。原因可能是无法自动多线程处理复杂的数据类型。
4. 访问Dataloader内的Dataset类函数
for step, batch in enumerate(dataloader):dataloader.dataset.output()