from torch.utils.data import TensorDataset batch_size =128 train_transform = transforms.Compose([ transforms.ToPILImage(), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(),] ) test_transform = transforms.Compose( [transforms.ToPILImage(), transforms.ToTensor(),] )#测试集不需要翻转或旋转图片 #继承Dataset classImgDataset(Dataset): def__init__(self, x, y=None, transform=None): self.x = x self.y = y # label is required to be a LongTensor if y isnotNone: self.y = torch.LongTensor(y) self.transform = transform def__len__(self): returnlen(self.x) def__getitem__(self, index): X = self.x[index] if self.transform isnotNone: X = self.transform(X) if self.y isnotNone: Y = self.y[index] return X, Y else: return X