热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

【人工智能作业】3例题程序复现PyTorch版

1.使用pytorch复现课上例题代码如下importtorchx1,x2torch.Tensor([0.5]),torch.Tensor([0.3])y1,y2torch.

1.使用pytorch复现课上例题

代码如下

import torchx1, x2 = torch.Tensor([0.5]), torch.Tensor([0.3])
y1, y2 = torch.Tensor([0.23]), torch.Tensor([-0.07])
print("=====输入值:x1, x2;真实输出值:y1, y2=====")
print(x1, x2, y1, y2)
w1, w2, w3, w4, w5, w6, w7, w8 = torch.Tensor([0.2]), torch.Tensor([-0.4]), torch.Tensor([0.5]), torch.Tensor([0.6]), torch.Tensor([0.1]), torch.Tensor([-0.5]), torch.Tensor([-0.3]), torch.Tensor([0.8]) # 权重初始值
w1.requires_grad = True
w2.requires_grad = True
w3.requires_grad = True
w4.requires_grad = True
w5.requires_grad = True
w6.requires_grad = True
w7.requires_grad = True
w8.requires_grad = Truedef sigmoid(z):a = 1 / (1 + torch.exp(-z))return adef forward_propagate(x1, x2):in_h1 = w1 * x1 + w3 * x2out_h1 = sigmoid(in_h1) # out_h1 = torch.sigmoid(in_h1)in_h2 = w2 * x1 + w4 * x2out_h2 = sigmoid(in_h2) # out_h2 = torch.sigmoid(in_h2)in_o1 = w5 * out_h1 + w7 * out_h2out_o1 = sigmoid(in_o1) # out_o1 = torch.sigmoid(in_o1)in_o2 = w6 * out_h1 + w8 * out_h2out_o2 = sigmoid(in_o2) # out_o2 = torch.sigmoid(in_o2)print("正向计算:o1 ,o2")print(out_o1.data, out_o2.data)return out_o1, out_o2def loss_fuction(x1, x2, y1, y2): # 损失函数y1_pred, y2_pred = forward_propagate(x1, x2) # 前向传播loss = (1 / 2) * (y1_pred - y1) ** 2 + (1 / 2) * (y2_pred - y2) ** 2 # 考虑 : t.nn.MSELoss()print("损失函数(均方误差):", loss.item())return lossdef update_w(w1, w2, w3, w4, w5, w6, w7, w8):# 步长step = 1w1.data = w1.data - step * w1.grad.dataw2.data = w2.data - step * w2.grad.dataw3.data = w3.data - step * w3.grad.dataw4.data = w4.data - step * w4.grad.dataw5.data = w5.data - step * w5.grad.dataw6.data = w6.data - step * w6.grad.dataw7.data = w7.data - step * w7.grad.dataw8.data = w8.data - step * w8.grad.dataw1.grad.data.zero_() # 注意:将w中所有梯度清零w2.grad.data.zero_()w3.grad.data.zero_()w4.grad.data.zero_()w5.grad.data.zero_()w6.grad.data.zero_()w7.grad.data.zero_()w8.grad.data.zero_()return w1, w2, w3, w4, w5, w6, w7, w8if __name__ == "__main__":print("=====更新前的权值=====")print(w1.data, w2.data, w3.data, w4.data, w5.data, w6.data, w7.data, w8.data)for i in range(1000):print("=====第" + str(i) + "轮=====")L = loss_fuction(x1, x2, y1, y2) # 前向传播,求 Loss,构建计算图L.backward() # 自动求梯度,不需要人工编程实现。反向传播,求出计算图中所有梯度存入w中print("\tgrad W: ", round(w1.grad.item(), 2), round(w2.grad.item(), 2), round(w3.grad.item(), 2),round(w4.grad.item(), 2), round(w5.grad.item(), 2), round(w6.grad.item(), 2), round(w7.grad.item(), 2),round(w8.grad.item(), 2))w1, w2, w3, w4, w5, w6, w7, w8 = update_w(w1, w2, w3, w4, w5, w6, w7, w8)print("更新后的权值")print(w1.data, w2.data, w3.data, w4.data, w5.data, w6.data, w7.data, w8.data)

