作者:曹衡斌_307 | 来源:互联网 | 2023-08-31 13:51
【1】本例程通过对一个二次函数进行拟合,包含数据集的生成,神经网络的搭建,神经网络的训练。
【2】代码
# -*- coding: utf-8 -*-##-------------------------------------------------------------------------------
# Name: huiguitest
# Description:
# Author: Administrator
# Date: 2020/11/28
#-------------------------------------------------------------------------------import torch
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F # 激励函数都在这#linspace函数的作用是,返回一个一维的tensor(张量),这个张量包含了从start到end,分成steps个线段得到的向量。
#torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度
#(1,3)变为3行
#torch.unsqueeze()这个函数主要是对数据维度进行扩充
#3行变为(1,3)
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)# 画图
plt.scatter(x.data.numpy(), y.data.numpy())
#plt.show()#建立神经网络
#创建一个类
class Net(torch.nn.Module): # 继承 torch 的 Moduledef __init__(self, n_feature, n_hidden, n_output):super(Net, self).__init__() # 继承 __init__ 功能# 定义每层用什么样的形式self.hidden = torch.nn.Linear(n_feature, n_hidden) # 隐藏层线性输出self.predict = torch.nn.Linear(n_hidden, n_output) # 输出层线性输出def forward(self, x): # 这同时也是 Module 中的 forward 功能# 正向传播输入值, 神经网络分析出输出值x = F.relu(self.hidden(x)) # 激励函数(隐藏层的线性值)x = self.predict(x) # 输出值return xnet = Net(n_feature=1, n_hidden=10, n_output=1)print(net) # net 的结构
"""
Net ((hidden): Linear (1 -> 10)(predict): Linear (10 -> 1)
)
"""#训练网络
# optimizer 是训练的工具
#定义损失函数和梯度下降方法
optimizer = torch.optim.SGD(net.parameters(), lr=0.2) # 传入 net 的所有参数, 学习率
loss_func = torch.nn.MSELoss() # 预测值和真实值的误差计算公式 (均方差)for t in range(300):prediction = net(x) # 喂给 net 训练数据 x, 输出预测值loss = loss_func(prediction, y) # 计算两者的误差optimizer.zero_grad() # 清空上一步的残余更新参数值loss.backward() # 误差反向传播, 计算参数更新值optimizer.step() # 将参数更新值施加到 net 的 parameters 上if t % 5 == 0:# plot and show learning processplt.cla()plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fOntdict={'size': 20, 'color': 'red'})plt.pause(0.1)plt.ion() # 画图
plt.show()
plt.pause(10000)
【3】效果图
![](https://img2.php1.cn/3cdc5/3984/807/95d7ea10e666f196.png)
![](https://img2.php1.cn/3cdc5/3984/807/c0e0a3aa7d90f2cf.png)