热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch中自定义数据集的读取方法

显然我们在学习深度学习时,不能只局限于通过使用官方提供的MNSIT、CIFAR-10、CIFAR-100这样的数据集,很多时候我们还是需要根据自己遇到的实际问题自己去搜集数据,然后

显然我们在学习深度学习时,不能只局限于通过使用官方提供的MNSIT、CIFAR-10、CIFAR-100这样的数据集,很多时候我们还是需要根据自己遇到的实际问题自己去搜集数据,然后制作数据集(收集数据集的方法有很多,这里就不过多的展开了)。这里只介绍数据集的读取。




  1. 自定义数据集的方法

    首先创建一个Dataset类

    在这里插入图片描述

    在代码中:

    def init() 一些初始化的过程写在这个函数下

    def len() 返回所有数据的数量,比如我们这里将数据划分好之后,这里仅仅返回的是被处理后的关系

    def getitem() 回数据和标签



  2. 补充代码

    上述已经将框架打出来了,接下来就是将框架填充完整就行了,下面是完整的代码,代码的解释说明我也已经写在其中了



# -*- coding: utf-8 -*-
# @Author : 胡子旋
# @Email :1017190168@qq.com
import torch
import os,glob
import visdom
import time
import torchvision
import random,csv
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
class pokemom(Dataset):
def __init__(self,root,resize,mode,):
super(pokemom,self).__init__()
# 保存参数
self.root=root
self.resize=resize
# 给每一个类做映射
self.name2label={} # "squirtle":0 ,"pikachu":1……
for name in sorted(os.listdir(os.path.join(root))):
# 过滤掉文件夹
if not os.path.isdir(os.path.join(root,name)):
continue
# 保存在表中;将最长的映射作为最新的元素的label的值
self.name2label[name]=len(self.name2label.keys())
print(self.name2label)
# 加载文件
self.images,self.labels=self.load_csv('images.csv')
# 裁剪数据
if mode=='train':
self.images=self.images[:int(0.6*len(self.images))] # 将数据集的60%设置为训练数据集合
self.labels=self.labels[:int(0.6*len(self.labels))] # label的60%分配给训练数据集合
elif mode=='val':
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))] # 从60%-80%的地方
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else:
self.images = self.images[int(0.8 * len(self.images)):] # 从80%的地方到最末尾
self.labels = self.labels[int(0.8 * len(self.labels)):]
# image+label 的路径
def load_csv(self,filename):
# 将所有的图片加载进来
# 如果不存在的话才进行创建
if not os.path.exists(os.path.join(self.root,filename)):
images=[]
for name in self.name2label.keys():
images+=glob.glob(os.path.join(self.root,name,'*.png'))
images+=glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
print(len(images),images)
# 1167 'pokeman\\bulbasaur\

