热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pytorch(网络模型训练)

上一篇目录标题网络模型训练小插曲训练模型数据训练GPU训练第一种方式方式二:查看GPU信息完整模型验证网络模型训练小插曲区别importtorchatorch

上一篇

目录标题

  • 网络模型训练
    • 小插曲
    • 训练模型
    • 数据训练
    • GPU 训练
      • 第一种方式
      • 方式二:
      • 查看GPU信息
    • 完整模型验证


网络模型训练

小插曲

区别

import torch
a=torch.tensor(5)
print(a)
print(a.item())

在这里插入图片描述

import torchoutput=torch.tensor([[0.1,0.2],[0.05,0.4]])
print(output.argmax(1))# 为1选取每一行最大值的索引,为0选取每一列最大值的索引preds=output.argmax(1)
target=torch.tensor([0,1])
print(preds==target)
print((preds==target).sum())

在这里插入图片描述

训练模型

import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, Flatten
# 搭建神经网络
class Dun(nn.Module):def __init__(self):super().__init__()# 2.self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self,x):x=self.model1(x)return x
if __name__=='__main__':dun=Dun()input=torch.ones((64,3,32,32))print(dun(input).shape)

数据训练


import torchvision
# 准备数据集
from torch.utils.tensorboard import SummaryWriterfrom model import *
from torch.utils.data import DataLoadertrain_data=torchvision.datasets.CIFAR10(root="./data_set_train",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data=torchvision.datasets.CIFAR10(root="./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)#长度
train_data_size=len(train_data)
test_data_size=len(test_data)
print("train_data_size:{}",format(train_data_size))
print("test_data_size:{}",format(test_data_size))# 加载数据集
train_dataloader=DataLoader(train_data,batch_size=64)
test_dataloder=DataLoader(test_data,batch_size=64)#创建网络模型
dun=Dun()
#损失函数
loss_fn=nn.CrossEntropyLoss()
# 优化器
learning_rate=1e-2
optimizerr=torch.optim.SGD(dun.parameters(),lr=learning_rate)
#设置训练网络参数
# 记录训练次数
total_train_step=0
# 记录测试次数
total_test_step=0
#训练次数
epoch=10# 追加tensorboard
writer=SummaryWriter("./logs")for i in range(epoch):print("----------第{}轮训练------".format(i+1))# 训练开始dun.train()# 网络模型中,对dropout、BatchNorm层等起作用,进入训练状态for data in train_dataloader:img,target=dataoutput=dun(img)loss=loss_fn(output,target)#优化器优化optimizerr.zero_grad()loss.backward()optimizerr.step()total_train_step+=1print("训练次数:{},loss:{}".format(total_train_step,loss.item()))writer.add_scalar("train_loss",loss.item(),total_train_step)# 测试步骤total_test_loss=0# 使用正确率判断模型的好坏total_accuracy=0dun.eval()# 网络模型中,对dropout、BatchNorm层等起作用,进入验证状态with torch.no_grad():for data in test_dataloder:img,target=dataoutput=dun(img)total_test_loss+=loss_fn(output,target).item()accuracy=(output.argmax(1)==target).sum()total_accuracy+=accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))writer.add_scalar("test_loss",total_test_loss,total_test_step)print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)total_test_step+=1#保存模型torch.save(dun,"dun{}.pth".format(i))print("保存模型")writer.close()

GPU 训练


第一种方式

在这里插入图片描述
将以上的三部分调用cuda方法,以上面训练数据的代码为例

# 模型
if torch.cuda.is_available():# 判断是否可以使用gpudun=dun.cuda()
#损失函数
if torch.cuda.is_available():loss_fn=loss_fn.cuda()
# 数据(包含训练和测试的)if torch.cuda.is_available():img = img.cuda()target = target.cuda()

方式二:

# 定义训练的设备
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")# 参数分为cpu和cuda,当显卡多个时cuda:0

将方式一的代码换成

dun=dun.to(device)
# 其他数据、loss类似

查看GPU信息

在这里插入图片描述

完整模型验证

查看数据集CIFAR10的类别
在这里插入图片描述

import torchvision
from PIL import Image
import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, Flatten
# 搭建神经网络
class Dun(nn.Module):def __init__(self):super().__init__()# 2.self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self,x):x=self.model1(x)return ximage_path="./img/1.png"
image=Image.open(image_path)
print(image)
# 类型转换
transform=torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()])
image=transform(image)
print(image.shape)# 加载网络模型注意加载的模型和现在验证的要么使用cpu要么gpu一致,否则需要map——location映射本地的cpu
model=torch.load("dun0.pth",map_location=torch.device("cpu"))
print(model)
# 类型转换
image=torch.reshape(image,(1,3,32,32))model.eval()# 模型转换测试类型
# 执行模型
with torch.no_grad():output=model(image)
print(output)print(output.argmax(1))


