热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

用Python实现人工神经网络(简易版)

用Python实现人工神经网络(简易版),Go语言社区,Golang程序员人脉社

人工神经网络

    • 人工神经网络简介
    • 代码
    • 神经网络的缺点
    • 程序优化
    • BP算法
    • 补充


人工神经网络简介

人工神经网络(Artificial Neural Network, ANN)是指一系列受生物学和神经学启发的数学模型. 在人工智能领域, 人工神经网络也常常简称为神经网络(Neural Network, NN)或神经模型(Neural Model). 一个简单的多层前馈神经网络如下图.

多层前馈神经网络


代码

构造一个神经网络类, 首先需要将一些变量进行初始化,其中各个层的权重矩阵以及偏置项分别存储在字典中, 键代表层数.

from scipy.special import expit
import numpy as np
class ANN(object):
def __init__(self, innum, outnum, lr, *hide_tuple):
self.innum = innum # 输入节点的个数
self.outnum = outnum # 输出节点的个数
self.lr = lr # 学习率
self.layernum = len(hide_tuple) + 1 # 神经网络的层数
self.Weight = {} # 权重矩阵
self.Bias = {} # 偏置项

# 对权重矩阵和偏置项进行初始化
self.Weight[1] = np.random.normal(0.0, pow(self.innum, -0.5), (hide_tuple[0], self.innum))
for i in range(1, self.layernum):
self.Bias[i] = np.random.randn(hide_tuple[i-1]).reshape(hide_tuple[i-1], 1)
if i>=2:
self.Weight[i] = np.random.normal(0.0, pow(hide_tuple[i-1], -0.5), (hide_tuple[i-1], hide_tuple[i-2]))
self.Weight[self.layernum] = np.random.normal(0.0, pow(hide_tuple[-1], -0.5), (self.outnum, hide_tuple[-1]))
self.Bias[self.layernum] = np.random.randn(self.outnum).reshape(self.outnum, 1)
self.ActiveFunction = lambda x: expit(x) # 激活函数为logistic函数

接下来通过BP算法反向求解误差,进而不断更新权重和偏置项, 其中将所有层的输入和输出分别存入相应的字典中, 键代表相应的层数.

def BPFit(self, input_list, target_list):
Input0 = np.array(input_list, ndmin=2).T
TargetValue = np.array(target_list, ndmin=2).T
Input = {} # 输入值
Output = {} # 输出值
Output[0] = Input0
for i in range(1, self.layernum+1):
Input[i] = np.dot(self.Weight[i], Output[i-1])
Output[i] = self.ActiveFunction(Input[i] + self.Bias[i])
Error={} # 误差项
Error[self.layernum] = Output[self.layernum] * (1 - Output[self.layernum]) * (-(TargetValue - Output[self.layernum]))
self.Weight[self.layernum] -= self.lr * Error[self.layernum] * Output[self.layernum-1].T
self.Bias[self.layernum] -= self.lr * Error[self.layernum]
for i in range(self.layernum-1, 0, -1): # 从倒数第二层的误差项开始
Error[i] = Output[i] * (1 - Output[i]) * np.dot(self.Weight[i+1].T, Error[i+1])
self.Weight[i] -= self.lr * Error[i] * Output[i-1].T
self.Bias[i] -= self.lr * Error[i]

接着再添加一个预测函数, 它是用来实现神经网络预测功能的成员函数.

def predict(self, input_list):
Input0 = np.array(input_list, ndmin=2).T
Input = {}
Output = {}
Output[0] = Input0
for i in range(1, self.layernum + 1):
Input[i] = np.dot(self.Weight[i], Output[i - 1])
Output[i] = self.ActiveFunction(Input[i] + self.Bias[i])
return Output[self.layernum]

最后来测试一下程序是否可以正确运行. 在测试中, 小编任选一个4维的输入向量, 1维的输出向量, 中间添加了四个隐藏层.

if __name__ == '__main__':
"""
测试样例
"""

inode = 4 # 输入节点个数
hnode1 = 4 # 第1层隐节点个数
hnode2 = 5 # 第2层隐节点个数
hnode3 = 10 # 第3层隐节点个数
hnode4 = 5 # 第4层隐节点个数
onodenum = 1 # 输出节点个数
learningrate = 0.3 # 学习率
ann = ANN(inode, onodenum, learningrate, hnode1, hnode2, hnode3, hnode4)
TrainValue = [1, 3, 2, 4]
TargetValue = [0]
ann.BPFit(TrainValue, TargetValue)
TestValue = [2, 4, 2, 4]
predict = ann.predict(TestValue)
print(predict) # 输出预测结果

预测结果为:
在这里插入图片描述这样就实现了简易版神经网络的搭建.