# -*- coding: utf-8 -*-
# @Author : 胡子旋
# @Email :1017190168@qq.com
import torch
import os,glob
import visdom
import time
import torchvision
import random,csv
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
class pokemom(Dataset):
def __init__(self,root,resize,mode,):
super(pokemom,self).__init__()
# 保存参数
self.root=root
self.resize=resize
# 给每一个类做映射
self.name2label={} # "squirtle":0 ,"pikachu":1……
for name in sorted(os.listdir(os.path.join(root))):
# 过滤掉文件夹
if not os.path.isdir(os.path.join(root,name)):
continue
# 保存在表中;将最长的映射作为最新的元素的label的值
self.name2label[name]=len(self.name2label.keys())
print(self.name2label)
# 加载文件
self.images,self.labels=self.load_csv('images.csv')
# 裁剪数据
if mode=='train':
self.images=self.images[:int(0.6*len(self.images))] # 将数据集的60%设置为训练数据集合
self.labels=self.labels[:int(0.6*len(self.labels))] # label的60%分配给训练数据集合
elif mode=='val':
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))] # 从60%-80%的地方
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else:
self.images = self.images[int(0.8 * len(self.images)):] # 从80%的地方到最末尾
self.labels = self.labels[int(0.8 * len(self.labels)):]
# image+label 的路径
def load_csv(self,filename):
# 将所有的图片加载进来
# 如果不存在的话才进行创建
if not os.path.exists(os.path.join(self.root,filename)):
images=[]
for name in self.name2label.keys():
images+=glob.glob(os.path.join(self.root,name,'*.png'))
images+=glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
print(len(images),images)
# 1167 'pokeman\\bulbasaur\\00000000.png'
# 将文件以上述的格式保存在csv文件内
random.shuffle(images)
with open(os.path.join(self.root,filename),mode='w',newline='') as f:
writer=csv.writer(f)
for img in images: # 'pokeman\\bulbasaur\\00000000.png'
name=img.split(os.sep)[-2]
label=self.name2label[name]
writer.writerow([img,label])
print("write into csv into :",filename)
# 如果存在的话就直接的跳到这个地方
images,labels=[],[]
with open(os.path.join(self.root, filename)) as f:
reader=csv.reader(f)
for row in reader:
# 接下来就会得到 'pokeman\\bulbasaur\\00000000.png' 0 的对象
img,label=row
# 将label转码为int类型
label=int(label)
images.append(img)
labels.append(label)
# 保证images和labels的长度是一致的
assert len(images)==len(labels)
return images,labels
# 返回数据的数量
def __len__(self):
return len(self.images) # 返回的是被裁剪之后的关系
def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
# print(mean.shape, std.shape)
x = x_hat * std + mean
return x
# 返回idx的数据和当前图片的label
def __getitem__(self,idx):
# idex-[0-总长度]
# retrun images,labels
# 将图片,label的路径取出来
# 得到的img是这样的一个类型:'pokeman\\bulbasaur\\00000000.png'
# 然而label得到的则是 0,1,2 这样的整形的格式
img,label=self.images[idx],self.labels[idx]
tf=transforms.Compose([
lambda x:Image.open(x).convert('RGB'), # 将t图片的路径转换可以处理图片数据
# 进行数据加强
transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
# 随机旋转
transforms.RandomRotation(15), # 设置旋转的度数小一些,否则的话会增加网络的学习难度
# 中心裁剪
transforms.CenterCrop(self.resize), # 此时:既旋转了又不至于导致图片变得比较的复杂
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
img=tf(img)
label=torch.tensor(label)
return img,label
def main():
# 验证工作
viz=visdom.Visdom()
db=pokemom('pokeman',64,'train') # 这里可以改变大小 224->64,可以通过visdom进行查看
# 可视化样本
x,y=next(iter(db))
print('sample:',x.shape,y.shape,y)
viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))
# 加载batch_size的数据
loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)
for x,y in loader:
viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))
viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
# 每一次加载后,休息10s
time.sleep(10)
if __name__ == '__main__':
main()
000000.png'
# 将文件以上述的格式保存在csv文件内
random.shuffle(images)
with open(os.path.join(self.root,filename),mode='w',newline='') as f:
writer=csv.writer(f)
for img in images: # 'pokeman\\bulbasaur\

