热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pytorch学习教程之自定义数据集

这篇文章主要给大家介绍了关于pytorch学习教程之自定义数据集的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着

自定义数据集

在训练深度学习模型之前,样本集的制作非常重要。在pytorch中,提供了一些接口和类,方便我们定义自己的数据集合,下面完整的试验自定义样本集的整个流程。

开发环境

  • Ubuntu 18.04
  • pytorch 1.0
  • pycharm

实验目的

  1. 掌握pytorch中数据集相关的API接口和类
  2. 熟悉数据集制作的整个流程

实验过程

1.收集图像样本

以简单的猫狗二分类为例,可以在网上下载一些猫狗图片。创建以下目录:

  • data-------------根目录
  • data/test-------测试集
  • data/train------训练集
  • data/val--------验证集

在test/train/val之下在校分别创建2个文件夹,dog, cat

cat, dog文件夹下分别存放2类图像:

标签

种类 标签
cat 0
dog 1

之后写一个简单的python脚本,生成txt文件,用于指明每个图像和标签的对应关系。

格式: /cat/1.jpg 0 dog/1.jpg 1 .....

如图:

至此,样本集的收集以及简单归类完成,下面将开始采用pytorch的数据集相关API和类。

2. 使用pytorch相关类,API对数据集进行封装

2.1 pytorch中数据集相关的类,接口

pytorch中数据集相关的类位于torch.utils.data package中。

https://pytorch.org/docs/stable/data.html

本次实验,主要使用以下类:

torch.utils.data.Dataset
torch.utils.data.DataLoader

Dataset类的使用: 所有的类都应该是此类的子类(也就是说应该继承该类)。 所有的子类都要重写(override) __len()__, __getitem()__ 这两个方法。

方法 作用
__len()__ 此方法应该提供数据集的大小(容量)
__getitem()__ 此方法应该提供支持下标索方式引访问数据集

这里和Java抽象类很相似,在抽象类abstract class中,一般会定义一些抽象方法abstract method,抽象方法:只有方法名没有方法的具体实现。如果一个子类继承于该抽象类,要重写(overrode)父类的抽象方法。

DataLoader类的使用:

2.2 实现

使用到的python package

python package 目的
numpy 矩阵操作,对图像进行转置
skimage 图像处理,图像I/O,图像变换
matplotlib 图像的显示,可视化
os 一些文件查找操作
torch pytorch
torvision pytorch

源码

导入python包

import numpy as np
from skimage import io
from skimage import transform
import matplotlib.pyplot as plt
import os
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torchvision.utils import make_grid

第一步:

定义一个子类,继承Dataset类, 重写 __len()__, __getitem()__ 方法。

细节:

1.数据集中一个一样的表示:采用字典的形式sample = {"image": image, "label": label}。

2.图像的读取:采用skimage.io进行读取,读取之后的结果为numpy.ndarray形式。

3.图像变换:transform参数

# step1: 定义MyDataset类, 继承Dataset, 重写抽象方法:__len()__, __getitem()__
class MyDataset(Dataset):

 def __init__(self, root_dir, names_file, transform=None):
 self.root_dir = root_dir
 self.names_file = names_file
 self.transform = transform
 self.size = 0
 self.names_list = []

 if not os.path.isfile(self.names_file):
  print(self.names_file + "does not exist!")
 file = open(self.names_file)
 for f in file:
  self.names_list.append(f)
  self.size += 1

 def __len__(self):
 return self.size

 def __getitem__(self, idx):
 image_path = self.root_dir + self.names_list[idx].split(" ")[0]
 if not os.path.isfile(image_path):
  print(image_path + "does not exist!")
  return None
 image = io.imread(image_path) # use skitimage
 label = int(self.names_list[idx].split(" ")[1])

 sample = {"image": image, "label": label}
 if self.transform:
  sample = self.transform(sample)

 return sample

第二步

实例化一个对象,并读取和显示数据集

train_dataset = MyDataset(root_dir="./data/train",
    names_file="./data/train/train.txt",
    transform=None)

plt.figure()
for (cnt,i) in enumerate(train_dataset):
 image = i["image"]
 label = i["label"]

 ax = plt.subplot(4, 4, cnt+1)
 ax.axis("off")
 ax.imshow(image)
 ax.set_title("label {}".format(label))
 plt.pause(0.001)

 if cnt == 15:
 break

