热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch练手项目三:模型微调

本文目的:基于kaggle上狗的种类识别项目,展示如何利用PyTorch来进行模型微调。PyTorch中torchvision是一个针对视觉领域的工具库,除了提供有大量的数据集,还

本文目的:基于kaggle上狗的种类识别项目,展示如何利用PyTorch来进行模型微调。

PyTorch中torchvision是一个针对视觉领域的工具库,除了提供有大量的数据集,还有许多预训练的经典模型。这里以官方训练好的resnet50为例,拿来参加kaggle上面的dog breed狗的种类识别。

1 导入相关库,设置一些超参

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, models, transforms
import pandas as pd
import os
from PIL import Image
from sklearn.model_selection import StratifiedShuffleSplit

print(torch.__version__)  #1.1.0
print(torchvision.__version__) #0.3.0


#定义一些超参
IMG_SIZE = 224 #模型要求的输入尺寸
IMG_MEAN = [0.485, 0.456, 0.406] #图像预处理中需要的均值和方差
IMG_STD = [0.229, 0.224, 0.225]
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #尽量使用GPU
BATCH_SIZE = 64  #每一个batch的大小
EPOCHS = 7  #训练轮数

2 准备数据

Pytorch中数据的读取通常需要封装成Dataset类对象和DataLoader类对象。

2.1 获取数据并整理

首先下载官方的数据并解压,只要保持数据的目录结构即可,这里指定一下目录的位置,并且看下内容。(注意:labels.csv文件中有10222条标签,对应的是train文件夹中图像。)

#DATA_ROOT = r'D:\KaggleDatasets\competitions\dog-breed-identification'
#注1:常用'/'表相对路径,'\'表绝对路径,网页网址和linux系统下一般用'/'
DATA_ROOT = '/KaggleDatasets/competitions/dog-breed-identification'
df = pd.read_csv(os.path.join(DATA_ROOT, 'labels.csv'))
df.head()

PyTorch练手项目三:模型微调

为了后续方便,这里定义两个字典,并将类别序号添加进DataFrame中。

#分别以标签字符串和序号为索引,定义两个字典
breeds = df.breed.unique()
breed2idx = dict((breed,idx) for idx,breed in enumerate(breeds)) 
idx2breed = dict((idx,breed) for idx,breed in enumerate(breeds))
len(breeds) #120

#将类别序号添加到df的列 
df['label_idx'] = pd.Series(breed2idx, index=df.breed).values  
#df.shape  #(10222, 3)
df.head()

PyTorch练手项目三:模型微调

将数据分割成训练集和验证集。这里只分割10%的数据作为训练时的验证数据。

#分割数据集
shuffle_split = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=0) #分层切割
train_idx, val_idx = next(iter(shuffle_split.split(df, df.breed))) #split方法返回迭代器
train_df = df.iloc[train_idx].reset_index(drop=True) #(9199, 3)
val_df = df.iloc[val_idx].reset_index(drop=True)  #(1023, 3)

注2:StratifiedShuffleSplit().split(X, y)

  • 含义:This cross-validation object is a merge of StratifiedKFold and ShuffleSplit
  • split(X, y)方法中,X实际上只用了 np.zeros(n_samples) 来产生占位符,而切分主要靠y
  • KFold(shuffle=True)与ShuffleSplit的区别是:前者只shuffle一次,后者每次切分之前都要shuffle

注3:sklearn中几种数据切分方法

  • train_test_split:普通切分
  • KFold:普通K折切分
  • StratifiedKFold:分层K折切分
  • StratifiedShuffleSplit:每次shuffle后分层切分

2.2 自定义Dataset

torch.utils.data.Dataset是一个抽象类, 自定义的Dataset需要继承它并且实现两个成员方法:

  • __ len __ () :返回整个数据集的长度。
  • __ getitem __ () :每次怎么读取数据。

另外,transform过程也在此处传进来。

