热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

【2】pytorch回归

【1】本例程通过对一个二次函数进行拟合,包含数据集的生成,神经网络的搭建,神经网络的训练。【2】代码#-*-coding:utf-8-

【1】本例程通过对一个二次函数进行拟合,包含数据集的生成,神经网络的搭建,神经网络的训练。

【2】代码

# -*- coding: utf-8 -*-##-------------------------------------------------------------------------------
# Name: huiguitest
# Description:
# Author: Administrator
# Date: 2020/11/28
#-------------------------------------------------------------------------------import torch
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F # 激励函数都在这#linspace函数的作用是,返回一个一维的tensor(张量),这个张量包含了从start到end,分成steps个线段得到的向量。
#torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度
#(1,3)变为3行
#torch.unsqueeze()这个函数主要是对数据维度进行扩充
#3行变为(1,3)
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)# 画图
plt.scatter(x.data.numpy(), y.data.numpy())
#plt.show()#建立神经网络
#创建一个类
class Net(torch.nn.Module): # 继承 torch 的 Moduledef __init__(self, n_feature, n_hidden, n_output):super(Net, self).__init__() # 继承 __init__ 功能# 定义每层用什么样的形式self.hidden = torch.nn.Linear(n_feature, n_hidden) # 隐藏层线性输出self.predict = torch.nn.Linear(n_hidden, n_output) # 输出层线性输出def forward(self, x): # 这同时也是 Module 中的 forward 功能# 正向传播输入值, 神经网络分析出输出值x = F.relu(self.hidden(x)) # 激励函数(隐藏层的线性值)x = self.predict(x) # 输出值return xnet = Net(n_feature=1, n_hidden=10, n_output=1)print(net) # net 的结构
"""
Net ((hidden): Linear (1 -> 10)(predict): Linear (10 -> 1)
)
"""#训练网络
# optimizer 是训练的工具
#定义损失函数和梯度下降方法
optimizer = torch.optim.SGD(net.parameters(), lr=0.2) # 传入 net 的所有参数, 学习率
loss_func = torch.nn.MSELoss() # 预测值和真实值的误差计算公式 (均方差)for t in range(300):prediction = net(x) # 喂给 net 训练数据 x, 输出预测值loss = loss_func(prediction, y) # 计算两者的误差optimizer.zero_grad() # 清空上一步的残余更新参数值loss.backward() # 误差反向传播, 计算参数更新值optimizer.step() # 将参数更新值施加到 net 的 parameters 上if t % 5 == 0:# plot and show learning processplt.cla()plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fOntdict={'size': 20, 'color': 'red'})plt.pause(0.1)plt.ion() # 画图
plt.show()
plt.pause(10000)

【3】效果图


推荐阅读
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 浏览器中的异常检测算法及其在深度学习中的应用
    本文介绍了在浏览器中进行异常检测的算法,包括统计学方法和机器学习方法,并探讨了异常检测在深度学习中的应用。异常检测在金融领域的信用卡欺诈、企业安全领域的非法入侵、IT运维中的设备维护时间点预测等方面具有广泛的应用。通过使用TensorFlow.js进行异常检测,可以实现对单变量和多变量异常的检测。统计学方法通过估计数据的分布概率来计算数据点的异常概率,而机器学习方法则通过训练数据来建立异常检测模型。 ... [详细]
  • SpringMVC接收请求参数的方式总结
    本文总结了在SpringMVC开发中处理控制器参数的各种方式,包括处理使用@RequestParam注解的参数、MultipartFile类型参数和Simple类型参数的RequestParamMethodArgumentResolver,处理@RequestBody注解的参数的RequestResponseBodyMethodProcessor,以及PathVariableMapMethodArgumentResol等子类。 ... [详细]
  • 本文由编程笔记#小编为大家整理,主要介绍了logistic回归(线性和非线性)相关的知识,包括线性logistic回归的代码和数据集的分布情况。希望对你有一定的参考价值。 ... [详细]
  • Spring源码解密之默认标签的解析方式分析
    本文分析了Spring源码解密中默认标签的解析方式。通过对命名空间的判断,区分默认命名空间和自定义命名空间,并采用不同的解析方式。其中,bean标签的解析最为复杂和重要。 ... [详细]
  • Linux重启网络命令实例及关机和重启示例教程
    本文介绍了Linux系统中重启网络命令的实例,以及使用不同方式关机和重启系统的示例教程。包括使用图形界面和控制台访问系统的方法,以及使用shutdown命令进行系统关机和重启的句法和用法。 ... [详细]
  • 本文讨论了在Windows 8上安装gvim中插件时出现的错误加载问题。作者将EasyMotion插件放在了正确的位置,但加载时却出现了错误。作者提供了下载链接和之前放置插件的位置,并列出了出现的错误信息。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 基于dlib的人脸68特征点提取(眨眼张嘴检测)python版本
    文章目录引言开发环境和库流程设计张嘴和闭眼的检测引言(1)利用Dlib官方训练好的模型“shape_predictor_68_face_landmarks.dat”进行68个点标定 ... [详细]
  • 移动端常用单位——rem的使用方法和注意事项
    本文介绍了移动端常用的单位rem的使用方法和注意事项,包括px、%、em、vw、vh等其他常用单位的比较。同时还介绍了如何通过JS获取视口宽度并动态调整rem的值,以适应不同设备的屏幕大小。此外,还提到了rem目前在移动端的主流地位。 ... [详细]
  • 树莓派语音控制的配置方法和步骤
    本文介绍了在树莓派上实现语音控制的配置方法和步骤。首先感谢博主Eoman的帮助,文章参考了他的内容。树莓派的配置需要通过sudo raspi-config进行,然后使用Eoman的控制方法,即安装wiringPi库并编写控制引脚的脚本。具体的安装步骤和脚本编写方法在文章中详细介绍。 ... [详细]
  • 如何在HTML中获取鼠标的当前位置
    本文介绍了在HTML中获取鼠标当前位置的三种方法,分别是相对于屏幕的位置、相对于窗口的位置以及考虑了页面滚动因素的位置。通过这些方法可以准确获取鼠标的坐标信息。 ... [详细]
  • 本文介绍了Python语言程序设计中文件和数据格式化的操作,包括使用np.savetext保存文本文件,对文本文件和二进制文件进行统一的操作步骤,以及使用Numpy模块进行数据可视化编程的指南。同时还提供了一些关于Python的测试题。 ... [详细]
  • Opencv提供了几种分类器,例程里通过字符识别来进行说明的1、支持向量机(SVM):给定训练样本,支持向量机建立一个超平面作为决策平面,使得正例和反例之间的隔离边缘被最大化。函数原型:训练原型cv ... [详细]
author-avatar
曹衡斌_307
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有