热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch常见预训练模型的下载链接及使用指南

本文提供了PyTorch框架中常用的预训练模型的下载链接及详细使用指南,涵盖ResNet、Inception、DenseNet、AlexNet、VGGNet等六大分类模型。每种模型的预训练参数均经过精心调优,适用于多种计算机视觉任务。文章不仅介绍了模型的下载方式,还详细说明了如何在实际项目中高效地加载和使用这些模型,为开发者提供全面的技术支持。

pytorch框架:常用模型的预训练参数

六大分类模型下载方式和使用方法:
Resnet
inception
Densenet
Alexnet
vggnet

Resnet:
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
inception:
model_urls = {
# Inception v3 ported from TensorFlow
'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
}
Densenet:
model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}
Alexnet:
model_urls = {
'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}
vggnet:
model_urls = {
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}

学习内容:测试实现预训练模型的使用,并牢记该方式-拿为己用

关键步骤讲述:



  1. 默认已经安装好环境和pytorch框架,以及torchvision等需要的库。



  2. import torchvision.models as models 所有成熟网络模型几乎都在里面



  3. # 初始化模型 model = models.resnet18()此处应用ResNet18来分类。



  4. 修改尾巴,毕竟你的输出不一定和原版(1000)一模一样。
    # 修改网络结构,将fc层1000个输出改为9个输出。
    # 获取最后一层的输入特征层信息。 fc_input_feature = model.fc.in_features
    # 取代原来输出层为新的nn。 model.fc = nn.Linear(fc_input_feature, 9)到这里,网络就构建好了。



  5. 下载预训练参数,为己所用。# load除最后一层的预训练权重 pretrained_weight = torch.hub.load_state_dict_from_url( url='https://download.pytorch.org/models/resnet18-5c106cde.pth', progress=True)到这里,下载的是原版的1000分类的参数,我们需要删除不需要的尾巴,并训练自己的尾巴。del pretrained_weight['fc.weight']
    del pretrained_weight['fc.bias']因为分类就是用的线性函数,包括权重w和偏移b,只需删除尾巴。



  6. 最后,将剩下的模型参数load到我们的模型上即可。model.load_state_dict(pretrained_weight, strict=False)模型准备完毕,剩下的操作和所有训练方法一样。参见详细训练代码。



import os
import torch
from torch.utils.data import DataLoader
from torch import nn
from torch import optim
import torchvision.models as models
import time
# use res18
# from resnet.resnetmini import ClassificModel as Model
from datasets.read_data_sleep import PlayPhoneData
def train(data_path=r"E:\Datasets\sleep_traindata"):
# 设置超参数
batch_size = 1 # 每次训练的数据量
LR = 0.01 # 学习率
STEP_SIZE = 5 # 控制学习率变化
MAX_EPOCH = 20 # 总的训练次数
num_print = 100 # 每n个batch打印一次
playPhoneData = PlayPhoneData(data_path)
# 利用dataloader加载数据集
train_loader = torch.utils.data.DataLoader(playPhoneData, batch_size=batch_size, shuffle=True, drop_last=True)
# 生成驱动器
use_gpu = torch.cuda.is_available()
if use_gpu:
print('congratulation! You can use gpu to support acceleration')
else:
print('oppps, please use a small batch size')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 初始化模型
model = models.resnet18()
# 修改网络结构,将fc层1000个输出改为9个输出
fc_input_feature = model.fc.in_features
model.fc = nn.Linear(fc_input_feature, 9)
# load除最后一层的预训练权重
pretrained_weight = torch.hub.load_state_dict_from_url(
url='https://download.pytorch.org/models/resnet18-5c106cde.pth', progress=True)
del pretrained_weight['fc.weight']
del pretrained_weight['fc.bias']
model.load_state_dict(pretrained_weight, strict=False)
model.to(device)
# net = Model(8).to(device) # class_num=8分八类:睡岗(趴着睡,躺着睡,仰着睡,低头睡),玩手机(俯视玩手机,平视玩手机,侧视玩手机),其他=[0,1,2,3,4,5,6,7]
# net = Model(9).to(device) # class_num=9分九类:睡岗(趴着睡,躺着睡,低头睡),站立,半蹲,坐(背坐,正坐,侧坐),其他=[0,1,2,3,4,5,6,7,8]
# 损失函数
get_loss = nn.CrossEntropyLoss() #交叉熵损失函数
# SGD优化器 第一个参数是输入需要优化的参数,第二个是学习率,第三个是动量,大致就是借助上一次导数结果,加快收敛速度。
'''
这一行代码里面实际上包含了多种优化:
一个是动量优化,增加了一个关于上一次迭代得到的系数的偏置,借助上一次的指导,减小梯度震荡,加快收敛速度
一个是权重衰减,通过对权重增加一个(正则项),该正则项会使得迭代公式中的权重按照比例缩减,这么做的原因是,过拟合的表现一般为参数浮动大,使用小参数可以防止过拟合
'''
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=0.001)
# optimizer = optim.Adam(net.parameters(), lr=learn_rate)
# 动态调整学习率 StepLR 是等间隔调整学习率,每step_size 令lr=lr*gamma
# 学习率衰减,随着训练的加深,目前的权重也越来越接近最优权重,原本的学习率会使得,loss上下震荡,逐步减小学习率能加快收敛速度。
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=0.5, last_epoch=-1)
# Step:设置学习率下降策略
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
loss_list = []
start = time.time()
for epoch in range(MAX_EPOCH):
running_loss = 0.0
# enumerate()是python自带的函数,用于迭代字典。参数1,是需要迭代的对象,第二参数是迭代的起始位置
for i, (inputs, labels) in enumerate(train_loader, 0):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs) # 前向传播求出预测的值
optimizer.zero_grad() # 将梯度初始化为0
loss = get_loss(outputs, labels.long())
loss.backward() # 反向传播求梯度
optimizer.step() # 更新所有参数
running_loss += loss.item() # loss是张量,访问值时需要使用item()
loss_list.append(loss.item())
if i % num_print == num_print - 1: # 每num_print打印平均loss
print('[%d epoch, %d] loss: %.6f' % (epoch + 1, i + 1, running_loss / num_print))
running_loss = 0.0
lr = optimizer.param_groups[0]['lr'] # 查看目前的学习率
print('learn_rate : %.5f' % lr)
scheduler.step() # 根据迭代epoch更新学习率
end = time.time()
print('time:{}'.format(end - start))
torch.save(model, f'E:/model/playphone+sleepthepose/model_resnetmini_睡岗9分类{end}.pth')
if __name__ == "__main__":
train()