#自定义Dataset
class DogDataset(Dataset):
    def __init__(self, df, img_path, transform=None):
        self.df = df
        self.img_path = img_path
        self.transform = transform
    
    def __len__(self):
        return self.df.shape[0]  #返回数据集长度
    
    def __getitem__(self, idx):  #每次根据idx返回一个(image,label)数据对
        img_name = os.path.join(self.img_path, self.df.id[idx]) + '.jpg'
        img = Image.open(img_name)  #建议用PIL,而非skimage
        label = self.df.label_idx[idx]
        
        if self.transform:
            img = self.transform(img)
        return img, label
    

#自定义训练集和验证集的transform
train_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.RandomResizedCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    transforms.Normalize(IMG_MEAN, IMG_STD),
])

test_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE), #注4:传入一个int时,短边缩放到IMG_SIZE,长边按比例缩放
    transforms.CenterCrop(IMG_SIZE),  
    transforms.ToTensor(),
    transforms.Normalize(IMG_MEAN, IMG_STD),
])


#生成dataset
train_dataset = DogDataset(train_df, os.path.join(DATA_ROOT,'train'), train_transform)
val_dataset = DogDataset(val_df, os.path.join(DATA_ROOT,'train'), test_transform)

2.3 定义DataLoader

类定义为:

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, ...)

可以看到主要参数有这么几个:

  • dataset:即上面自定义的dataset;
  • batch_size:一个batch中样本个数;
  • shuffle:划分batch前是否打乱顺序;
  • sampler:定义抽样的策略;
  • batch_sampler:定义批次抽样的策略;
  • num_worker:定义多线程方法,默认为0。
#生成dataloader
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

3 准备模型

使用Pytorch中torchvision.models.resnet50。由于ImageNet是识别1000个物体,这里狗的分类一共只有120,所以需要对模型的最后一层全连接层进行微调,将输出从1000改为120。

#准备模型
model = models.resnet50(pretrained=True) #可用dir(model)查看属性及方法

#将所有参数冻结
for param in model.parameters(): 
    param.requires_grad = False
print(model.fc)

#修改fc层。可用model.named_parameters()迭代查看具体名称和参数
num_feature = model.fc.in_features  #获取fc层的输入个数
model.fc = nn.Linear(num_feature, len(breeds))  #重新定义fc层
print(model.fc)
#print(model)

#将model移至GPU
model.to(DEVICE)

PyTorch练手项目三:模型微调

注5:关于预训练模型的使用,需要

  • 传入pretrained=True,可加载预训练权重;
  • 模型使用前需要调用model.train(),或者model.eval()来开启或关闭BN和Dropout等;
  • 传给预训练模型的图像应符合:(可见2.2中定义的transform)
    • 3通道RGB格式;
    • shape为(3,H,W),其中H和W至少为224,若不够则需要Resize;
    • 以[0,1]范围加载后用mean=[0.485,0.456,0.406]和std=[0.229, 0.224, 0.225]来Normalize

4 训练

4.1 定义训练参数和函数

训练需要定义损失函数和优化器。另外也打包定义了训练和验证函数。

#指定损失函数和优化器
loss_fn = nn.CrossEntropyLoss()  #注6:默认的reduction为mean,即求平均损失
#optimizer = torch.optim.Adam([{'params':model.fc.parameters()}], lr=0.001) #定义fc层学习率
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)


#定义训练函数
#注7:训练5部曲:梯度清零,前向传播,计算损失,反向传播,梯度更新。
def train(model, train_loader, device, epoch):
    model.train()  #注8:开启训练模型,即开启BN和Dropout等
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device) #注9:模型和数据均要移至GPU
        #data和target的size分别为torch.Size([64, 3, 224, 224])、torch.Size([64])
        optimizer.zero_grad() #梯度清零
        yhat = model(data) #前向传播 torch.Size([64, 120])
        loss = loss_fn(yhat, target) #计算损失
        loss.backward() #反向传播
        optimizer.step() #更新梯度
    print('Train epoch {}\t Loss {:.6f}'.format(epoch, loss.item()))
    
    
