热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch制作图片数据集

PyTorch制作图片数据集图片的处理os.listdir()函数path.datatraining_settraining_setcatsos.listdir

PyTorch制作图片数据集


图片的处理


os.listdir()函数

>>> path = './data/training_set/training_set/cats'
>>> os.listdir(path)
['cat.1.jpg', 'cat.10.jpg', 'cat.100.jpg', 'cat.1000.jpg', 'cat.1001.jpg', 'cat.1002.jpg', 'cat.1003.jpg', 'cat.1004.jpg', 'cat.1005.jpg', 'cat.1006.jpg', 'cat.1007.jpg', 'cat.1008.jpg', 'cat.1009.jpg', 'cat.101.jpg', 'cat.1010.jpg', 'cat.1011.jpg', 'cat.1012.jpg', 'cat.1013.jpg', 'cat.1014.jpg', 'cat.1015.jpg', '
cat.1016.jpg', 'cat.1017.jpg', 'cat.1018.jpg', 'cat.1019.jpg', 'cat.102.jpg', 'cat.1020.jpg', 'cat.1021.jpg', 'cat.1022.jpg', 'cat.1023.jpg', 'cat.1024.jpg', 'cat.1025.jpg', 'cat.1026.jpg', 'cat.1027.jpg', 'cat.1028.jpg', 'cat.1029.jpg', 'cat.103.jpg', 'cat.1030.jpg', 'cat.1031.jpg', 'cat.1032.jpg', 'cat.1033.jpg'
, 'cat.1034.jpg', 'cat.1035.jpg', 'cat.1036.jpg', 'cat.1037.jpg', 'cat.1038.jpg', 'cat.1039.jpg', 'cat.104.jpg', 'cat.1040.jpg', 'cat.1041.jpg', 'cat.1042.jpg', 'cat.1043.jpg', 'cat.1044.jpg', 'cat.1045.jpg', 'cat.1046.jpg', 'cat.1047.jpg', 'cat.1048.jpg', 'cat.1049.jpg', 'cat.105.jpg', 'cat.1050.jpg', 'cat.1051.j
pg', 'cat.1052.jpg', 'cat.1053.jpg', 'cat.1054.jpg', 'cat.1055.jpg', 'cat.1056.jpg', 'cat.1057.jpg', 'cat.1058.jpg', 'cat.1059.jpg', 'cat.106.jpg', 'cat.1060.jpg', 'cat.1061.jpg', 'cat.1062.jpg', 'cat.1063.jpg', 'cat.1064.jpg', 'cat.1065.jpg', 'cat.1066.jpg', 'cat.1067.jpg', 'cat.1068.jpg', 'cat.1069.jpg', 'cat.10
7.jpg', 'cat.1070.jpg', 'cat.1071.jpg', 'cat.1072.jpg', 'cat.1073.jpg', 'cat.1074.jpg', 'cat.1075.jpg', 'cat.1076.jpg', 'cat.1077.jpg', 'cat.1078.jpg', 'cat.1079.jpg', 'cat.108.jpg', 'cat.1080.jpg', 'cat.1081.jpg', 'cat.1082.jpg', 'cat.1083.jpg', 'cat.1084.jpg', 'cat.1085.jpg', 'cat.1086.jpg', 'cat.1087.jpg', 'cat
.1088.jpg', 'cat.1089.jpg', 'cat.109.jpg', 'cat.1090.jpg', 'cat.1091.jpg', 'cat.1092.jpg', 'cat.1093.jpg', 'cat.1094.jpg', 'cat.1095.jpg', 'cat.1096.jpg', 'cat.1097.jpg', 'cat.1098.jpg', 'cat.1099.jpg', 'cat.11.jpg', 'cat.110.jpg', 'cat.1100.jpg', 'cat.1101.jpg', 'cat.1102.jpg', 'cat.1103.jpg', 'cat.1104.jpg', '...]

可以看出作用是返回path路径文件夹下所有文件名的列表


torchvision的transform模块

>>> import torch
>>> from PIL import Image
>>> from torchvision import transforms
>>> img = Image.open('./data/training_set/training_set/cats/cat.1.jpg')
>>> img.size
(300, 280)
>>> transforms.Resize(256)(img).size # 比例改变大小
(274, 256)
>>> transforms.Resize([256,256])(img).size # 结果为图-比例缩小
(256, 256)
>>> transforms.RandomResizedCrop(256)(img).size # 随机切为目标大小
(256, 256)
>>> transforms.RandomSizedCrop([256,200])(img).size # 随机切2
(200, 256)
>>> transforms.Pad(20,1)(img).size # 第一个参数为填充大小,第二个参数为填充值
(340, 320)
>>> transforms.CenterCrop(256)(img).size # 中心切
(256, 256)

原图:
在这里插入图片描述
比例缩小:
在这里插入图片描述

