热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pytorch——图像预处理模块(Transforms)

transforms运行机制torchvision是pytorch的计算机视觉工具包,在torchvision中有三个主要的模块:torchvisi

transforms运行机制

torchvision是pytorch的计算机视觉工具包,在torchvision中有三个主要的模块:


  • torchvision.transforms,常用的图像预处理方法,在transforms中提供了一系列的图像预处理方法,例如数据的标准化,中心化,旋转,翻转等等;
  • torchvision.datasets,定义了一系列常用的公开数据集的datasets,比如常用的MNIST,CIFAR-10,ImageNet等等;
  • torchvision.model,提供大量常用的预训练模型,例如AlexNet,VGG,ResNet,GoogLeNet等等;

transforms

torchvision.transforms:常用的图像预处理方法


  • 数据中心化
  • 数据标准化
  • 缩放
  • 裁剪
  • 旋转
  • 翻转
  • 填充
  • 噪声添加
  • 灰度变换
  • 线性变换
  • 仿射变换
  • 亮度、饱和度及对比度变换

深度学习是由数据驱动的,数据的数量以及分布对模型的优劣起到决定性作用,所以需要对数据进行一定的预处理以及数据增强,用来提升模型的泛化能力;

观察下面这个图,这是经过数据增强之后生成的一系列数据,一共有64张图片,这64张图片都来源于一张原始图片,经过一系列的缩放、裁剪、平移、变换等等操作的组合,生成了64张图片;对图片进行数据增强的原因是为了提高模型的泛化能力,类似于5年高考,3年模拟的卷子;5年高考的真题卷就类似于原始训练数据,3年模拟就相当于做一些数据增强,去丰富训练数据;假如在三年模拟的卷子中出现了当年的高考题,那么分数自然有所提高;同样的,如果我们做数据增强,生成了与测试样本很相似的图片,那么模型的泛化能力自然可以得到提高,这就是做数据增强的原因;
在这里插入图片描述
看一下代码,这里使用上一篇博客介绍的人民币二分类实验的代码的数据预处理部分,
数据标准化——transforms.normalize

# ============================ step 1/5 数据 ============================
# 这部分设置数据的路径
split_dir = os.path.join("C:/Users/10530/Desktop/pytorch/rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")#设置数据标准化的均值和标准差
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]# transforms.Compose的功能是将一系列的transforms方法进行有序的组合包装,在具体实现的时候,会依次按顺序对图像进行操作
train_transform = transforms.Compose([transforms.Resize((32, 32)), #Resize,将图像缩放到32*32的大小transforms.RandomCrop(32, padding=4), #RandomCrop,对数据进行随机的裁剪transforms.ToTensor(), #ToTensor,将图片转成张量的形式同时会进行归一化操作,把像素值的区间从0-255归一化到0-1transforms.Normalize(norm_mean, norm_std), #标准化操作,将数据的均值变为0,标准差变为1
]) # Resize的功能是缩放,RandomCrop的功能是裁剪,ToTensor的功能是把图片变为张量#验证集的预处理的方法,对比训练集,少了RandomCrop这一部分,因为在验证集中是不需要对数据进行数据增强的
valid_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])# 构建MyDataset实例,MyDataset必须是用户自己构建的
train_data = RMBDataset(data_dir=train_dir, transform=train_transform) # data_dir是数据的路径,transform是数据预处理
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform) # 一个用于训练,一个用于验证# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) # shuffle=True,每一个epoch中样本都是乱序的
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

同样,在模型训练中设置断点,断点位置位于如下代码处:

for i, data in enumerate(train_loader):

进行debug,并点击step into进行操作,在跳转后的代码中进行一个是否采用多进程的判断:

def __iter__(self):if self.num_workers == 0:return _SingleProcessDataLoaderIter(self)else:return _MultiProcessingDataLoaderIter(self)

选择单进程的运行机制,进入dataloader.py界面,找到def init(self)方法,点击Run to Cursor,程序就会运行到光标所在的行,具体如下***的代码:

def __next__(self):****index = self._next_index() # may raise StopIterationdata = self.dataset_fetcher.fetch(index) # may raise StopIterationif self.pin_memory:data = _utils.pin_memory.pin_memory(data)return data

这一步的作用是获取Index,也就是要读取哪些数据。得到Index就可以进入dataset_fetcher.fetch(index),根据索引去获取数据;进入到fetch函数:

class _MapDatasetFetcher(_BaseDatasetFetcher):def __init__(self, dataset, auto_collation, collate_fn, drop_last):super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)def fetch(self, possibly_batched_index):if self.auto_collation:data = [self.dataset[idx] for idx in possibly_batched_index]else:data = self.dataset[possibly_batched_index]return self.collate_fn(data)

