热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

图像分类训练方案优化设计

针对图像分类任务的训练方案进行了优化设计。通过引入PyTorch等深度学习框架,利用其丰富的工具包和模块,如`torch.nn`和`torch.nn.functional`,提升了模型的训练效率和分类准确性。优化方案包括数据预处理、模型架构选择和损失函数的设计等方面,旨在提高图像分类任务的整体性能。

图像分类训练设计图像分类训练设计

在这里插入图片描述

# == 引入工具包 ==
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim
# == step 0 参数配置 ==# == step 1 数据处理 ==
norm_mean = [0.33424968,0.33424437, 0.33428448]
norm_std = [0.24796878, 0.24796101, 0.24801227]train_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(), # 0-255 归一化到0-1 转Tensortransforms.Normalize(norm_mean, norm_std),
])train_data_path = r"D:\PycharmProjects\AI_Easy_Demo\MyData\split_data\train"
from MyDataset.Cifar10_Dataset import LoadDataset
train_dataset = LoadDataset(data_dir=r"D:\PycharmProjects\AI_Easy_Demo\MyData\split_data\train",transform=train_transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=True) # shuffle训练时打乱样本
# == step 2 模型 ==
from MyNet.ResNet import ResNet34
net = ResNet34(num_classes=10, num_linear=512)# == step 3 损失函数 ==
criterion = nn.NLLLoss()# == step 4 优化器 ==
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # 设置学习率下降策略,每过step_size个epoch,做一次更新
# == step 5 评测函数==# == step 6 训练 ==
MAX_EPOCH=200
best = [0] # 存储最优指标,用于Early Stoppingfor i in range(MAX_EPOCH):print("当前轮转次数:",i+1)for idx, data_info in enumerate(train_loader):# print("训练数据索引",idx)inputs, labels = data_infooutputs = net(inputs)outputs = F.log_softmax(outputs, dim=1)optimizer.zero_grad() # 梯度置零,设置在loss之前loss = criterion(outputs,labels)loss.backward() # loss反向传播,梯度的计算print(loss)# update weightsoptimizer.step() # 更新所有的参数 根据误差和梯度进行权重的更新# if best_loss torch.save(net.state_dict(), "best.pth")scheduler.step() # 更新学习率# == step 7 训练可视化 ==# == inference ==


推荐阅读
  • 毕业设计:基于机器学习与深度学习的垃圾邮件(短信)分类算法实现
    本文详细介绍了如何使用机器学习和深度学习技术对垃圾邮件和短信进行分类。内容涵盖从数据集介绍、预处理、特征提取到模型训练与评估的完整流程,并提供了具体的代码示例和实验结果。 ... [详细]
  • 技术分享:从动态网站提取站点密钥的解决方案
    本文探讨了如何从动态网站中提取站点密钥,特别是针对验证码(reCAPTCHA)的处理方法。通过结合Selenium和requests库,提供了详细的代码示例和优化建议。 ... [详细]
  • 前言--页数多了以后需要指定到某一页(只做了功能,样式没有细调)html ... [详细]
  • 本文详细解析了Python中的os和sys模块,介绍了它们的功能、常用方法及其在实际编程中的应用。 ... [详细]
  • 利用决策树预测NBA比赛胜负的Python数据挖掘实践
    本文通过使用2013-14赛季NBA赛程与结果数据集以及2013年NBA排名数据,结合《Python数据挖掘入门与实践》一书中的方法,展示如何应用决策树算法进行比赛胜负预测。我们将详细讲解数据预处理、特征工程及模型评估等关键步骤。 ... [详细]
  • 1.如何在运行状态查看源代码?查看函数的源代码,我们通常会使用IDE来完成。比如在PyCharm中,你可以Ctrl+鼠标点击进入函数的源代码。那如果没有IDE呢?当我们想使用一个函 ... [详细]
  • 本文详细介绍了Akka中的BackoffSupervisor机制,探讨其在处理持久化失败和Actor重启时的应用。通过具体示例,展示了如何配置和使用BackoffSupervisor以实现更细粒度的异常处理。 ... [详细]
  • 本文介绍如何使用阿里云的fastjson库解析包含时间戳、IP地址和参数等信息的JSON格式文本,并进行数据处理和保存。 ... [详细]
  • 对象自省自省在计算机编程领域里,是指在运行时判断一个对象的类型和能力。dir能够返回一个列表,列举了一个对象所拥有的属性和方法。my_list[ ... [详细]
  • 本教程详细介绍了如何使用 TensorFlow 2.0 构建和训练多层感知机(MLP)网络,涵盖回归和分类任务。通过具体示例和代码实现,帮助初学者快速掌握 TensorFlow 的核心概念和操作。 ... [详细]
  • 本文将探讨2015年RCTF竞赛中的一道PWN题目——shaxian,重点分析其利用Fastbin和堆溢出的技巧。通过详细解析代码流程和漏洞利用过程,帮助读者理解此类题目的破解方法。 ... [详细]
  • 深入理解 H5C3 和 JavaScript 核心问题
    本文详细探讨了 H5C3 和 JavaScript 中的一些核心编程问题,通过实例解析和代码示例,帮助开发者更好地理解和应用这些技术。 ... [详细]
  • 本文介绍了在Windows环境下使用pydoc工具的方法,并详细解释了如何通过命令行和浏览器查看Python内置函数的文档。此外,还提供了关于raw_input和open函数的具体用法和功能说明。 ... [详细]
  • 本文探讨了如何在不重新加载URL的情况下,触发WebView的PictureListener.onNewPicture()方法,以实现页面的重新绘制或渲染。 ... [详细]
  • 本文详细介绍了Python中函数的基本概念,包括函数的定义与调用、文档注释、参数传递(形参与实参)、返回值以及函数嵌套。通过具体示例和解释,帮助读者掌握函数在编程中的应用。 ... [详细]
author-avatar
小石头
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有