随机切:
在这里插入图片描述
随机切2:
在这里插入图片描述
中心切:
在这里插入图片描述

填充:
在这里插入图片描述

>>> transforms.ToTensor()(img).shape
torch.Size([3, 280, 300])

transforms.ToTensor()函数可以将PIL.Image对象转换为tensor对象
维度由[H, W, C]转为[C, H, W]
H:Height
W:Width
C:Channel
上面的函数一般在ToTensor函数之前进行,下面的函数对tensor进行操作,一般在ToTensor之后进行

>>> imgt = transforms.ToTensor()(img)
>>> imgt.shape
torch.Size([3, 280, 300])
>>> transforms.ToPILImage()(imgt).size # tensor转化为PILImage
(300, 280)
>>> m = imgt.mean(axis = [1,2]) # 求均值,axis=[1,2]可以看作关闭了后两个维度,即在后两个维度上求均值
>>> m
tensor([0.3089, 0.2677, 0.2665])
>>> s = imgt.std(axis = [1,2]) # 求标准差
>>> s
tensor([0.1619, 0.1375, 0.1380])
>>> imgtn = transforms.Normalize(m,s)(imgt) # 对tensor标准化
>>> imgtn.shape
torch.Size([3, 280, 300])
>>> imgtnt = torch.zeros(3,280,300)
>>> imgtnt[0] = (imgt[0]-m[0])/s[0]
>>> imgtnt[1] = (imgt[1]-m[1])/s[1]
>>> imgtnt[2] = (imgt[2]-m[2])/s[2]
>>> imgtnt == imgtn # 标准化过程就是对每个通道output = (input-mean)/std
tensor([[[True, True, True, ..., True, True, True],[True, True, True, ..., True, True, True],[True, True, True, ..., True, True, True],...,[True, True, True, ..., True, True, True],[True, True, True, ..., True, True, True],[True, True, True, ..., True, True, True]],[[True, True, True, ..., True, True, True],[True, True, True, ..., True, True, True],[True, True, True, ..., True, True, True],...,[True, True, True, ..., True, True, True],[True, True, True, ..., True, True, True],[True, True, True, ..., True, True, True]],[[True, True, True, ..., True, True, True],[True, True, True, ..., True, True, True],[True, True, True, ..., True, True, True],...,[True, True, True, ..., True, True, True],[True, True, True, ..., True, True, True],[True, True, True, ..., True, True, True]]])

多种变化的组合transforms.Compose()

data_transforms = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(), # 以0.5的概率随机水平翻转transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

Dataset类的重写

Dataset类源码

class Dataset(object):"""An abstract class representing a Dataset.All other datasets should subclass it. All subclasses should override``__len__``, that provides the size of the dataset, and ``__getitem__``,supporting integer indexing in range from 0 to len(self) exclusive."""def __getitem__(self, index):raise NotImplementedErrordef __len__(self):raise NotImplementedErrordef __add__(self, other):return ConcatDataset([self, other])

由上可以看出__getitem__和__len__方法必须重写
__len__方法用于使用 len(Dataset) 函数
__getitem__方法用于 Dataset[n] 操作
模板可以参考

from torch.utils.data import Dataset
import pandas as pd # pandas库提供了读取csv文件的函数read_csv()class myDataset(Dataset): # 定义自己的数据类myDataset,继承的抽象类Datasetdef __init__(self, csv_file, txt_file,root_dir,other_file): # csv_file:抽象的表示.csv文件;txt_file:抽象的表示txt文件;# root_dir:地址,这些参数放在初始化函数里self.csv_data= pd.read_csv(csv_file) # 读取csv文件,并且赋给他本身with open(txt_file,'r') as f: # 读取txt文件,并且赋给他本身,读取的方式为:with open(...) as f:data_list = f.readlines() # 读取每一行数据,并且放到data_list里self.txt_data = data_listself.root_dir = root_dir# 实现下面这个方法: def __len__(self): # 定义自己的数据类,必须重写这个方法(函数)return len(self.csv_data) # 返回的数据的长度def __getitem__(self, idx): # 定义自己的数据类,必须重写这个方法(函数)data = (self.csv_data[idx],self.txt_data[idx]) # 获取数据的方式,按照索引进行的 return data

数据的批量读取

torch.utils.data已经提供的类:Dataset,但是通过这种方式只能一个个的数据的把数据全部读出来,定义了数据读取的方式,不能实现** 批量**的把数据读取出来,为此pytorch有提供了一个方法:DataLoader(),它的参数如下:

from torch.utils.data import DataLoaderdataiter = DataLoader(myDataset,batch_size=32,shuffle=True,collate_fn=default_collate)

myDatase:上面自己定义的数据类
batch_size=32:实现批量读取数据,比如一次取32个数据
shuffle=True:将顺序打乱
collate_fn:表示的是如何读取样本,可以自己定义函数来准确的说明想要实现的功能。
该段来源于该博客

