作者:ReMadrism_FaithlU9D_1990 | 来源:互联网 | 2023-08-26 20:42
很多小伙伴在刚刚结束深度学习算法的时候,肯定想用自己的数据来进行训练网络,但是不知到怎么写代码,下面这个代码就会为你解惑,自己可以根据实际情况来更改代码,训练自己的图片数据集。
下面我用蚂蚁和蜜蜂数据集为例,我的数据格式是这样的,如下图:
![](https://img5.php1.cn/3cdc5/92e2/3b4/fb0e400e01dc55f3.png)
![](https://img5.php1.cn/3cdc5/92e2/3b4/6e0b87bd3084e57a.png)
![](https://img5.php1.cn/3cdc5/92e2/3b4/0628ee15ed2f5b05.png)
每个类别都会有相应的图片
from torch.utils.data import Dataset,DataLoader
import os
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
transform = transforms.Compose([transforms.Resize([500, 500]), # 图像预处理transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])])
class_list={0:"ants",1:"bees"}#用于后续预测的时候可以使用,用预测到的标签来直接获取相应的类别
class Mydataset(Dataset):def __init__(self,file_path="D:/PycharmProjects/pythonProject/classification-pytorch-main1/datasets",formate="train",transform=False):self.transforms=transformself.file_path=file_pathself.formate=formateself.file_train=os.path.join(self.file_path,self.formate)print(self.file_train)files_class = os.listdir(self.file_train)self.imgs=[]for i, j in enumerate(files_class):data = os.path.join(self.file_train, j)print(data)data_1 = os.listdir(data)data_all = [[os.path.join(data, k), i] for k in data_1]self.imgs += data_allprint(self.imgs)def __len__(self):return len(self.imgs)def __getitem__(self, index):img_path, label = self.imgs[index] # 选择文件路径pil_img = Image.open(img_path).convert('RGB') # 利用PIL打开文件路径if self.transforms:img=transform(pil_img)else:pil_img = np.asarray(pil_img)img = torch.from_numpy(pil_img)return img, labelif __name__ == '__main__':train_data= Mydataset(transform=True)print(train_data.__getitem__(0)[0])print(train_data.__getitem__(0)[1])#验证能否传进模型中train_dataloder=DataLoader(train_data,batch_size=8,shuffle=True)for data in train_dataloder:print(data[0].shape)print(data[1])#结果不唯一,其中的结果如下:# torch.Size([8, 3, 500, 500])# tensor([1, 0, 0, 1, 1, 0, 1, 1])break
如果觉得有帮助,就点个赞吧,祝大家学业有成!