# -*- coding: utf-8 -*-
# @Author : 胡子旋
# @Email :1017190168@qq.com
import torch
import os,glob
import visdom
import time
import torchvision
import random,csv
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
class pokemom(Dataset):
def __init__(self,root,resize,mode,):
super(pokemom,self).__init__()
# 保存参数
self.root=root
self.resize=resize
# 给每一个类做映射
self.name2label={} # "squirtle":0 ,"pikachu":1……
for name in sorted(os.listdir(os.path.join(root))):
# 过滤掉文件夹
if not os.path.isdir(os.path.join(root,name)):
continue
# 保存在表中;将最长的映射作为最新的元素的label的值
self.name2label[name]=len(self.name2label.keys())
print(self.name2label)
# 加载文件
self.images,self.labels=self.load_csv('images.csv')
# 裁剪数据
if mode=='train':
self.images=self.images[:int(0.6*len(self.images))] # 将数据集的60%设置为训练数据集合
self.labels=self.labels[:int(0.6*len(self.labels))] # label的60%分配给训练数据集合
elif mode=='val':
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))] # 从60%-80%的地方
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else:
self.images = self.images[int(0.8 * len(self.images)):] # 从80%的地方到最末尾
self.labels = self.labels[int(0.8 * len(self.labels)):]
# image+label 的路径
def load_csv(self,filename):
# 将所有的图片加载进来
# 如果不存在的话才进行创建
if not os.path.exists(os.path.join(self.root,filename)):
images=[]
for name in self.name2label.keys():
images+=glob.glob(os.path.join(self.root,name,'*.png'))
images+=glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
print(len(images),images)
# 1167 'pokeman\\bulbasaur\\00000000.png'
# 将文件以上述的格式保存在csv文件内
random.shuffle(images)
with open(os.path.join(self.root,filename),mode='w',newline='') as f:
writer=csv.writer(f)
for img in images: # 'pokeman\\bulbasaur\\00000000.png'
name=img.split(os.sep)[-2]
label=self.name2label[name]
writer.writerow([img,label])
print("write into csv into :",filename)
# 如果存在的话就直接的跳到这个地方
images,labels=[],[]
with open(os.path.join(self.root, filename)) as f:
reader=csv.reader(f)
for row in reader:
# 接下来就会得到 'pokeman\\bulbasaur\\00000000.png' 0 的对象
img,label=row
# 将label转码为int类型
label=int(label)
images.append(img)
labels.append(label)
# 保证images和labels的长度是一致的
assert len(images)==len(labels)
return images,labels
# 返回数据的数量
def __len__(self):
return len(self.images) # 返回的是被裁剪之后的关系
def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
# print(mean.shape, std.shape)
x = x_hat * std + mean
return x
# 返回idx的数据和当前图片的label
def __getitem__(self,idx):
# idex-[0-总长度]
# retrun images,labels
# 将图片,label的路径取出来
# 得到的img是这样的一个类型:'pokeman\\bulbasaur\\00000000.png'
# 然而label得到的则是 0,1,2 这样的整形的格式
img,label=self.images[idx],self.labels[idx]
tf=transforms.Compose([
lambda x:Image.open(x).convert('RGB'), # 将t图片的路径转换可以处理图片数据
# 进行数据加强
transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
# 随机旋转
transforms.RandomRotation(15), # 设置旋转的度数小一些,否则的话会增加网络的学习难度
# 中心裁剪
transforms.CenterCrop(self.resize), # 此时:既旋转了又不至于导致图片变得比较的复杂
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
img=tf(img)
label=torch.tensor(label)
return img,label
def main():
# 验证工作
viz=visdom.Visdom()
db=pokemom('pokeman',64,'train') # 这里可以改变大小 224->64,可以通过visdom进行查看
# 可视化样本
x,y=next(iter(db))
print('sample:',x.shape,y.shape,y)
viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))
# 加载batch_size的数据
loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)
for x,y in loader:
viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))
viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
# 每一次加载后,休息10s
time.sleep(10)
if __name__ == '__main__':
main()
000000.png'
name=img.split(os.sep)[-2]
label=self.name2label[name]
writer.writerow([img,label])
print("write into csv into :",filename)
# 如果存在的话就直接的跳到这个地方
images,labels=[],[]
with open(os.path.join(self.root, filename)) as f:
reader=csv.reader(f)
for row in reader:
# 接下来就会得到 'pokeman\\bulbasaur\

