热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

开发笔记:AI基础之全连接

如上篇文章所讲,将我们需用的环境搭建完成以后,我们就可以开始AI之路了,下面就让我们来看看第一个网络框架结构——全连接吧。impo

如上篇文章所讲,将我们需用的环境搭建完成以后,我们就可以开始AI之路了,下面就让我们来看看第一个网络框架结构——全连接吧。

import torch.nn as nn
#导入所需库

class Net(nn.Module):
#初始化网络结构(设计神经网络)
def __init__(self):
super().__init__()
#设计一个多层结构的神经网络
self.layers = nn.Sequential(
nn.Linear(28*28,512),
#设计一层神经网络,有512个神经元,接受748个
nn.ReLU(),
nn.Linear(512,256),
nn.ReLU(),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128,10),
nn.Softmax(dim=1)
)
# 前向计算(使用神经网络),将数据x输入到网络中,返回结果
def forward(self, x):
return self.layers(x)

***************************************************************************************************************************************

import torch
import torchvision
import torch.nn as nn
from PIL import Image
import torch.utils.data as data
from my_net import Net
import numpy as np
import os
save_path = "module/net_ps.pth"

train_data = torchvision.datasets.MNIST(
root="MNIST_data",#单通道28*28黑白图片(0-9数字)
train=True,
transform=torchvision.transforms.ToTensor(),
download=True
)
test_data = torchvision.datasets.MNIST(
root="MNIST_data",
train=False,
transform=torchvision.transforms.ToTensor(),
download=False
)

if __name__ == ‘__main__‘:
#创建数据加载器,每次从train_data里面取100张数据,打乱
train = data.DataLoader(dataset=train_data,batch_size=100,shuffle=True)#用数据加载器从train中每次加载100张图片并打乱
#实例化网络对象
net = Net()
#判断本地是否已经有网络的参数,如果有,那就加载之前的参数
if os.path.isfile(save_path):
net = torch.load(save_path)
#定义损失函数
loss_fun = nn.MSELoss()#对(h-y)^2求平均
#定义优化器,用这个优化器来优化网络内部的参数
optimizer = torch.optim.Adam(net.parameters())
#取数据,训练网络
for epoch in range(1000000):
for i,(x,y) in enumerate(train):#N C H W形状
#将图片变为100,784
x = x.reshape(-1,28*28)
#将图片输入到网络,得到结果
out = net(x)
#将标签y进行one-hot编码
target = torch.zeros(y.size()[0],10).scatter_(1,y.view(-1,1),1)
#将网络的结果和标签拿来做损失
loss = loss_fun(target,out)
#优化损失
optimizer.zero_grad()#清空梯度
loss.backward()#根据损失进行反向求导
optimizer.step()#更新梯度
#每训练10次,进行一次测试
if i%10 == 0:
out_put = torch.argmax(out,dim=1)
# print("target:",y)
# print("out:",out_put)
print("loss:",loss.item())

#计算准确度
acc = np.mean(np.array(out_put==y,dtype=np.float32))
print("精度:",acc)
#保存网络参数
torch.save(net,save_path)

推荐阅读
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • Opencv提供了几种分类器,例程里通过字符识别来进行说明的1、支持向量机(SVM):给定训练样本,支持向量机建立一个超平面作为决策平面,使得正例和反例之间的隔离边缘被最大化。函数原型:训练原型cv ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • 本文介绍了OC学习笔记中的@property和@synthesize,包括属性的定义和合成的使用方法。通过示例代码详细讲解了@property和@synthesize的作用和用法。 ... [详细]
  • 欢乐的票圈重构之旅——RecyclerView的头尾布局增加
    项目重构的Git地址:https:github.comrazerdpFriendCircletreemain-dev项目同步更新的文集:http:www.jianshu.comno ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • 本文讨论了如何使用IF函数从基于有限输入列表的有限输出列表中获取输出,并提出了是否有更快/更有效的执行代码的方法。作者希望了解是否有办法缩短代码,并从自我开发的角度来看是否有更好的方法。提供的代码可以按原样工作,但作者想知道是否有更好的方法来执行这样的任务。 ... [详细]
  • 本文介绍了Python爬虫技术基础篇面向对象高级编程(中)中的多重继承概念。通过继承,子类可以扩展父类的功能。文章以动物类层次的设计为例,讨论了按照不同分类方式设计类层次的复杂性和多重继承的优势。最后给出了哺乳动物和鸟类的设计示例,以及能跑、能飞、宠物类和非宠物类的增加对类数量的影响。 ... [详细]
  • 合并列值-合并为一列问题需求:createtabletab(Aint,Bint,Cint)inserttabselect1,2,3unionallsel ... [详细]
  • 本文介绍了使用readlink命令获取文件的完整路径的简单方法,并提供了一个示例命令来打印文件的完整路径。共有28种解决方案可供选择。 ... [详细]
  • 读手语图像识别论文笔记2
    文章目录一、前言二、笔记1.名词解释2.流程分析上一篇快速门:读手语图像识别论文笔记1(手语识别背景和方法)一、前言一句:“做完了&#x ... [详细]
  • unigine中存在如下两个函数。mat4lookAt(vec3position,vec3target,vec3up)dmat4lookAt(dvec3position,dvec ... [详细]
  • 逻辑回归_训练二元分类器#训练一个二元分类器fromsklearn.linear_modelimportLogisticRegressionfromsklearnimport ... [详细]
author-avatar
手机用户2702936867
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有