训练情况:

......
[3 epoch, 500] loss: 2.186424
[3 epoch, 600] loss: 2.192622
[3 epoch, 700] loss: 2.165229
[3 epoch, 800] loss: 2.125184
[3 epoch, 900] loss: 2.185377
learn_rate : 0.01000
[4 epoch, 100] loss: 2.138786
[4 epoch, 200] loss: 2.177925
[4 epoch, 300] loss: 2.103718
......

备注:代码只是讲解工具,并非可以运行的实例,因为里面的数据集需要有并自己写数据集的代码。




学习内容:进阶应用方法

直接拿来用固然不错,但自己分装一遍再用,显得更加标准,有水平。
比如封装如下:


class ResNet18forClassify(nn.Module):
def __init__(self, phase="train"):
super(ResNet18forClassify, self).__init__()
self.phase = phase
self.net = models.resnet18()
fc_input_feature = self.net.fc.in_features
self.net.fc = nn.Linear(fc_input_feature, 9)
pretrained_weight = torch.hub.load_state_dict_from_url(
url='https://download.pytorch.org/models/resnet18-5c106cde.pth', progress=True)
del pretrained_weight['fc.weight']
del pretrained_weight['fc.bias']
self.net.load_state_dict(pretrained_weight, strict=False)
self.softmax = nn.Softmax(dim=1)
def forward(self, input_img):
out = self.net(input_img)
if self.phase == "test":
return self.softmax(out)
return out

备注:封装成自己的网络模型,更加方便。
其中,if self.phase == "test": return self.softmax(out),分类时训练输出的是类别标签与实际标签做损失计算;测试时,预测结果由激活函数转换为–类型和该类型可能性概率。输出可能是该类别的概率值。