在fetch函数中,代码

data = [self.dataset[idx] for idx in possibly_batched_index]

调用了dataset,接着进入dataset所在的代码位置,如下所示:

def __getitem__(self, index):path_img, label = self.data_info[index]img = Image.open(path_img).convert('RGB') # 0~255if self.transform is not None:img = self.transform(img) # 在这里做transform,转为tensor等等return img, label

dataest代码位于类RMBDataset(Dataset)中的def getitem()函数,在getitem()中根据索引去获取图片的路径以及标签;然后采用代码

img = Image.open(path_img).convert('RGB') # 0~255

打开图片,读取进来的图片是一个PIL的数据类型,然后在getitem中调用transform()进行图像预处理操作,通过step_into进入transform()代码位置进行分析,代码位于transform中的def call()函数

def __call__(self, img):for t in self.transforms:img = t(img)return img

call()函数是一个for循环,也就是依次有序地从compose中去调用预处理方法,第一个预处理方法是t(img),其功能是是Resize缩放;第二个功能是裁剪,第三个功能是进行张量操作,第四个功能是进行归一化;对compose的四个功能循环结束之后,就会返回transform。

transform是在__getitem__()中调用,并且在__getitem__()中实现数据预处理,然后通过__getitem__返回一个样本;

def __getitem__(self, index):path_img, label = self.data_info[index]img = Image.open(path_img).convert('RGB') # 0~255if self.transform is not None:img = self.transform(img) # 在这里做transform,转为tensor等等

执行step out操作返回fetch()函数,接着就是不断地循环index获取一个batch_size大小的数据,最后在return的时候调用collate_fn()函数,将数据整理成一个batch_data的形式。

然后执行step out操作返回到dataloader.py中的__next__()函数中,然后跳出dataloader.py回到主代码当中,接着数据就读取进来了。这就是pytorch数据读取和transforms的运行机制。
在这里插入图片描述
回顾上面的数据读取流程图,transforms是在getitem中使用的;在getitem中读取一张图片,然后对这一张图片进行一系列预处理,然后返回图片以及标签。

了解了transforms的机制,现在学习一个比较常用的预处理方法,数据的标准化transforms.Normalize;

transforms.Normalize


  • 功能:逐channel的对图像进行标准化,即数据的均值变为0,标准差变为1
  • 标准化的计算公式为 output=(input−mean)/stdoutput = (input - mean) /stdoutput=(inputmean)/std
  • mean:各通道的均值
  • std:各通道的标准差
  • inplace:是否原位操作

transform.Normalize(mean,std,inplace=False)

回到代码中看一下normalize的具体实现方法,transform是在dataset的getitem中实现的,所以可以直接去dataset的getitem函数中设置断点,具体如下:

def __getitem__(self, index):path_img, label = self.data_info[index]img = Image.open(path_img).convert('RGB') # 0~255if self.transform is not None:***img = self.transform(img) # 在这里做transform,转为tensor等等return img, label

代码中***标注的地方就是断点的设置位置,进行debug操作,点击step into进入详细代码环境,进入了transforms.py中的call()函数中,在call函数中循环transforms。

def __call__(self, tensor):"""Args:tensor (Tensor): Tensor image of size (C, H, W) to be normalized.Returns:Tensor: Normalized Tensor image."""return F.normalize(tensor, self.mean, self.std, self.inplace)

接着进入transforms中查看normalize的实现,来到了normalize()类中的__call__()函数中,代码只有一行,实际上这行代码是调用了pytorch中的function中normalize方法;pytorch的function提供了很多常用的函数,使用step into查看normalize中的具体实现。

if not _is_tensor_image(tensor): #输入的合法性判断raise TypeError('tensor is not a torch image.')if not inplace: #判断是否需要原地操作tensor = tensor.clone()dtype = tensor.dtypemean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)std = torch.as_tensor(std, dtype=dtype, device=tensor.device)tensor.sub_(mean[:, None, None]).div_(std[:, None, None]) #归一化公式return tensor

首先是输入的合法性判断,输入的是tensor,也就是原始的图像,接着判断是否要原地操作,如果不是inplace就需要将张量复制一份到新的内存空间中。下面的代码就是获取数据的均值和标准差,并将数据转换为张量。注意在sub_和div_后面有下划线,意思是进行原位操作,这样就完成了数据标准化的操作。

对数据进行标准化之后可以加快模型的收敛,具体可以看百面机器学习的第一章。


推荐阅读
author-avatar
曹莹888淑女
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有