神经网络的缺点


  • 可解释性差

    神经网络相当于一个黑箱模型, 不知道能产生什么结果, 也不知道为什么产生这种结果. 但是决策树可以遵循一定的逻辑, 如果出问题也能想出来原因, 比如银行就不会用神经网络预测一个人的信誉.

  • 耗时耗力

    训练大型的神经网络需要花费大量的时间以及需要大量的内存来对神经网络进行训练

  • 样本量大

    为了得到一个预测能力强的神经网络, 前提是需要大量的样本

  • 信息丢失

    因为在训练神经网络时, 需要把样本都转换成数值型, 在转化的过程中就会用信息丢失.

程序优化

虽然实现了简易版的神经网络, 并且可以任意设置层数以及神经元的个数, 但是为了提高预测的准确率和运行速度, 还可以从许多方面对该程序进行优化. 比如:


  1. 初始权重矩阵以及偏置项

  2. 样本量

  3. 优化器的种类

  4. 优化器的学习率

  5. 神经网络的层数

  6. 神经元的个数

  7. 激活函数

  8. 损失函数

  9. 训练次数

BP算法

本文中的神经网络是基于BP算法来搭建的. 在本文结尾引用的书籍中, 对于该算法有详细的介绍, 在这里小编就不再详细介绍该算法了, 感兴趣的小伙伴们可以了解一下哈~


补充

这是小编的第一篇博文, 也是小编初入互联网大家庭的一个标志. 因为本篇文章主要靠小编自己的理解去写的, 所以可能存在一些不是很恰当的词语, 还请大家多多包容与理解. 最后立个flag, 小编会继续加油, 希望可以早日甩掉技术小白的称号!!!(ps: Python源码已经上传到小编的github)


[1]: 邱锡鹏. 神经网络与深度学习.





推荐阅读
  • Java容器中的compareto方法排序原理解析
    本文从源码解析Java容器中的compareto方法的排序原理,讲解了在使用数组存储数据时的限制以及存储效率的问题。同时提到了Redis的五大数据结构和list、set等知识点,回忆了作者大学时代的Java学习经历。文章以作者做的思维导图作为目录,展示了整个讲解过程。 ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • Iamtryingtomakeaclassthatwillreadatextfileofnamesintoanarray,thenreturnthatarra ... [详细]
  • 本文介绍了OC学习笔记中的@property和@synthesize,包括属性的定义和合成的使用方法。通过示例代码详细讲解了@property和@synthesize的作用和用法。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 关键词:Golang, Cookie, 跟踪位置, net/http/cookiejar, package main, golang.org/x/net/publicsuffix, io/ioutil, log, net/http, net/http/cookiejar ... [详细]
  • Java太阳系小游戏分析和源码详解
    本文介绍了一个基于Java的太阳系小游戏的分析和源码详解。通过对面向对象的知识的学习和实践,作者实现了太阳系各行星绕太阳转的效果。文章详细介绍了游戏的设计思路和源码结构,包括工具类、常量、图片加载、面板等。通过这个小游戏的制作,读者可以巩固和应用所学的知识,如类的继承、方法的重载与重写、多态和封装等。 ... [详细]
  • 本文介绍了一个Java猜拳小游戏的代码,通过使用Scanner类获取用户输入的拳的数字,并随机生成计算机的拳,然后判断胜负。该游戏可以选择剪刀、石头、布三种拳,通过比较两者的拳来决定胜负。 ... [详细]
  • 开发笔记:加密&json&StringIO模块&BytesIO模块
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了加密&json&StringIO模块&BytesIO模块相关的知识,希望对你有一定的参考价值。一、加密加密 ... [详细]
  • 本文主要解析了Open judge C16H问题中涉及到的Magical Balls的快速幂和逆元算法,并给出了问题的解析和解决方法。详细介绍了问题的背景和规则,并给出了相应的算法解析和实现步骤。通过本文的解析,读者可以更好地理解和解决Open judge C16H问题中的Magical Balls部分。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 本文详细介绍了Java中vector的使用方法和相关知识,包括vector类的功能、构造方法和使用注意事项。通过使用vector类,可以方便地实现动态数组的功能,并且可以随意插入不同类型的对象,进行查找、插入和删除操作。这篇文章对于需要频繁进行查找、插入和删除操作的情况下,使用vector类是一个很好的选择。 ... [详细]
  • Java学习笔记之面向对象编程(OOP)
    本文介绍了Java学习笔记中的面向对象编程(OOP)内容,包括OOP的三大特性(封装、继承、多态)和五大原则(单一职责原则、开放封闭原则、里式替换原则、依赖倒置原则)。通过学习OOP,可以提高代码复用性、拓展性和安全性。 ... [详细]
  • STL迭代器的种类及其功能介绍
    本文介绍了标准模板库(STL)定义的五种迭代器的种类和功能。通过图表展示了这几种迭代器之间的关系,并详细描述了各个迭代器的功能和使用方法。其中,输入迭代器用于从容器中读取元素,输出迭代器用于向容器中写入元素,正向迭代器是输入迭代器和输出迭代器的组合。本文的目的是帮助读者更好地理解STL迭代器的使用方法和特点。 ... [详细]
author-avatar
手机用户2502852661
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有