参考文献:

1.https://github.com/pytorch/vision/tree/master/torchvision/models
2.环境搭建:NVIDIA+CUDA+cudaNN的配置与Anaconda虚拟环境的搭建–深度学习第一步
3.Parallax:常用预训练模型下载地址



来源:柏常青



推荐阅读
  • 对象自省自省在计算机编程领域里,是指在运行时判断一个对象的类型和能力。dir能够返回一个列表,列举了一个对象所拥有的属性和方法。my_list[ ... [详细]
  • 在项目部署后,Node.js 进程可能会遇到不可预见的错误并崩溃。为了及时通知开发人员进行问题排查,我们可以利用 nodemailer 插件来发送邮件提醒。本文将详细介绍如何配置和使用 nodemailer 实现这一功能。 ... [详细]
  • Coursera ML 机器学习
    2019独角兽企业重金招聘Python工程师标准线性回归算法计算过程CostFunction梯度下降算法多变量回归![选择特征](https:static.oschina.n ... [详细]
  • 历经三十年的开发,Mathematica 已成为技术计算领域的标杆,为全球的技术创新者、教育工作者、学生及其他用户提供了一个领先的计算平台。最新版本 Mathematica 12.3.1 增加了多项核心语言、数学计算、可视化和图形处理的新功能。 ... [详细]
  • CSS高级技巧:动态高亮当前页面导航
    本文介绍了如何使用CSS实现网站导航栏中当前页面的高亮显示,提升用户体验。通过为每个页面的body元素添加特定ID,并结合导航项的类名,可以轻松实现这一功能。 ... [详细]
  • 深入解析Spring启动过程
    本文详细介绍了Spring框架的启动流程,帮助开发者理解其内部机制。通过具体示例和代码片段,解释了Bean定义、工厂类、读取器以及条件评估等关键概念,使读者能够更全面地掌握Spring的初始化过程。 ... [详细]
  • 探讨ChatGPT在法律和版权方面的潜在风险及影响,分析其作为内容创造工具的合法性和合规性。 ... [详细]
  • 本文将详细介绍如何在没有显示器的情况下,使用Raspberry Pi Imager为树莓派4B安装操作系统,并进行基本配置,包括设置SSH、WiFi连接以及更新软件源。 ... [详细]
  • 本文详细介绍了如何通过RPM包在Linux系统(如CentOS)上安装MySQL 5.6。涵盖了检查现有安装、下载和安装RPM包、配置MySQL以及设置远程访问和开机自启动等步骤。 ... [详细]
  • 反向投影技术主要用于在大型输入图像中定位特定的小型模板图像。通过直方图对比,它能够识别出最匹配的区域或点,从而确定模板图像在输入图像中的位置。 ... [详细]
  • Java 实现二维极点算法
    本文介绍了一种使用 Java 编程语言实现的二维极点算法。该算法用于从一组二维坐标中筛选出极点,适用于需要处理几何图形和空间数据的应用场景。文章不仅详细解释了算法的工作原理,还提供了完整的代码示例。 ... [详细]
  • 本文介绍了SVD(奇异值分解)和QR分解的基本原理及其在Python中的实现方法。通过具体代码示例,展示了如何使用这两种矩阵分解技术处理图像数据和计算特征值。 ... [详细]
  • 本文介绍如何从字符串中移除大写、小写、特殊、数字和非数字字符,并提供了多种编程语言的实现示例。 ... [详细]
  • Linux环境下C语言实现定时向文件写入当前时间
    本文介绍如何在Linux系统中使用C语言编程,实现在每秒钟向指定文件中写入当前时间戳。通过此示例,读者可以了解基本的文件操作、时间处理以及循环控制。 ... [详细]
  • 离线安装Grafana Cloudera Manager插件并监控CDH集群
    本文详细介绍如何离线安装Cloudera Manager (CM) 插件,并通过Grafana监控CDH集群的健康状况和资源使用情况。该插件利用CM提供的API接口进行数据获取和展示。 ... [详细]
author-avatar
UUUUUUUUUU8
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有