训练前参数为(图片文字过小,使用文本)



 =====输入值:x1, x2;真实输出值:y1, y2=====
tensor([0.5000]) tensor([0.3000]) tensor([0.2300]) tensor([-0.0700])
=====更新前的权值=====
tensor([0.2000]) tensor([-0.4000]) tensor([0.5000]) tensor([0.6000]) tensor([0.1000]) tensor([-0.5000]) tensor([-0.3000]) tensor([0.8000])



 第一轮运行结果如下


作业二中的第一轮运行结果


第一千轮运行结果如下(图片文字过小,使用文本) 



 =====第999轮=====
正向计算:o1 ,o2
tensor([0.2296]) tensor([0.0098])
损失函数(均方误差): 0.0031851977109909058
    grad W:  -0.0 -0.0 -0.0 -0.0 -0.0 0.0 -0.0 0.0
更新后的权值
tensor([1.6515]) tensor([0.1770]) tensor([1.3709]) tensor([0.9462]) tensor([-0.7798]) tensor([-4.2741]) tensor([-1.0236]) tensor([-2.1999])




 作业二中第一千轮运行结果



2.对比【作业3】和【作业2】的程序,观察两种方法结果是否相同?如果不同,哪个正确?

结果不相同,作业三正确。 


3.【作业2】程序更新(保留【作业2中】的错误答案,留作对比。新程序到作业3)

 更新后的反向传播函数代码

def back_propagate(out_o1, out_o2, out_h1, out_h2):# 反向传播d_o1 = out_o1 - y1d_o2 = out_o2 - y2d_w5 = d_o1 * out_o1 * (1 - out_o1) * out_h1d_w7 = d_o1 * out_o1 * (1 - out_o1) * out_h2d_w6 = d_o2 * out_o2 * (1 - out_o2) * out_h1d_w8 = d_o2 * out_o2 * (1 - out_o2) * out_h2d_w1 = (d_o1 * out_h1 * (1 - out_h1) * w5 + d_o2 * out_o2 * (1 - out_o2) * w6) * out_h1 * (1 - out_h1) * x1d_w3 = (d_o1 * out_h1 * (1 - out_h1) * w5 + d_o2 * out_o2 * (1 - out_o2) * w6) * out_h1 * (1 - out_h1) * x2d_w2 = (d_o1 * out_h1 * (1 - out_h1) * w7 + d_o2 * out_o2 * (1 - out_o2) * w8) * out_h2 * (1 - out_h2) * x1d_w4 = (d_o1 * out_h1 * (1 - out_h1) * w7 + d_o2 * out_o2 * (1 - out_o2) * w8) * out_h2 * (1 - out_h2) * x2print("w的梯度:", round(d_w1, 2), round(d_w2, 2), round(d_w3, 2), round(d_w4, 2), round(d_w5, 2), round(d_w6, 2),round(d_w7, 2), round(d_w8, 2))return d_w1, d_w2, d_w3, d_w4, d_w5, d_w6, d_w7, d_w8

4. 对比【作业2】与【作业3】的反向传播的实现方法。总结并陈述

作业2中的方法通过手动计算,得到反向传播过程中各参数梯度。作业3通过张量Tensor求Loss,构建计算图,在最后通过backword()自动求梯度。由于使用了Pytorch当中的Tensor张量,使用前向传播建立计算图,即可根据传播结果利用back_propagate自动计算梯度。激活函数使用Pytorch自带的torch.sigmoid()。


5.激活函数Sigmoid用PyTorch自带函数torch.sigmoid(),观察、总结并陈述

二者使用同一个公式

sigmoid(x)=\frac{1}{1+e^{-x}}


10.权值w1-w8初始值换为随机数,对比【作业2】指定权值结果,观察、总结并陈述

 随机数代码如下