# -*- coding: utf-8 -*-
# @Author : 胡子旋
# @Email :1017190168@qq.com
import torch
import os,glob
import visdom
import time
import torchvision
import random,csv
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
class pokemom(Dataset):
def __init__(self,root,resize,mode,):
super(pokemom,self).__init__()
# 保存参数
self.root=root
self.resize=resize
# 给每一个类做映射
self.name2label={} # "squirtle":0 ,"pikachu":1……
for name in sorted(os.listdir(os.path.join(root))):
# 过滤掉文件夹
if not os.path.isdir(os.path.join(root,name)):
continue
# 保存在表中;将最长的映射作为最新的元素的label的值
self.name2label[name]=len(self.name2label.keys())
print(self.name2label)
# 加载文件
self.images,self.labels=self.load_csv('images.csv')
# 裁剪数据
if mode=='train':
self.images=self.images[:int(0.6*len(self.images))] # 将数据集的60%设置为训练数据集合
self.labels=self.labels[:int(0.6*len(self.labels))] # label的60%分配给训练数据集合
elif mode=='val':
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))] # 从60%-80%的地方
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else:
self.images = self.images[int(0.8 * len(self.images)):] # 从80%的地方到最末尾
self.labels = self.labels[int(0.8 * len(self.labels)):]
# image+label 的路径
def load_csv(self,filename):
# 将所有的图片加载进来
# 如果不存在的话才进行创建
if not os.path.exists(os.path.join(self.root,filename)):
images=[]
for name in self.name2label.keys():
images+=glob.glob(os.path.join(self.root,name,'*.png'))
images+=glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
print(len(images),images)
# 1167 'pokeman\\bulbasaur\\00000000.png'
# 将文件以上述的格式保存在csv文件内
random.shuffle(images)
with open(os.path.join(self.root,filename),mode='w',newline='') as f:
writer=csv.writer(f)
for img in images: # 'pokeman\\bulbasaur\\00000000.png'
name=img.split(os.sep)[-2]
label=self.name2label[name]
writer.writerow([img,label])
print("write into csv into :",filename)
# 如果存在的话就直接的跳到这个地方
images,labels=[],[]
with open(os.path.join(self.root, filename)) as f:
reader=csv.reader(f)
for row in reader:
# 接下来就会得到 'pokeman\\bulbasaur\\00000000.png' 0 的对象
img,label=row
# 将label转码为int类型
label=int(label)
images.append(img)
labels.append(label)
# 保证images和labels的长度是一致的
assert len(images)==len(labels)
return images,labels
# 返回数据的数量
def __len__(self):
return len(self.images) # 返回的是被裁剪之后的关系
def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
# print(mean.shape, std.shape)
x = x_hat * std + mean
return x
# 返回idx的数据和当前图片的label
def __getitem__(self,idx):
# idex-[0-总长度]
# retrun images,labels
# 将图片,label的路径取出来
# 得到的img是这样的一个类型:'pokeman\\bulbasaur\\00000000.png'
# 然而label得到的则是 0,1,2 这样的整形的格式
img,label=self.images[idx],self.labels[idx]
tf=transforms.Compose([
lambda x:Image.open(x).convert('RGB'), # 将t图片的路径转换可以处理图片数据
# 进行数据加强
transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
# 随机旋转
transforms.RandomRotation(15), # 设置旋转的度数小一些,否则的话会增加网络的学习难度
# 中心裁剪
transforms.CenterCrop(self.resize), # 此时:既旋转了又不至于导致图片变得比较的复杂
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
img=tf(img)
label=torch.tensor(label)
return img,label
def main():
# 验证工作
viz=visdom.Visdom()
db=pokemom('pokeman',64,'train') # 这里可以改变大小 224->64,可以通过visdom进行查看
# 可视化样本
x,y=next(iter(db))
print('sample:',x.shape,y.shape,y)
viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))
# 加载batch_size的数据
loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)
for x,y in loader:
viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))
viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
# 每一次加载后,休息10s
time.sleep(10)
if __name__ == '__main__':
main()
000000.png' 0 的对象
img,label=row
# 将label转码为int类型
label=int(label)
images.append(img)
labels.append(label)
# 保证images和labels的长度是一致的
assert len(images)==len(labels)
return images,labels
# 返回数据的数量
def __len__(self):
return len(self.images) # 返回的是被裁剪之后的关系
def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
# print(mean.shape, std.shape)
x = x_hat * std + mean
return x
# 返回idx的数据和当前图片的label
def __getitem__(self,idx):
# idex-[0-总长度]
# retrun images,labels
# 将图片,label的路径取出来
# 得到的img是这样的一个类型:'pokeman\\bulbasaur\

