热门标签 | HotTags
当前位置:  开发笔记 > 运维 > 正文

pytorch从csv加载自定义数据模板的操作

这篇文章主要介绍了pytorch从csv加载自定义数据模板的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

整理了一套模板,全注释了,这个难点终于克服了

from PIL import Image
import pandas as pd
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os
#放文件的路径
dir_path= './97/train/'
csv_path='./97/train.csv'
class Mydataset(Dataset):
 #传递数据路径,csv路径 ,数据增强方法
 def __init__(self, dir_path,csv, transform=None, target_transform=None):
  super(Mydataset, self).__init__()
  #一个个往列表里面加绝对路径
  self.path = []
  #读取csv
  self.data = pd.read_csv(csv)
  #对标签进行硬编码,例如0 1 2 3 4,把字母变成这个
  colorMap = {elem: index + 1 for index, elem in enumerate(set(self.data["label"]))}
  self.data['label'] = self.data['label'].map(colorMap)
  #创造空的label准备存放标签
  self.num = int(self.data.shape[0]) # 一共多少照片
  self.label = np.zeros(self.num, dtype=np.int32)
  #迭代得到数据路径和标签一一对应
  for index, row in self.data.iterrows():
   self.path.append(os.path.join(dir_path,row['filename']))
   self.label[index] = row['label'] # 将数据全部读取出来
  #训练数据增强
  self.transform = transform
  #验证数据增强在这里没用
  self.target_transform = target_transform
 #最关键的部分,在这里使用前面的方法
 def __getitem__(self, index):
  img =Image.open(self.path[index]).convert('RGB')
  labels = self.label[index]
  #在这里做数据增强
  if self.transform is not None:
   img = self.transform(img) # 转化tensor类型
  return img, labels
 def __len__(self):
  return len(self.data)