只显示了部分数据,前部分全是cat

第三步(可选 optional)

对数据集进行变换:一般收集到的图像大小尺寸,亮度等存在差异,变换的目的就是使得数据归一化。另一方面,可以通过变换进行数据增加data argument

关于pytorch中的变换transforms,请参考该系列之前的文章

由于数据集中样本采用字典dicts形式表示。 因此不能直接调用torchvision.transofrms中的方法。

本实验只进行尺寸归一化Resize, 数据类型变换ToTensor操作。

Resize

# # 变换Resize
class Resize(object):

 def __init__(self, output_size: tuple):
 self.output_size = output_size

 def __call__(self, sample):
 # 图像
 image = sample["image"]
 # 使用skitimage.transform对图像进行缩放
 image_new = transform.resize(image, self.output_size)
 return {"image": image_new, "label": sample["label"]}

ToTensor

# # 变换ToTensor
class ToTensor(object):

 def __call__(self, sample):
 image = sample["image"]
 image_new = np.transpose(image, (2, 0, 1))
 return {"image": torch.from_numpy(image_new),
  "label": sample["label"]}

第四步: 对整个数据集应用变换

细节: transformers.Compose() 将不同的几个组合起来。先进行Resize, 再进行ToTensor

# 对原始的训练数据集进行变换
transformed_trainset = MyDataset(root_dir="./data/train",
    names_file="./data/train/train.txt",
    transform=transforms.Compose(
    [Resize((224,224)),
    ToTensor()]
    ))

第五步: 使用DataLoader进行包装

为何要使用DataLoader?

① 深度学习的输入是mini_batch形式

② 样本加载时候可能需要随机打乱顺序,shuffle操作

③ 样本加载需要采用多线程

pytorch提供的DataLoader封装了上述的功能,这样使用起来更方便。

# 使用DataLoader可以利用多线程,batch,shuffle等
trainset_dataloader = DataLoader(dataset=transformed_trainset,
     batch_size=4,
     shuffle=True,
     num_workers=4)

可视化:

def show_images_batch(sample_batched):
 images_batch, labels_batch = 
 sample_batched["image"], sample_batched["label"]
 grid = make_grid(images_batch)
 plt.imshow(grid.numpy().transpose(1, 2, 0))


# sample_batch: Tensor , NxCxHxW
plt.figure()
for i_batch, sample_batch in enumerate(trainset_dataloader):
 show_images_batch(sample_batch)
 plt.axis("off")
 plt.ioff()
 plt.show()


plt.show()

通过DataLoader包装之后,样本以min_batch形式输出,而且进行了随机打乱顺序。

至此,自定义数据集的完整流程已实现,test, val集只需要改路径即可。

补充

更简单的方法

上述继承Dataset, 重写 __len()__, __getitem() 是通用的方法,过程相对繁琐。对于简单的分类数据集,pytorch中提供了更简便的方式――ImageFolder。

如果每种类别的样本放在各自的文件夹中,则可以直接使用ImageFolder。

仍然以cat, dog 二分类数据集为例:

文件结构:



Code

import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import numpy as np


# https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

# data_transform = transforms.Compose([
#  transforms.RandomResizedCrop(224),
#  transforms.RandomHorizontalFlip(),
#  transforms.ToTensor(),
#  transforms.Normalize(mean=[0.485, 0.456, 0.406],
#       std=[0.229, 0.224, 0.225])
# ])

data_transform = transforms.Compose([
 transforms.Resize((224,224)),
 transforms.RandomHorizontalFlip(),
 transforms.ToTensor(),

])

train_dataset = datasets.ImageFolder(root="./data/train",transform=data_transform)
train_dataloader = DataLoader(dataset=train_dataset,
        batch_size=4,
        shuffle=True,
        num_workers=4)


def show_batch_images(sample_batch):
 labels_batch = sample_batch[1]
 images_batch = sample_batch[0]

 for i in range(4):
  label_ = labels_batch[i].item()
  image_ = np.transpose(images_batch[i], (1, 2, 0))
  ax = plt.subplot(1, 4, i + 1)
  ax.imshow(image_)
  ax.set_title(str(label_))
  ax.axis("off")
  plt.pause(0.01)