以下是kaggle平台上猫狗识别数据集的制作

import numpy as np
from torchvision import transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset, ConcatDataset
import osclass catdogDataset(Dataset):def __init__(self, file_list, dir, mode='train', transform=None):self.file_list = file_listself.dir = dirself.mode = modeself.transform = transformif 'dog' in self.file_list[0]:self.label = 1 # 标签为1时为狗,标签为0是为猫else:self.label = 0def __len__(self):return len(self.file_list)def __getitem__(self, idx):img = Image.open(os.path.join(self.dir, self.file_list[idx]))if self.transform:img = self.transform(img)if self.mode == 'train':img = np.array(img)return img.astype('float32'), self.label # 此处当模式为训练时,返回图片的32位浮点array和它的标签else:img = np.array(img)return img.astype('float32'), self.file_list[idx]datapathtra = './data/training_set/training_set'
cat_file_list = os.listdir(os.path.join(datapathtra, 'cats'))
dog_file_list = os.listdir(os.path.join(datapathtra, 'dogs'))data_transform = transforms.Compose([transforms.Resize([256, 256]),transforms.RandomHorizontalFlip(),transforms.ToTensor()
])catDataset = catdogDataset(cat_file_list, os.path.join(datapathtra, 'cats'), transform=data_transform)
dogDataset = catdogDataset(dog_file_list, os.path.join(datapathtra, 'dogs'), transform=data_transform)totalDataset = ConcatDataset([catDataset, dogDataset]) # 聚合两个数据集
print(totalDataset[5]) # 此处为__getitem__的用法
print(len(totalDataset)) # 此处为__len__的用法

推荐阅读
  • 很多时候在注册一些比较重要的帐号,或者使用一些比较重要的接口的时候,需要使用到随机字符串,为了方便,我们设计这个脚本需要注意 ... [详细]
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
  • 本文由编程笔记#小编为大家整理,主要介绍了logistic回归(线性和非线性)相关的知识,包括线性logistic回归的代码和数据集的分布情况。希望对你有一定的参考价值。 ... [详细]
  • IjustinheritedsomewebpageswhichusesMooTools.IneverusedMooTools.NowIneedtoaddsomef ... [详细]
  • Python教学练习二Python1-12练习二一、判断季节用户输入月份,判断这个月是哪个季节?3,4,5月----春 ... [详细]
  • 本文介绍了在Python3中如何使用选择文件对话框的格式打开和保存图片的方法。通过使用tkinter库中的filedialog模块的asksaveasfilename和askopenfilename函数,可以方便地选择要打开或保存的图片文件,并进行相关操作。具体的代码示例和操作步骤也被提供。 ... [详细]
  • 本文分享了一个关于在C#中使用异步代码的问题,作者在控制台中运行时代码正常工作,但在Windows窗体中却无法正常工作。作者尝试搜索局域网上的主机,但在窗体中计数器没有减少。文章提供了相关的代码和解决思路。 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • 本文介绍了为什么要使用多进程处理TCP服务端,多进程的好处包括可靠性高和处理大量数据时速度快。然而,多进程不能共享进程空间,因此有一些变量不能共享。文章还提供了使用多进程实现TCP服务端的代码,并对代码进行了详细注释。 ... [详细]
  • 本文介绍了如何使用python从列表中删除所有的零,并将结果以列表形式输出,同时提供了示例格式。 ... [详细]
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • 本文介绍了Python爬虫技术基础篇面向对象高级编程(中)中的多重继承概念。通过继承,子类可以扩展父类的功能。文章以动物类层次的设计为例,讨论了按照不同分类方式设计类层次的复杂性和多重继承的优势。最后给出了哺乳动物和鸟类的设计示例,以及能跑、能飞、宠物类和非宠物类的增加对类数量的影响。 ... [详细]
  • Day2列表、字典、集合操作详解
    本文详细介绍了列表、字典、集合的操作方法,包括定义列表、访问列表元素、字符串操作、字典操作、集合操作、文件操作、字符编码与转码等内容。内容详实,适合初学者参考。 ... [详细]
  • 开源Keras Faster RCNN模型介绍及代码结构解析
    本文介绍了开源Keras Faster RCNN模型的环境需求和代码结构,包括FasterRCNN源码解析、RPN与classifier定义、data_generators.py文件的功能以及损失计算。同时提供了该模型的开源地址和安装所需的库。 ... [详细]
  • Python使用Pillow包生成验证码图片的方法
    本文介绍了使用Python中的Pillow包生成验证码图片的方法。通过随机生成数字和符号,并添加干扰象素,生成一幅验证码图片。需要配置好Python环境,并安装Pillow库。代码实现包括导入Pillow包和随机模块,定义随机生成字母、数字和字体颜色的函数。 ... [详细]
author-avatar
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有