#数据增强的具体内容
transform = transforms.Compose(
 [transforms.ToTensor(),
  transforms.Resize(150),
  transforms.CenterCrop(150),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
#加载数据
train_data = Mydataset(dir_path=dir_path,csv=csv_path, transform=transform)
trainloader = DataLoader(train_data, batch_size=16, shuffle=True, num_workers=0)
#迭代训练
for i_batch,batch_data in enumerate(trainloader):
 image,label=batch_data

补充:pytorch—定义自己的数据集及加载训练

笔记:pytorch Conv2d 的宽高公式理解,pytorch 使用自己的数据集并且加载训练

一、pypi 镜像使用帮助

pypi 镜像每 5 分钟同步一次。

临时使用

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple some-package

注意,simple 不能少, 是 https 而不是 http

设为默认

修改 ~/.config/pip/pip.conf (Linux), %APPDATA%\pip\pip.ini (Windows 10)$HOME/Library/Application Support/pip/pip.conf (macOS) (没有就创建一个), 修改 index-urltuna,例如

[global]
index-url = https://pypi.tuna.tsinghua.edu.cn/simple

pip 和 pip3 并存时,只需修改 ~/.pip/pip.conf。

二、pytorch Conv2d 的宽高公式理解

三、pytorch 使用自己的数据集并且加载训练

import os
import sys
import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import time
import random
import csv
from PIL import Image
def createImgIndex(dataPath, ratio):
 '''
 读取目录下面的图片制作包含图片信息、图片label的train.txt和val.txt
 dataPath: 图片目录路径
 ratio: val占比
 return:label列表
 '''
 fileList = os.listdir(dataPath)
 random.shuffle(fileList)
 classList = [] # label列表
 # val 数据集制作
 with open('data/val_section1015.csv', 'w') as f:
  writer = csv.writer(f)
  for i in range(int(len(fileList)*ratio)):
   row = []
   if '.jpg' in fileList[i]:
    fileInfo = fileList[i].split('_')
    sectiOnName= fileInfo[0] + '_' + fileInfo[1] # 切面名+标准与否
    row.append(os.path.join(dataPath, fileList[i])) # 图片路径
    if sectionName not in classList:
     classList.append(sectionName)
    row.append(classList.index(sectionName))
    writer.writerow(row)
  f.close()
 # train 数据集制作
 with open('data/train_section1015.csv', 'w') as f:
  writer = csv.writer(f)
  for i in range(int(len(fileList) * ratio)+1, len(fileList)):
   row = []
   if '.jpg' in fileList[i]:
    fileInfo = fileList[i].split('_')
    sectiOnName= fileInfo[0] + '_' + fileInfo[1] # 切面名+标准与否
    row.append(os.path.join(dataPath, fileList[i])) # 图片路径
    if sectionName not in classList:
     classList.append(sectionName)
    row.append(classList.index(sectionName))
    writer.writerow(row)
  f.close()
 print(classList, len(classList))
 return classList
def default_loader(path):
 '''定义读取文件的格式'''
 return Image.open(path).resize((128, 128),Image.ANTIALIAS).convert('RGB')
class MyDataset(Dataset):
 '''Dataset类是读入数据集数据并且对读入的数据进行索引'''
 def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
  super(MyDataset, self).__init__() #对继承自父类的属性进行初始化
  fh = open(txt, 'r') #按照传入的路径和txt文本参数,以只读的方式打开这个文本
  reader = csv.reader(fh)
  imgs = []
  for row in reader:
   imgs.append((row[0], int(row[1]))) # (图片信息,lable)
  self.imgs = imgs
  self.transform = transform
  self.target_transform = target_transform
  self.loader = loader
 
 def __getitem__(self, index):
  '''用于按照索引读取每个元素的具体内容'''
  # fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中row[0]和row[1]的信息
  fn, label = self.imgs[index]
  img = self.loader(fn)
  if self.transform is not None:
   img = self.transform(img) #数据标签转换为Tensor
  return img, label
 
 def __len__(self):
  '''返回数据集的长度'''
  return len(self.imgs)
class Model(nn.Module):
 def __init__(self, classNum=31):
  super(Model, self).__init__()
  # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
  # torch.nn.MaxPool2d(kernel_size, stride, padding)
  # input 维度 [3, 128, 128]
  self.cnn = nn.Sequential(
   nn.Conv2d(3, 64, 3, 1, 1), # [64, 128, 128]
   nn.BatchNorm2d(64),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [64, 64, 64]
   nn.Conv2d(64, 128, 3, 1, 1), # [128, 64, 64]
   nn.BatchNorm2d(128),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [128, 32, 32]
   nn.Conv2d(128, 256, 3, 1, 1), # [256, 32, 32]
   nn.BatchNorm2d(256),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [256, 16, 16]
   nn.Conv2d(256, 512, 3, 1, 1), # [512, 16, 16]
   nn.BatchNorm2d(512),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [512, 8, 8]
   nn.Conv2d(512, 512, 3, 1, 1), # [512, 8, 8]
   nn.BatchNorm2d(512),
   nn.ReLU(),
   nn.MaxPool2d(2, 2, 0), # [512, 4, 4]
  )
  self.fc = nn.Sequential(
   nn.Linear(512 * 4 * 4, 1024),
   nn.ReLU(),
   nn.Linear(1024, 512),
   nn.ReLU(),
   nn.Linear(512, classNum)
  )
 def forward(self, x):
  out = self.cnn(x)
  out = out.view(out.size()[0], -1)
  return self.fc(out)
def train(train_set, train_loader, val_set, val_loader):
 model = Model()
 loss = nn.CrossEntropyLoss() # 因为是分类任务,所以loss function使用 CrossEntropyLoss
 optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # optimizer 使用 Adam
 num_epoch = 10
 # 开始训练
 for epoch in range(num_epoch):
  epoch_start_time = time.time()
  train_acc = 0.0
  train_loss = 0.0
  val_acc = 0.0
  val_loss = 0.0
  model.train() # train model会开放Dropout和BN
  for i, data in enumerate(train_loader):
   optimizer.zero_grad() # 用 optimizer 將 model 參數的 gradient 歸零
   train_pred = model(data[0]) # 利用 model 的 forward 函数返回预测结果
   batch_loss = loss(train_pred, data[1]) # 计算 loss
   batch_loss.backward() # tensor(item, grad_fn=)
   optimizer.step() # 以 optimizer 用 gradient 更新参数
   train_acc += np.sum(np.argmax(train_pred.data.numpy(), axis=1) == data[1].numpy())
   train_loss += batch_loss.item()
  model.eval()
  with torch.no_grad(): # 不跟踪梯度
   for i, data in enumerate(val_loader):
    # data = [imgData, labelList]
    val_pred = model(data[0])
    batch_loss = loss(val_pred, data[1])
    val_acc += np.sum(np.argmax(val_pred.data.numpy(), axis=1) == data[1].numpy())
    val_loss += batch_loss.item()
   # 打印结果
   print('[%03d/%03d] %2.2f sec(s) Train Acc: %3.6f Loss: %3.6f | Val Acc: %3.6f loss: %3.6f' % \
     (epoch + 1, num_epoch, time.time() - epoch_start_time, \
     train_acc / train_set.__len__(), train_loss / train_set.__len__(), val_acc / val_set.__len__(),
     val_loss / val_set.__len__()))
if __name__ == '__main__':
 dirPath = '/data/Matt/QC_images/test0916' # 图片文件目录
 createImgIndex(dirPath, 0.2)    # 创建train.txt, val.txt
 root = os.getcwd() + '/data/'
 train_data = MyDataset(txt=root+'train_section1015.csv', transform=transforms.ToTensor())
 val_data = MyDataset(txt=root+'val_section1015.csv', transform=transforms.ToTensor())
 train_loader = DataLoader(dataset=train_data, batch_size=6, shuffle=True, num_workers = 4)
 val_loader = DataLoader(dataset=val_data, batch_size=6, shuffle=False, num_workers = 4)
 # 开始训练模型
 train(train_data, train_loader, val_data, val_loader)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持。如有错误或未考虑完全的地方,望不吝赐教。


推荐阅读
  • 2023年1月28日网络安全热点
    涵盖最新的网络安全动态,包括OpenSSH和WordPress的安全更新、VirtualBox提权漏洞、以及谷歌推出的新证书验证机制等内容。 ... [详细]
  • 本文详细介绍了如何在PHP中使用Memcached进行数据缓存,包括服务器连接、数据操作、高级功能等。 ... [详细]
  • 本文探讨了Linux环境下线程私有数据(Thread-Specific Data, TSD)的概念及其重要性,介绍了如何通过TSD技术避免多线程间全局变量冲突的问题,并提供了具体的实现方法和示例代码。 ... [详细]
  • 在使用 PyInstaller 将 Python 应用程序打包成独立的可执行文件时,若项目中包含动态加载的库或插件,需要正确配置 --hidden-import 和 --add-binary 参数,以确保所有依赖项均能被正确识别和打包。 ... [详细]
  • 分布式计算助力链力实现毫秒级安全响应,确保100%数据准确性
    随着分布式计算技术的发展,其在数据存储、文件传输、在线视频、社交平台及去中心化金融等多个领域的应用日益广泛。国际知名企业如Firefox、Google、Opera、Netflix、OpenBazaar等均已采用该技术,推动了技术创新和服务升级。 ... [详细]
  • 本文详细介绍了在PHP中如何获取和处理HTTP头部信息,包括通过cURL获取请求头信息、使用header函数发送响应头以及获取客户端HTTP头部的方法。同时,还探讨了PHP中$_SERVER变量的使用,以获取客户端和服务器的相关信息。 ... [详细]
  • 本文概述了在GNU/Linux系统中,动态库在链接和运行阶段的搜索路径及其指定方法,包括通过编译时参数、环境变量及系统配置文件等方式来控制动态库的查找路径。 ... [详细]
  • 本文概述了作者在2014年的几项目标与愿望,包括职业发展、个人成长及家庭幸福等方面的具体计划。 ... [详细]
  • 使用 ModelAttribute 实现页面数据自动填充
    本文介绍了如何利用 Spring MVC 中的 ModelAttribute 注解,在页面跳转后自动填充表单数据。主要探讨了两种实现方法及其背后的原理。 ... [详细]
  • Docker基础入门与环境配置指南
    本文介绍了Docker——一款用Go语言编写的开源应用程序容器引擎。通过Docker,用户能够将应用及其依赖打包进容器内,实现高效、轻量级的虚拟化。容器之间采用沙箱机制,确保彼此隔离且资源消耗低。 ... [详细]
  • Linux内核中的内存反碎片技术解析
    本文深入探讨了Linux内核中实现的内存反碎片技术,包括其历史发展、关键概念如虚拟可移动区域以及具体的内存碎片整理策略。旨在为开发者提供全面的技术理解。 ... [详细]
  • selenium通过JS语法操作页面元素
    做过web测试的小伙伴们都知道,web元素现在很多是JS写的,那么既然是JS写的,可以通过JS语言去操作页面,来帮助我们操作一些selenium不能覆盖的功能。问题来了我们能否通过 ... [详细]
  • 2019年独角兽企业招聘Python工程师标准课程概览
    本文详细介绍了2019年独角兽企业在招聘Python工程师时的标准课程内容,包括Shell脚本中的逻辑判断、文件属性判断、if语句的特殊用法及case语句的应用。 ... [详细]
  • 解决Linux中wget无法解析主机的问题
    本文介绍了如何通过修改/etc/resolv.conf文件来解决Linux系统中wget命令无法解析主机名的问题,通过添加Google的公共DNS服务器地址作为解决方案。 ... [详细]
  • 本文探讨了服务器系统架构的性能评估方法,包括性能评估的目的、步骤以及如何选择合适的度量标准。文章还介绍了几种常用的基准测试程序及其应用,并详细说明了Web服务器性能评估的关键指标与测试方法。 ... [详细]
author-avatar
手机用户2502881415
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有