plt.figure()
for i_batch, sample_batch in enumerate(train_dataloader):
 show_batch_images(sample_batch)

 plt.show()

由于 train 目录下只有2个文件夹,分别为cat, dog, 因此ImageFolder安装顺序对cat使用标签0, dog使用标签1。

End

参考:

https://pytorch.org/docs/stable/data.html

https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

到此这篇关于pytorch学习教程之自定义数据集的文章就介绍到这了,更多相关pytorch自定义数据集内容请搜索编程笔记以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程笔记!

原文链接:https://www.jianshu.com/p/2d9927a70594


推荐阅读
  • PyCharm 安装与首个 Python 程序实践
    本文将指导您如何安装 PyCharm,并通过创建一个简单的 'Hello, World' 程序来初步体验这一强大的 Python 集成开发环境。 ... [详细]
  • Android 中的布局方式之线性布局
    nsitionalENhttp:www.w3.orgTRxhtml1DTDxhtml1-transitional.dtd ... [详细]
  • Django与Python及其他Web框架的对比
    本文详细介绍了Django与其他Python Web框架(如Flask和Tornado)的区别,并探讨了Django的基本使用方法及与其他语言(如PHP)的比较。 ... [详细]
  • 第二十五天接口、多态
    1.java是面向对象的语言。设计模式:接口接口类是从java里衍生出来的,不是python原生支持的主要用于继承里多继承抽象类是python原生支持的主要用于继承里的单继承但是接 ... [详细]
  • 对于初学者而言,搭建一个高效稳定的 Python 开发环境是入门的关键一步。本文将详细介绍如何利用 Anaconda 和 Jupyter Notebook 来构建一个既易于管理又功能强大的开发环境。 ... [详细]
  • 如何在Django框架中实现对象关系映射(ORM)
    本文介绍了Django框架中对象关系映射(ORM)的实现方式,通过ORM,开发者可以通过定义模型类来间接操作数据库表,从而简化数据库操作流程,提高开发效率。 ... [详细]
  • JUnit下的测试和suite
    nsitionalENhttp:www.w3.orgTRxhtml1DTDxhtml1-transitional.dtd ... [详细]
  • Requests库的基本使用方法
    本文介绍了Python中Requests库的基础用法,包括如何安装、GET和POST请求的实现、如何处理Cookies和Headers,以及如何解析JSON响应。相比urllib库,Requests库提供了更为简洁高效的接口来处理HTTP请求。 ... [详细]
  • OBS Studio自动化实践:利用脚本批量生成录制场景
    本文探讨了如何利用OBS Studio进行高效录屏,并通过脚本实现场景的自动生成。适合对自动化办公感兴趣的读者。 ... [详细]
  • Web动态服务器Python基本实现
    Web动态服务器Python基本实现 ... [详细]
  • Jenkins API当前未直接提供获取任务构建队列长度的功能,因此需要通过解析HTML页面来间接实现这一需求。 ... [详细]
  • Bootstrap Paginator 分页插件详解与应用
    本文深入探讨了Bootstrap Paginator这款流行的JavaScript分页插件,提供了详细的使用指南和示例代码,旨在帮助开发者更好地理解和利用该工具进行高效的数据展示。 ... [详细]
  • HTML前端开发:UINavigationController与页面间数据传递详解
    本文详细介绍了如何在HTML前端开发中利用UINavigationController进行页面管理和数据传递,适合初学者和有一定基础的开发者学习。 ... [详细]
  • Python 3 Scrapy 框架执行流程详解
    本文详细介绍了如何在 Python 3 环境下安装和使用 Scrapy 框架,包括常用命令和执行流程。Scrapy 是一个强大的 Web 抓取框架,适用于数据挖掘、监控和自动化测试等多种场景。 ... [详细]
  • 本项目通过Python编程实现了一个简单的汇率转换器v1.02。主要内容包括:1. Python的基本语法元素:(1)缩进:用于表示代码的层次结构,是Python中定义程序框架的唯一方式;(2)注释:提供开发者说明信息,不参与实际运行,通常每个代码块添加一个注释;(3)常量和变量:用于存储和操作数据,是程序执行过程中的重要组成部分。此外,项目还涉及了函数定义、用户输入处理和异常捕获等高级特性,以确保程序的健壮性和易用性。 ... [详细]
author-avatar
手机用户2602930391
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有