热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

神经网络字符识别matlab程序,用matlab实现神经网络识别数字|学步园

AndrewNg机器学习第四周的编程练习是用matlab实现一个神经网络对一幅图中的数字进行识别,有待识别的数字全集如下:其中每一个数字都是一个大小为2

Andrew Ng机器学习第四周的编程练习是用matlab实现一个神经网络对一幅图中的数字进行识别,有待识别的数字全集如下:

27ec67dea300c3ec54ce58cb0f503ca2.png

其中每一个数字都是一个大小为20*20像素的图像,如果把每个像素作为一个输入单元,那有400个输入。考虑到神经网络还需要增加一个额外输入单元表示偏差,一共有401个输入单元。题目中给的训练数据X是一个5000*400的向量。

题目中要求包含一个25个节点的隐藏层,隐藏层也存在表示偏差的额外输入,所以一共有26个输入。

最终的输出结果是一个10维的向量,分别表示该数字在0-9上面的概率值(由于没有0这个下标位,这里题目中把0标记为10,其余1-9还是对应1-9),找到其中概率最大的就是要识别的结果。

神经网络的结构如下:

e0a4127de02530ff296d668bf877af5f.png

从上图可以看到,神经网络中除了输入参数外,还包含Theta1和Theta2两个参数。

其中的Theta1就表示输入层到隐含层中每条边的权重,为25*401的向量。Theta2是隐含层到输出层每条边的权重,为10*26的向量。

为了把数据标准化减少误差,这里要对每一步的输出用sigmoid函数进行处理。

构造好神经网络后,首先是用训练数据进行训练,得出Theta1和Theta2的权重信息,然后就可以预测了。

主要的matlab代码如下:

%% Machine Learning Online Class - Exercise 3 | Part 2: Neural Networks

% Instructions

% ------------

%

% This file contains code that helps you get started on the

% linear exercise. You will need to complete the following functions

% in this exericse:

%

% lrCostFunction.m (logistic regression cost function)

% oneVsAll.m

% predictOneVsAll.m

% predict.m

%

% For this exercise, you will not need to change any code in this file,

% or any other files other than those mentioned above.

%

%% Initialization

clear ; close all; clc

%% Setup the parameters you will use for this exercise

input_layer_size = 400; % 20x20 Input Images of Digits

hidden_layer_size = 25; % 25 hidden units

num_labels = 10; % 10 labels, from 1 to 10

% (note that we have mapped "0" to label 10)

%% =========== Part 1: Loading and Visualizing Data =============

% We start the exercise by first loading and visualizing the dataset.

% You will be working with a dataset that contains handwritten digits.

%

% Load Training Data

fprintf('Loading and Visualizing Data ...\n')

load('ex3data1.mat');

m = size(X, 1);

% Randomly select 100 data points to display

sel = randperm(size(X, 1));

sel = sel(1:100);

displayData(X(sel, :));

fprintf('Program paused. Press enter to continue.\n');

pause;

%% ================ Part 2: Loading Pameters ================

% In this part of the exercise, we load some pre-initialized

% neural network parameters.

fprintf('\nLoading Saved Neural Network Parameters ...\n')

% Load the weights into variables Theta1 and Theta2

load('ex3weights.mat');

%% ================= Part 3: Implement Predict =================

% After training the neural network, we would like to use it to predict

% the labels. You will now implement the "predict" function to use the

% neural network to predict the labels of the training set. This lets

% you compute the training set accuracy.

pred = predict(Theta1, Theta2, X);

fprintf('\nTraining Set Accuracy: %f\n', mean(double(pred == y)) * 100);

fprintf('Program paused. Press enter to continue.\n');

pause;

% To give you an idea of the network's output, you can also run

% through the examples one at the a time to see what it is predicting.

% Randomly permute examples

rp = randperm(m);

for i = 1:m

% Display

fprintf('\nDisplaying Example Image\n');

displayData(X(rp(i), :));

pred = predict(Theta1, Theta2, X(rp(i),:));

fprintf('\nNeural Network Prediction: %d (digit %d)\n', pred, mod(pred, 10));

% Pause

fprintf('Program paused. Press enter to continue.\n');

pause;

end

预测函数如下:

function p = predict(Theta1, Theta2, X)

%PREDICT Predict the label of an input given a trained neural network

% p = PREDICT(Theta1, Theta2, X) outputs the predicted label of X given the

% trained weights of a neural network (Theta1, Theta2)

% Useful values

m = size(X, 1);

num_labels = size(Theta2, 1);

% You need to return the following variables correctly

p = zeros(size(X, 1), 1);

