发现 DataLoader 在不同的 pytorch 版本上,执行 dataset 的__item__ 会返回不同的效果。
pytorch 在1.12.1 上,每一次迭代会返回BatchEncoding这个类型(可能会比这个版本低也返回)
class MyDataSet(Dataset):
#1.12.1 版本
def __init__(self, path):
self.text_list = []
self.label_list = []
self.tokenizer = BertTokenizer.from_pretrained(bert_path)
pass
def __getitem__(self, item):
return self.text_list[item], self.label_list[item]
def __len__(self):
return len(self.text_list)
而在老旧的1.8 版本中,会返回为一个dict对象。这是我在不同平台上我唯一能发现的不同的地方--pytorch的版本号。
这个BatchEncoding 是 HuggingFace 的东西,是一个dict的子类,但是它可以直接被送入cuda加速,而dict不能。那么为了兼容老旧的平台,把它当做一个dict来用,取出它的值再送入cuda中。做开发嘛,遇到不兼容的问题会很常见。