# -*- coding: utf-8 -*-
# @Author : 胡子旋
# @Email :1017190168@qq.com
import torch
import os,glob
import visdom
import time
import torchvision
import random,csv
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
class pokemom(Dataset):
def __init__(self,root,resize,mode,):
super(pokemom,self).__init__()
# 保存参数
self.root=root
self.resize=resize
# 给每一个类做映射
self.name2label={} # "squirtle":0 ,"pikachu":1……
for name in sorted(os.listdir(os.path.join(root))):
# 过滤掉文件夹
if not os.path.isdir(os.path.join(root,name)):
continue
# 保存在表中;将最长的映射作为最新的元素的label的值
self.name2label[name]=len(self.name2label.keys())
print(self.name2label)
# 加载文件
self.images,self.labels=self.load_csv('images.csv')
# 裁剪数据
if mode=='train':
self.images=self.images[:int(0.6*len(self.images))] # 将数据集的60%设置为训练数据集合
self.labels=self.labels[:int(0.6*len(self.labels))] # label的60%分配给训练数据集合
elif mode=='val':
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))] # 从60%-80%的地方
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else:
self.images = self.images[int(0.8 * len(self.images)):] # 从80%的地方到最末尾
self.labels = self.labels[int(0.8 * len(self.labels)):]
# image+label 的路径
def load_csv(self,filename):
# 将所有的图片加载进来
# 如果不存在的话才进行创建
if not os.path.exists(os.path.join(self.root,filename)):
images=[]
for name in self.name2label.keys():
images+=glob.glob(os.path.join(self.root,name,'*.png'))
images+=glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
print(len(images),images)
# 1167 'pokeman\\bulbasaur\\00000000.png'
# 将文件以上述的格式保存在csv文件内
random.shuffle(images)
with open(os.path.join(self.root,filename),mode='w',newline='') as f:
writer=csv.writer(f)
for img in images: # 'pokeman\\bulbasaur\\00000000.png'
name=img.split(os.sep)[-2]
label=self.name2label[name]
writer.writerow([img,label])
print("write into csv into :",filename)
# 如果存在的话就直接的跳到这个地方
images,labels=[],[]
with open(os.path.join(self.root, filename)) as f:
reader=csv.reader(f)
for row in reader:
# 接下来就会得到 'pokeman\\bulbasaur\\00000000.png' 0 的对象
img,label=row
# 将label转码为int类型
label=int(label)
images.append(img)
labels.append(label)
# 保证images和labels的长度是一致的
assert len(images)==len(labels)
return images,labels
# 返回数据的数量
def __len__(self):
return len(self.images) # 返回的是被裁剪之后的关系
def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
# print(mean.shape, std.shape)
x = x_hat * std + mean
return x
# 返回idx的数据和当前图片的label
def __getitem__(self,idx):
# idex-[0-总长度]
# retrun images,labels
# 将图片,label的路径取出来
# 得到的img是这样的一个类型:'pokeman\\bulbasaur\\00000000.png'
# 然而label得到的则是 0,1,2 这样的整形的格式
img,label=self.images[idx],self.labels[idx]
tf=transforms.Compose([
lambda x:Image.open(x).convert('RGB'), # 将t图片的路径转换可以处理图片数据
# 进行数据加强
transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
# 随机旋转
transforms.RandomRotation(15), # 设置旋转的度数小一些,否则的话会增加网络的学习难度
# 中心裁剪
transforms.CenterCrop(self.resize), # 此时:既旋转了又不至于导致图片变得比较的复杂
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
img=tf(img)
label=torch.tensor(label)
return img,label
def main():
# 验证工作
viz=visdom.Visdom()
db=pokemom('pokeman',64,'train') # 这里可以改变大小 224->64,可以通过visdom进行查看
# 可视化样本
x,y=next(iter(db))
print('sample:',x.shape,y.shape,y)
viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))
# 加载batch_size的数据
loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)
for x,y in loader:
viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))
viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
# 每一次加载后,休息10s
time.sleep(10)
if __name__ == '__main__':
main()
000000.png'
# 然而label得到的则是 0,1,2 这样的整形的格式
img,label=self.images[idx],self.labels[idx]
tf=transforms.Compose([
lambda x:Image.open(x).convert('RGB'), # 将t图片的路径转换可以处理图片数据
# 进行数据加强
transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
# 随机旋转
transforms.RandomRotation(15), # 设置旋转的度数小一些,否则的话会增加网络的学习难度
# 中心裁剪
transforms.CenterCrop(self.resize), # 此时:既旋转了又不至于导致图片变得比较的复杂
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
img=tf(img)
label=torch.tensor(label)
return img,label

def main():
# 验证工作
viz=visdom.Visdom()
db=pokemom('pokeman',64,'train') # 这里可以改变大小 224->64,可以通过visdom进行查看
# 可视化样本
x,y=next(iter(db))
print('sample:',x.shape,y.shape,y)
viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))
# 加载batch_size的数据
loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)
for x,y in loader:
viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))
viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
# 每一次加载后,休息10s
time.sleep(10)
if __name__ == '__main__':
main()