w1, w2, w3, w4, w5, w6, w7, w8 = torch.rand(1, 1), torch.rand(1, 1), torch.rand(1, 1), torch.rand(1, 1), torch.rand(1, 1), torch.rand(1, 1), torch.rand(1, 1), torch.rand(1, 1)

观测结果,前后改变不大。


推荐阅读
  • Java太阳系小游戏分析和源码详解
    本文介绍了一个基于Java的太阳系小游戏的分析和源码详解。通过对面向对象的知识的学习和实践,作者实现了太阳系各行星绕太阳转的效果。文章详细介绍了游戏的设计思路和源码结构,包括工具类、常量、图片加载、面板等。通过这个小游戏的制作,读者可以巩固和应用所学的知识,如类的继承、方法的重载与重写、多态和封装等。 ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • 在springmvc框架中,前台ajax调用方法,对图片批量下载,如何弹出提示保存位置选框?Controller方法 ... [详细]
  • 本文介绍了在Python3中如何使用选择文件对话框的格式打开和保存图片的方法。通过使用tkinter库中的filedialog模块的asksaveasfilename和askopenfilename函数,可以方便地选择要打开或保存的图片文件,并进行相关操作。具体的代码示例和操作步骤也被提供。 ... [详细]
  • Webpack5内置处理图片资源的配置方法
    本文介绍了在Webpack5中处理图片资源的配置方法。在Webpack4中,我们需要使用file-loader和url-loader来处理图片资源,但是在Webpack5中,这两个Loader的功能已经被内置到Webpack中,我们只需要简单配置即可实现图片资源的处理。本文还介绍了一些常用的配置方法,如匹配不同类型的图片文件、设置输出路径等。通过本文的学习,读者可以快速掌握Webpack5处理图片资源的方法。 ... [详细]
  • Commit1ced2a7433ea8937a1b260ea65d708f32ca7c95eintroduceda+Clonetraitboundtom ... [详细]
  • 开发笔记:加密&json&StringIO模块&BytesIO模块
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了加密&json&StringIO模块&BytesIO模块相关的知识,希望对你有一定的参考价值。一、加密加密 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • 本文介绍了基于c语言的mcs51单片机定时器计数器的应用教程,包括定时器的设置和计数方法,以及中断函数的使用。同时介绍了定时器应用的举例,包括定时器中断函数的编写和频率值的计算方法。主函数中设置了T0模式和T1计数的初值,并开启了T0和T1的中断,最后启动了CPU中断。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
  • Spring常用注解(绝对经典),全靠这份Java知识点PDF大全
    本文介绍了Spring常用注解和注入bean的注解,包括@Bean、@Autowired、@Inject等,同时提供了一个Java知识点PDF大全的资源链接。其中详细介绍了ColorFactoryBean的使用,以及@Autowired和@Inject的区别和用法。此外,还提到了@Required属性的配置和使用。 ... [详细]
  • 本文讨论了如何使用GStreamer来删除H264格式视频文件中的中间部分,而不需要进行重编码。作者提出了使用gst_element_seek(...)函数来实现这个目标的思路,并提到遇到了一个解决不了的BUG。文章还列举了8个解决方案,希望能够得到更好的思路。 ... [详细]
  • 基于移动平台的会展导游系统APP设计与实现的技术介绍与需求分析
    本文介绍了基于移动平台的会展导游系统APP的设计与实现过程。首先,对会展经济和移动互联网的概念进行了简要介绍,并阐述了将会展引入移动互联网的意义。接着,对基础技术进行了介绍,包括百度云开发环境、安卓系统和近场通讯技术。然后,进行了用户需求分析和系统需求分析,并提出了系统界面运行流畅和第三方授权等需求。最后,对系统的概要设计进行了详细阐述,包括系统前端设计和交互与原型设计。本文对基于移动平台的会展导游系统APP的设计与实现提供了技术支持和需求分析。 ... [详细]
author-avatar
D调肥仔
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有