% ====================== YOUR CODE HERE ======================

% Instructions: Complete the following code to make predictions using

% your learned neural network. You should set p to a

% vector containing labels between 1 to num_labels.

%

% Hint: The max function might come in useful. In particular, the max

% function can also return the index of the max element, for more

% information see 'help max'. If your examples are in rows, then, you

% can use max(A, [], 2) to obtain the max for each row.

%

X = [ones(m, 1) X];

predictZ=X*Theta1';

predictZ=sigmoid(predictZ);

predictZ=[ones(m,1) predictZ];

predictZZ=predictZ*Theta2';

predictY=sigmoid(predictZZ);

[mp,imp]=max(predictY,[],2);

p=imp;

% =========================================================================

end

最终运算截图如下:

b8bc6d4d1114007648c4af3de4e4de51.png

最后与回归分析做个对比:

回归分析需要对每个数字训练一个分类器,这里需要训练10个,每一个分类器迭代50次,结果为:

959d2bf261bbf698fcf770d991840c6e.png

显然神经网络准确率要比回归分析高,同时也要显得简洁很多,代码行数也明显减少,这也正是神经网络的优势所在。



推荐阅读
  • 扫描线三巨头 hdu1928hdu 1255  hdu 1542 [POJ 1151]
    学习链接:http:blog.csdn.netlwt36articledetails48908031学习扫描线主要学习的是一种扫描的思想,后期可以求解很 ... [详细]
  • golang常用库:配置文件解析库/管理工具viper使用
    golang常用库:配置文件解析库管理工具-viper使用-一、viper简介viper配置管理解析库,是由大神SteveFrancia开发,他在google领导着golang的 ... [详细]
  • 本文将介绍如何编写一些有趣的VBScript脚本,这些脚本可以在朋友之间进行无害的恶作剧。通过简单的代码示例,帮助您了解VBScript的基本语法和功能。 ... [详细]
  • Explore a common issue encountered when implementing an OAuth 1.0a API, specifically the inability to encode null objects and how to resolve it. ... [详细]
  • 技术分享:从动态网站提取站点密钥的解决方案
    本文探讨了如何从动态网站中提取站点密钥,特别是针对验证码(reCAPTCHA)的处理方法。通过结合Selenium和requests库,提供了详细的代码示例和优化建议。 ... [详细]
  • 本文深入探讨了 Java 中的 Serializable 接口,解释了其实现机制、用途及注意事项,帮助开发者更好地理解和使用序列化功能。 ... [详细]
  • UNP 第9章:主机名与地址转换
    本章探讨了用于在主机名和数值地址之间进行转换的函数,如gethostbyname和gethostbyaddr。此外,还介绍了getservbyname和getservbyport函数,用于在服务器名和端口号之间进行转换。 ... [详细]
  • ImmutableX Poised to Pioneer Web3 Gaming Revolution
    ImmutableX is set to spearhead the evolution of Web3 gaming, with its innovative technologies and strategic partnerships driving significant advancements in the industry. ... [详细]
  • 题目Link题目学习link1题目学习link2题目学习link3%%%受益匪浅!-----&# ... [详细]
  • 深入理解 SQL 视图、存储过程与事务
    本文详细介绍了SQL中的视图、存储过程和事务的概念及应用。视图为用户提供了一种灵活的数据查询方式,存储过程则封装了复杂的SQL逻辑,而事务确保了数据库操作的完整性和一致性。 ... [详细]
  • 本文介绍了如何使用JQuery实现省市二级联动和表单验证。首先,通过change事件监听用户选择的省份,并动态加载对应的城市列表。其次,详细讲解了使用Validation插件进行表单验证的方法,包括内置规则、自定义规则及实时验证功能。 ... [详细]
  • 前言--页数多了以后需要指定到某一页(只做了功能,样式没有细调)html ... [详细]
  • 本文介绍了在Windows环境下使用pydoc工具的方法,并详细解释了如何通过命令行和浏览器查看Python内置函数的文档。此外,还提供了关于raw_input和open函数的具体用法和功能说明。 ... [详细]
  • 本文详细介绍了C语言中链表的两种动态创建方法——头插法和尾插法,包括具体的实现代码和运行示例。通过这些内容,读者可以更好地理解和掌握链表的基本操作。 ... [详细]
  • 尽管使用TensorFlow和PyTorch等成熟框架可以显著降低实现递归神经网络(RNN)的门槛,但对于初学者来说,理解其底层原理至关重要。本文将引导您使用NumPy从头构建一个用于自然语言处理(NLP)的RNN模型。 ... [详细]
author-avatar
可怜小淖_135
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有