#定义测试函数
def test(model, val_loader, device):
    model.eval()
    test_loss = 0  #记录测试损失
    correct = 0  #记录预测正确个数
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(val_loader):
            data, target = data.to(device), target.to(device)
            yhat = model(data)
            test_loss += loss_fn(yhat, target).item() #每次加上一个batch的平均损失值
            pred = torch.max(yhat, dim=1, keepdim=True)[1]  #注10:找到概率最大的下标
            correct += pred.eq(target.view_as(pred)).sum().item() #累加正确的样本个数
            
    test_loss /= len(val_loader) #注意此处是除以batch个数,而非len(val_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(
        test_loss, correct, len(val_loader.dataset),
        100. * correct / len(val_loader.dataset)))

4.2 开始训练

#开始训练
for epoch in range(1, EPOCHS+1):
    %time train(model, train_loader, DEVICE, epoch)
    test(model, val_loader, DEVICE)

从结果可以看出,运行几轮之后准确率大约在80%左右,比随机猜测(0.83%)要好很多。

Train epoch 1	 Loss 1.935438
Wall time: 3min 26s

Test set: Average loss: 1.2672, Accuracy: 723/1023 (70.7%)

Train epoch 2	 Loss 1.673698
Wall time: 1min 41s

Test set: Average loss: 0.8607, Accuracy: 782/1023 (76.4%)

Train epoch 3	 Loss 1.657430
Wall time: 1min 41s

Test set: Average loss: 0.7643, Accuracy: 795/1023 (77.7%)

Train epoch 4	 Loss 1.463368
Wall time: 1min 40s

Test set: Average loss: 0.7109, Accuracy: 806/1023 (78.8%)

Train epoch 5	 Loss 1.849077
Wall time: 1min 40s

Test set: Average loss: 0.7227, Accuracy: 803/1023 (78.5%)

Train epoch 6	 Loss 1.442590
Wall time: 1min 40s

Test set: Average loss: 0.7080, Accuracy: 796/1023 (77.8%)

Train epoch 7	 Loss 1.540823
Wall time: 1min 41s

Test set: Average loss: 0.6738, Accuracy: 822/1023 (80.4%)

5 小结

  • 普通任务的过程:准备数据、准备模型、训练、评估或预测;
  • 如何对预训练模型进行微调;
  • 利用Pandas和sklearn工具处理数据;
  • 标注的10个注意事项。

Reference

  • https://github.com/zergtant/pytorch-handbook
  • https://pytorch.org/docs/stable/index.html

