热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

利用PyTorch快速实现分类任务

关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P

关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习PyTorch,找了很多自定义数据加载的方法,还是使用torch中封装的库函数好用,而且快捷,会根据路径自动返回对应的标签,下面的代码每一行都给了注释。import torchfrom torchvision import transforms, utilsfrom torchvision import datasetsimport torch.utils.dataimport matplotlib.pyplot as plt# 定义图像预处理transform1 = tranhttps://blog.csdn.net/weixin_55737425/article/details/122958584

这里给出一个模板,适合想要快速实现的朋友们(想要快速做出效果),不需要多少理论知识,只需要将文中的文件地址更改为自己的电脑上的地址即可。(注意图片的保存方式有一定的格式,详细可以查阅ImageFolder函数的用法)

此处每一行的代码都已经标记缘由和作用,如果还有疑惑,欢迎垂询问题!

import random
from torch.utils.data import DataLoader
from torchvision.models import resnet50
from imutils import paths
import torch.nn as nn
from torch import optim
import numpy
from torchvision import transforms, utils
import torch
from torchvision import datasets
import matplotlib.pyplot as pltdef load_data():transform1 = transforms.Compose([ # 这里最好加上一个中括号,否则会被认为是意外实参transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转,概率为0.3transforms.RandomVerticalFlip(p=0.5), # 随机垂直翻转,概率为0.3# transforms.CenterCrop((400, 400)),transforms.ToTensor(), # 转换成Tensor类型transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.255)) # 这里是为了和官方文档保持一致])batch_size = 8train_data = datasets.ImageFolder(r"C:\Users\asus\Desktop\cnn_data\cnn_data\data\training_data", transform=transform1)# print(train_data.imgs)# 加载数据train_data = DataLoader(train_data, batch_size=batch_size, shuffle=True)return train_data# def im_convert(tensor): # 这里可以不用理睬,我是想要显示原来图片的
# image = tensor
# image = image.numpy().squeeze()
# image = image.transpose(1, 2, 0)
# iamge = image*numpy.array(0.229, 0.224, 0.255) + numpy.array(0.485, 0.456, 0.406)
# image = image.clip(0, 1)
# return imagedef train(train_data):lr = 0.0001EPOCH = 12 # 可以自己调整,多一点会更好,但十分耗时间device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 利用GPU进行训练model = resnet50(pretrained=True).to(device) # 此处使用迁移学习的方法预加载权重,此处会下载一段时间model.train() # 设置运行模式in_channel = model.fc.in_features # 获取全连接层中输入的维数model.fc = nn.Linear(in_channel, 2) # 重新赋值全连接层criterion = nn.CrossEntropyLoss().to(device) # 分类问题使用交叉熵的方法optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) # 也可以使用Adam,效果也好,momentum根据文献资料,0.9为最优选择scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1) # 每经历2个epoch就衰减十分之一,也可以自己选择running_loss = 0for epoch in range(0, EPOCH):correct = 0for i, (data, target) in enumerate(train_data, 1):data = torch.autograd.Variable(data).to(device)target = torch.autograd.Variable(target).to(device)optimizer.zero_grad() # 清空上一次的梯度值output = model(data)loss = criterion(output, target)running_loss = loss.item()loss.backward()optimizer.step()prediction = torch.argmax(output, dim=1) # 返回维度为dim上最大值的索引correct += (prediction == target).sum().item() # 当prediction==target时会返回“1”,predicton和target在此处都是tensor类型,所以返回的是“tensor(1)”,之后通过item返回数值if i % 2 == 0:print("第{}个EPOCH,第{}个batch,当前损失为{}".format(epoch+1, i, running_loss))print("本轮训练的准确率为{:}".format(correct/len(train_data)))
if __name__ == '__main__':train_data = load_data()train(train_data)


推荐阅读
  • 【图像分类实战】利用DenseNet在PyTorch中实现秃头识别
    本文详细介绍了如何使用DenseNet模型在PyTorch框架下实现秃头识别。首先,文章概述了项目所需的库和全局参数设置。接着,对图像进行预处理并读取数据集。随后,构建并配置DenseNet模型,设置训练和验证流程。最后,通过测试阶段验证模型性能,并提供了完整的代码实现。本文不仅涵盖了技术细节,还提供了实用的操作指南,适合初学者和有经验的研究人员参考。 ... [详细]
  • OBS Studio自动化实践:利用脚本批量生成录制场景
    本文探讨了如何利用OBS Studio进行高效录屏,并通过脚本实现场景的自动生成。适合对自动化办公感兴趣的读者。 ... [详细]
  • CoreData 表关联详解
    在企业中,通常会有多个部门,每个员工隶属于某个部门。这种情况下,员工表和部门表之间就会形成关联关系。本文将详细介绍如何在CoreData中实现表关联,并通过示例代码展示如何添加和查询关联数据。 ... [详细]
  • Spring Boot + RabbitMQ 消息确认机制详解
    本文详细介绍如何在 Spring Boot 项目中使用 RabbitMQ 的消息确认机制,包括消息发送确认和消息接收确认,帮助开发者解决在实际操作中可能遇到的问题。 ... [详细]
  • 目录预备知识导包构建数据集神经网络结构训练测试精度可视化计算模型精度损失可视化输出网络结构信息训练神经网络定义参数载入数据载入神经网络结构、损失及优化训练及测试损失、精度可视化qu ... [详细]
  • Ihavetwomethodsofgeneratingmdistinctrandomnumbersintherange[0..n-1]我有两种方法在范围[0.n-1]中生 ... [详细]
  • 本文介绍了如何利用 `matplotlib` 库中的 `FuncAnimation` 类将 Python 中的动态图像保存为视频文件。通过详细解释 `FuncAnimation` 类的参数和方法,文章提供了多种实用技巧,帮助用户高效地生成高质量的动态图像视频。此外,还探讨了不同视频编码器的选择及其对输出文件质量的影响,为读者提供了全面的技术指导。 ... [详细]
  • 在Java基础中,私有静态内部类是一种常见的设计模式,主要用于防止外部类的直接调用或实例化。这种内部类仅服务于其所属的外部类,确保了代码的封装性和安全性。通过分析JDK源码,我们可以发现许多常用类中都包含了私有静态内部类,这些内部类虽然功能强大,但其复杂性往往让人感到困惑。本文将深入探讨私有静态内部类的作用、实现方式及其在实际开发中的应用,帮助读者更好地理解和使用这一重要的编程技巧。 ... [详细]
  • 对于初学者而言,搭建一个高效稳定的 Python 开发环境是入门的关键一步。本文将详细介绍如何利用 Anaconda 和 Jupyter Notebook 来构建一个既易于管理又功能强大的开发环境。 ... [详细]
  • 二维码的实现与应用
    本文介绍了二维码的基本概念、分类及其优缺点,并详细描述了如何使用Java编程语言结合第三方库(如ZXing和qrcode.jar)来实现二维码的生成与解析。 ... [详细]
  • 图像处理学习笔记:噪声分析与去除策略
    本文详细探讨了不同类型的图像噪声及其对应的降噪技术,旨在帮助读者理解各种噪声的本质,并掌握有效的降噪方法。文章不仅介绍了高斯噪声、瑞利噪声、伽马噪声、指数噪声、均匀噪声和椒盐噪声等常见噪声类型,还特别讨论了周期噪声的特性及处理技巧。 ... [详细]
  • 小编给大家分享一下Vue3中如何提高开发效率,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获, ... [详细]
  • HTML:  将文件拖拽到此区域 ... [详细]
  • 本文探讨了如何通过优化SOAP服务调用和多线程处理来减少生成的事件数量,并提高加载大量实体的效率。 ... [详细]
  • 在机器学习中,我们经常需要对训练数据进行随机打乱以提高模型的泛化能力。本文介绍如何使用 numpy.random.permutation 函数在打乱数据的同时保持 x 和 y 的原始映射关系。 ... [详细]
author-avatar
漫湾镇团委
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有