推荐阅读
  • 目录预备知识导包构建数据集神经网络结构训练测试精度可视化计算模型精度损失可视化输出网络结构信息训练神经网络定义参数载入数据载入神经网络结构、损失及优化训练及测试损失、精度可视化qu ... [详细]
  • 本文介绍如何使用 Java 编程语言来判断一个给定的年份是否为闰年,并提供两种不同的实现方法。 ... [详细]
  • 图神经网络模型综述
    本文综述了图神经网络(Graph Neural Networks, GNN)的发展,从传统的数据存储模型转向图和动态模型,探讨了模型中的显性和隐性结构,并详细介绍了GNN的关键组件及其应用。 ... [详细]
  • 尤洋:夸父AI系统——大规模并行训练的深度学习解决方案
    自从AlexNet等模型在计算机视觉领域取得突破以来,深度学习技术迅速发展。近年来,随着BERT等大型模型的广泛应用,AI模型的规模持续扩大,对硬件提出了更高的要求。本文介绍了新加坡国立大学尤洋教授团队开发的夸父AI系统,旨在解决大规模模型训练中的并行计算挑战。 ... [详细]
  • 本文介绍了如何使用 Google Colab 的免费 GPU 资源进行深度学习应用开发。Google Colab 是一个无需配置即可使用的云端 Jupyter 笔记本环境,支持多种深度学习框架,并且提供免费的 GPU 计算资源。 ... [详细]
  • 本教程详细介绍了如何使用 TensorFlow 2.0 构建和训练多层感知机(MLP)网络,涵盖回归和分类任务。通过具体示例和代码实现,帮助初学者快速掌握 TensorFlow 的核心概念和操作。 ... [详细]
  • 二维几何变换矩阵解析
    本文详细介绍了二维平面上的三种常见几何变换:平移、缩放和旋转。通过引入齐次坐标系,使得这些变换可以通过统一的矩阵乘法实现,从而简化了计算过程。文中不仅提供了理论推导,还附有Python代码示例,帮助读者更好地理解这些概念。 ... [详细]
  • 本文将探讨2015年RCTF竞赛中的一道PWN题目——shaxian,重点分析其利用Fastbin和堆溢出的技巧。通过详细解析代码流程和漏洞利用过程,帮助读者理解此类题目的破解方法。 ... [详细]
  • yikesnews第11期:微软Office两个0day和一个提权0day
    点击阅读原文可点击链接根据法国大选被黑客干扰,发送了带漏洞的文档Trumps_Attack_on_Syria_English.docx而此漏洞与ESET&FireEy ... [详细]
  • 本文介绍了 Python 的 Pmagick 库中用于图像处理的木炭滤镜方法,探讨其功能和用法,并通过实例演示如何应用该方法。 ... [详细]
  • 本文详细介绍了 Python 中的条件语句和循环结构。主要内容包括:1. 分支语句(if...elif...else);2. 循环语句(for, while 及嵌套循环);3. 控制循环的语句(break, continue, else)。通过具体示例,帮助读者更好地理解和应用这些语句。 ... [详细]
  • vivo Y5s配备了联发科Helio P65八核处理器,这款处理器采用12纳米工艺制造,具备两颗高性能Cortex-A75核心和六颗高效能Cortex-A55核心。此外,它还集成了先进的图像处理单元和语音唤醒功能,为用户提供卓越的性能体验。 ... [详细]
  • 新手指南:在Windows 10上搭建深度学习与PyTorch开发环境
    本文详细记录了一名新手在Windows 10操作系统上搭建深度学习环境的过程,包括安装必要的软件和配置环境变量等步骤,旨在帮助同样初入该领域的读者避免常见的错误。 ... [详细]
  • 在Win10上利用VS2015构建Caffe2环境
    本文详细介绍如何在Windows 10操作系统上通过Visual Studio 2015编译Caffe2深度学习框架的过程。包括必要的软件安装、环境配置以及常见问题的解决方法。 ... [详细]
  • 浪潮AI服务器NF5488A5在MLPerf基准测试中刷新多项纪录
    近日,国际权威AI基准测试平台MLPerf发布了最新的推理测试结果,浪潮AI服务器NF5488A5在此次测试中创造了18项性能纪录,显著提升了数据中心AI推理性能。 ... [详细]
author-avatar
Stupid锋_891
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有