推荐阅读
  • 在本教程中,我们将看到如何使用FLASK制作第一个用于机器学习模型的RESTAPI。我们将从创建机器学习模型开始。然后,我们将看到使用Flask创建AP ... [详细]
  • 本文介绍了在Python3中如何使用选择文件对话框的格式打开和保存图片的方法。通过使用tkinter库中的filedialog模块的asksaveasfilename和askopenfilename函数,可以方便地选择要打开或保存的图片文件,并进行相关操作。具体的代码示例和操作步骤也被提供。 ... [详细]
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • 动量|收益率_基于MT策略的实战分析
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了基于MT策略的实战分析相关的知识,希望对你有一定的参考价值。基于MT策略的实战分析 ... [详细]
  • 基于词向量计算文本相似度1.测试数据:链接:https:pan.baidu.coms1fXJjcujAmAwTfsuTg2CbWA提取码:f4vx2.实验代码:imp ... [详细]
  • 语义分割系列3SegNet(pytorch实现)
    SegNet手稿最早是在2015年12月投出,和FCN属于同时期作品。稍晚于FCN,既然属于后来者,又是与FCN同属于语义分割网络 ... [详细]
  • 如何在mysqlshell命令中执行sql命令行本文介绍MySQL8.0shell子模块Util的两个导入特性importTableimport_table(JS和python版本 ... [详细]
  • Spring源码解密之默认标签的解析方式分析
    本文分析了Spring源码解密中默认标签的解析方式。通过对命名空间的判断,区分默认命名空间和自定义命名空间,并采用不同的解析方式。其中,bean标签的解析最为复杂和重要。 ... [详细]
  • 开发笔记:加密&json&StringIO模块&BytesIO模块
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了加密&json&StringIO模块&BytesIO模块相关的知识,希望对你有一定的参考价值。一、加密加密 ... [详细]
  • Java容器中的compareto方法排序原理解析
    本文从源码解析Java容器中的compareto方法的排序原理,讲解了在使用数组存储数据时的限制以及存储效率的问题。同时提到了Redis的五大数据结构和list、set等知识点,回忆了作者大学时代的Java学习经历。文章以作者做的思维导图作为目录,展示了整个讲解过程。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 本文总结了使用不同方式生成 Dataframe 的方法,包括通过CSV文件、Excel文件、python dictionary、List of tuples和List of dictionary。同时介绍了一些注意事项,如使用绝对路径引入文件和安装xlrd包来读取Excel文件。 ... [详细]
  • python3 nmap函数简介及使用方法
    本文介绍了python3 nmap函数的简介及使用方法,python-nmap是一个使用nmap进行端口扫描的python库,它可以生成nmap扫描报告,并帮助系统管理员进行自动化扫描任务和生成报告。同时,它也支持nmap脚本输出。文章详细介绍了python-nmap的几个py文件的功能和用途,包括__init__.py、nmap.py和test.py。__init__.py主要导入基本信息,nmap.py用于调用nmap的功能进行扫描,test.py用于测试是否可以利用nmap的扫描功能。 ... [详细]
  • 一、死锁现象与递归锁进程也是有死锁的所谓死锁:是指两个或两个以上的进程或线程在执行过程中,因争夺资源而造成的一种互相等待的现象,若无外力作 ... [详细]
author-avatar
潜水的飞机537
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有