推荐阅读
  • 本文介绍了如何使用php限制数据库插入的条数并显示每次插入数据库之间的数据数目,以及避免重复提交的方法。同时还介绍了如何限制某一个数据库用户的并发连接数,以及设置数据库的连接数和连接超时时间的方法。最后提供了一些关于浏览器在线用户数和数据库连接数量比例的参考值。 ... [详细]
  • 图解redis的持久化存储机制RDB和AOF的原理和优缺点
    本文通过图解的方式介绍了redis的持久化存储机制RDB和AOF的原理和优缺点。RDB是将redis内存中的数据保存为快照文件,恢复速度较快但不支持拉链式快照。AOF是将操作日志保存到磁盘,实时存储数据但恢复速度较慢。文章详细分析了两种机制的优缺点,帮助读者更好地理解redis的持久化存储策略。 ... [详细]
  • 基于事件驱动的并发编程及其消息通信机制的同步与异步、阻塞与非阻塞、IO模型的分类
    本文介绍了基于事件驱动的并发编程中的消息通信机制,包括同步和异步的概念及其区别,阻塞和非阻塞的状态,以及IO模型的分类。同步阻塞IO、同步非阻塞IO、异步阻塞IO和异步非阻塞IO等不同的IO模型被详细解释。这些概念和模型对于理解并发编程中的消息通信和IO操作具有重要意义。 ... [详细]
  • 本文讨论了在openwrt-17.01版本中,mt7628设备上初始化启动时eth0的mac地址总是随机生成的问题。每次随机生成的eth0的mac地址都会写到/sys/class/net/eth0/address目录下,而openwrt-17.01原版的SDK会根据随机生成的eth0的mac地址再生成eth0.1、eth0.2等,生成后的mac地址会保存在/etc/config/network下。 ... [详细]
  • MyBatis多表查询与动态SQL使用
    本文介绍了MyBatis多表查询与动态SQL的使用方法,包括一对一查询和一对多查询。同时还介绍了动态SQL的使用,包括if标签、trim标签、where标签、set标签和foreach标签的用法。文章还提供了相关的配置信息和示例代码。 ... [详细]
  • 本文介绍了Swing组件的用法,重点讲解了图标接口的定义和创建方法。图标接口用来将图标与各种组件相关联,可以是简单的绘画或使用磁盘上的GIF格式图像。文章详细介绍了图标接口的属性和绘制方法,并给出了一个菱形图标的实现示例。该示例可以配置图标的尺寸、颜色和填充状态。 ... [详细]
  • MySQL多表数据库操作方法及子查询详解
    本文详细介绍了MySQL数据库的多表操作方法,包括增删改和单表查询,同时还解释了子查询的概念和用法。文章通过示例和步骤说明了如何进行数据的插入、删除和更新操作,以及如何执行单表查询和使用聚合函数进行统计。对于需要对MySQL数据库进行操作的读者来说,本文是一个非常实用的参考资料。 ... [详细]
  • OpenMap教程4 – 图层概述
    本文介绍了OpenMap教程4中关于地图图层的内容,包括将ShapeLayer添加到MapBean中的方法,OpenMap支持的图层类型以及使用BufferedLayer创建图像的MapBean。此外,还介绍了Layer背景标志的作用和OMGraphicHandlerLayer的基础层类。 ... [详细]
  • 通过Anaconda安装tensorflow,并安装运行spyder编译器的完整教程
    本文提供了一个完整的教程,介绍了如何通过Anaconda安装tensorflow,并安装运行spyder编译器。文章详细介绍了安装Anaconda、创建tensorflow环境、安装GPU版本tensorflow、安装和运行Spyder编译器以及安装OpenCV等步骤。该教程适用于Windows 8操作系统,并提供了相关的网址供参考。通过本教程,读者可以轻松地安装和配置tensorflow环境,以及运行spyder编译器进行开发。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • 本文讨论了在数据库打开和关闭状态下,重新命名或移动数据文件和日志文件的情况。针对性能和维护原因,需要将数据库文件移动到不同的磁盘上或重新分配到新的磁盘上的情况,以及在操作系统级别移动或重命名数据文件但未在数据库层进行重命名导致报错的情况。通过三个方面进行讨论。 ... [详细]
  • 网址:https:vue.docschina.orgv2guideforms.html表单input绑定基础用法可以通过使用v-model指令,在 ... [详细]
  • Gitlab接入公司内部单点登录的安装和配置教程
    本文介绍了如何将公司内部的Gitlab系统接入单点登录服务,并提供了安装和配置的详细教程。通过使用oauth2协议,将原有的各子系统的独立登录统一迁移至单点登录。文章包括Gitlab的安装环境、版本号、编辑配置文件的步骤,并解决了在迁移过程中可能遇到的问题。 ... [详细]
  • 本文介绍了利用ARMA模型对平稳非白噪声序列进行建模的步骤及代码实现。首先对观察值序列进行样本自相关系数和样本偏自相关系数的计算,然后根据这些系数的性质选择适当的ARMA模型进行拟合,并估计模型中的位置参数。接着进行模型的有效性检验,如果不通过则重新选择模型再拟合,如果通过则进行模型优化。最后利用拟合模型预测序列的未来走势。文章还介绍了绘制时序图、平稳性检验、白噪声检验、确定ARMA阶数和预测未来走势的代码实现。 ... [详细]
  • 我用Tkinter制作了一个图形用户界面,有两个主按钮:“开始”和“停止”。请您就如何使用“停止”按钮终止“开始”按钮为以下代码调用的已运行功能提供建议 ... [详细]
author-avatar
zhengxing
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有