热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

利用PyTorch快速实现分类任务

关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P

关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习PyTorch,找了很多自定义数据加载的方法,还是使用torch中封装的库函数好用,而且快捷,会根据路径自动返回对应的标签,下面的代码每一行都给了注释。import torchfrom torchvision import transforms, utilsfrom torchvision import datasetsimport torch.utils.dataimport matplotlib.pyplot as plt# 定义图像预处理transform1 = tranhttps://blog.csdn.net/weixin_55737425/article/details/122958584

这里给出一个模板,适合想要快速实现的朋友们(想要快速做出效果),不需要多少理论知识,只需要将文中的文件地址更改为自己的电脑上的地址即可。(注意图片的保存方式有一定的格式,详细可以查阅ImageFolder函数的用法)

此处每一行的代码都已经标记缘由和作用,如果还有疑惑,欢迎垂询问题!

import random
from torch.utils.data import DataLoader
from torchvision.models import resnet50
from imutils import paths
import torch.nn as nn
from torch import optim
import numpy
from torchvision import transforms, utils
import torch
from torchvision import datasets
import matplotlib.pyplot as pltdef load_data():transform1 = transforms.Compose([ # 这里最好加上一个中括号,否则会被认为是意外实参transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转,概率为0.3transforms.RandomVerticalFlip(p=0.5), # 随机垂直翻转,概率为0.3# transforms.CenterCrop((400, 400)),transforms.ToTensor(), # 转换成Tensor类型transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.255)) # 这里是为了和官方文档保持一致])batch_size = 8train_data = datasets.ImageFolder(r"C:\Users\asus\Desktop\cnn_data\cnn_data\data\training_data", transform=transform1)# print(train_data.imgs)# 加载数据train_data = DataLoader(train_data, batch_size=batch_size, shuffle=True)return train_data# def im_convert(tensor): # 这里可以不用理睬,我是想要显示原来图片的
# image = tensor
# image = image.numpy().squeeze()
# image = image.transpose(1, 2, 0)
# iamge = image*numpy.array(0.229, 0.224, 0.255) + numpy.array(0.485, 0.456, 0.406)
# image = image.clip(0, 1)
# return imagedef train(train_data):lr = 0.0001EPOCH = 12 # 可以自己调整,多一点会更好,但十分耗时间device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 利用GPU进行训练model = resnet50(pretrained=True).to(device) # 此处使用迁移学习的方法预加载权重,此处会下载一段时间model.train() # 设置运行模式in_channel = model.fc.in_features # 获取全连接层中输入的维数model.fc = nn.Linear(in_channel, 2) # 重新赋值全连接层criterion = nn.CrossEntropyLoss().to(device) # 分类问题使用交叉熵的方法optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) # 也可以使用Adam,效果也好,momentum根据文献资料,0.9为最优选择scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1) # 每经历2个epoch就衰减十分之一,也可以自己选择running_loss = 0for epoch in range(0, EPOCH):correct = 0for i, (data, target) in enumerate(train_data, 1):data = torch.autograd.Variable(data).to(device)target = torch.autograd.Variable(target).to(device)optimizer.zero_grad() # 清空上一次的梯度值output = model(data)loss = criterion(output, target)running_loss = loss.item()loss.backward()optimizer.step()prediction = torch.argmax(output, dim=1) # 返回维度为dim上最大值的索引correct += (prediction == target).sum().item() # 当prediction==target时会返回“1”,predicton和target在此处都是tensor类型,所以返回的是“tensor(1)”,之后通过item返回数值if i % 2 == 0:print("第{}个EPOCH,第{}个batch,当前损失为{}".format(epoch+1, i, running_loss))print("本轮训练的准确率为{:}".format(correct/len(train_data)))
if __name__ == '__main__':train_data = load_data()train(train_data)


推荐阅读
  • 通过使用CIFAR-10数据集,本文详细介绍了如何快速掌握Mixup数据增强技术,并展示了该方法在图像分类任务中的显著效果。实验结果表明,Mixup能够有效提高模型的泛化能力和分类精度,为图像识别领域的研究提供了有价值的参考。 ... [详细]
  • 【图像分类实战】利用DenseNet在PyTorch中实现秃头识别
    本文详细介绍了如何使用DenseNet模型在PyTorch框架下实现秃头识别。首先,文章概述了项目所需的库和全局参数设置。接着,对图像进行预处理并读取数据集。随后,构建并配置DenseNet模型,设置训练和验证流程。最后,通过测试阶段验证模型性能,并提供了完整的代码实现。本文不仅涵盖了技术细节,还提供了实用的操作指南,适合初学者和有经验的研究人员参考。 ... [详细]
  • 在稀疏直接法视觉里程计中,通过优化特征点并采用基于光度误差最小化的灰度图像线性插值技术,提高了定位精度。该方法通过对空间点的非齐次和齐次表示进行处理,利用RGB-D传感器获取的3D坐标信息,在两帧图像之间实现精确匹配,有效减少了光度误差,提升了系统的鲁棒性和稳定性。 ... [详细]
  • Java中不同类型的常量池(字符串常量池、Class常量池和运行时常量池)的对比与关联分析
    在研究Java虚拟机的过程中,笔者发现存在多种类型的常量池,包括字符串常量池、Class常量池和运行时常量池。通过查阅CSDN、博客园等相关资料,对这些常量池的特性、用途及其相互关系进行了详细探讨。本文将深入分析这三种常量池的差异与联系,帮助读者更好地理解Java虚拟机的内部机制。 ... [详细]
  • Android中将独立SO库封装进JAR包并实现SO库的加载与调用
    在Android开发中,将独立的SO库封装进JAR包并实现其加载与调用是一个常见的需求。本文详细介绍了如何将SO库嵌入到JAR包中,并确保在外部应用调用该JAR包时能够正确加载和使用这些SO库。通过这种方式,开发者可以更方便地管理和分发包含原生代码的库文件,提高开发效率和代码复用性。文章还探讨了常见的问题及其解决方案,帮助开发者避免在实际应用中遇到的坑。 ... [详细]
  • 本文探讨了 Kafka 集群的高效部署与优化策略。首先介绍了 Kafka 的下载与安装步骤,包括从官方网站获取最新版本的压缩包并进行解压。随后详细讨论了集群配置的最佳实践,涵盖节点选择、网络优化和性能调优等方面,旨在提升系统的稳定性和处理能力。此外,还提供了常见的故障排查方法和监控方案,帮助运维人员更好地管理和维护 Kafka 集群。 ... [详细]
  • Python 实战:异步爬虫(协程技术)与分布式爬虫(多进程应用)深入解析
    本文将深入探讨 Python 异步爬虫和分布式爬虫的技术细节,重点介绍协程技术和多进程应用在爬虫开发中的实际应用。通过对比多进程和协程的工作原理,帮助读者理解两者在性能和资源利用上的差异,从而在实际项目中做出更合适的选择。文章还将结合具体案例,展示如何高效地实现异步和分布式爬虫,以提升数据抓取的效率和稳定性。 ... [详细]
  • 在过去,我曾使用过自建MySQL服务器中的MyISAM和InnoDB存储引擎(也曾尝试过Memory引擎)。今年初,我开始转向阿里云的关系型数据库服务,并深入研究了其高效的压缩存储引擎TokuDB。TokuDB在数据压缩和处理大规模数据集方面表现出色,显著提升了存储效率和查询性能。通过实际应用,我发现TokuDB不仅能够有效减少存储成本,还能显著提高数据处理速度,特别适用于高并发和大数据量的场景。 ... [详细]
  • 在Windows环境下离线安装PyTorch GPU版时,首先需确认系统配置,例如本文作者使用的是Win8、CUDA 8.0和Python 3.6.5。用户应根据自身Python和CUDA版本,在PyTorch官网查找并下载相应的.whl文件。此外,建议检查系统环境变量设置,确保CUDA路径正确配置,以避免安装过程中可能出现的兼容性问题。 ... [详细]
  • 表面缺陷检测数据集综述及GitHub开源项目推荐
    本文综述了表面缺陷检测领域的数据集,并推荐了多个GitHub上的开源项目。通过对现有文献和数据集的系统整理,为研究人员提供了全面的资源参考,有助于推动该领域的发展和技术进步。 ... [详细]
  • 在Java Web服务开发中,Apache CXF 和 Axis2 是两个广泛使用的框架。CXF 由于其与 Spring 框架的无缝集成能力,以及更简便的部署方式,成为了许多开发者的首选。本文将详细介绍如何使用 CXF 框架进行 Web 服务的开发,包括环境搭建、服务发布和客户端调用等关键步骤,为开发者提供一个全面的实践指南。 ... [详细]
  • NOIP2000的单词接龙问题与常见的成语接龙游戏有异曲同工之妙。题目要求在给定的一组单词中,从指定的起始字母开始,构建最长的“单词链”。每个单词在链中最多可出现两次。本文将详细解析该题目的解法,并分享学习过程中的心得体会。 ... [详细]
  • 本文介绍了如何利用ObjectMapper实现JSON与JavaBean之间的高效转换。ObjectMapper是Jackson库的核心组件,能够便捷地将Java对象序列化为JSON格式,并支持从JSON、XML以及文件等多种数据源反序列化为Java对象。此外,还探讨了在实际应用中如何优化转换性能,以提升系统整体效率。 ... [详细]
  • 开发日志:201521044091 《Java编程基础》第11周学习心得与总结
    开发日志:201521044091 《Java编程基础》第11周学习心得与总结 ... [详细]
  • NVIDIA最新推出的Ampere架构标志着显卡技术的一次重大突破,不仅在性能上实现了显著提升,还在能效比方面进行了深度优化。该架构融合了创新设计与技术改进,为用户带来更加流畅的图形处理体验,同时降低了功耗,提升了计算效率。 ... [详细]
author-avatar
漫湾镇团委
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有