热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

神经网络字符识别matlab程序,用matlab实现神经网络识别数字|学步园

AndrewNg机器学习第四周的编程练习是用matlab实现一个神经网络对一幅图中的数字进行识别,有待识别的数字全集如下:其中每一个数字都是一个大小为2

Andrew Ng机器学习第四周的编程练习是用matlab实现一个神经网络对一幅图中的数字进行识别,有待识别的数字全集如下:

27ec67dea300c3ec54ce58cb0f503ca2.png

其中每一个数字都是一个大小为20*20像素的图像,如果把每个像素作为一个输入单元,那有400个输入。考虑到神经网络还需要增加一个额外输入单元表示偏差,一共有401个输入单元。题目中给的训练数据X是一个5000*400的向量。

题目中要求包含一个25个节点的隐藏层,隐藏层也存在表示偏差的额外输入,所以一共有26个输入。

最终的输出结果是一个10维的向量,分别表示该数字在0-9上面的概率值(由于没有0这个下标位,这里题目中把0标记为10,其余1-9还是对应1-9),找到其中概率最大的就是要识别的结果。

神经网络的结构如下:

e0a4127de02530ff296d668bf877af5f.png

从上图可以看到,神经网络中除了输入参数外,还包含Theta1和Theta2两个参数。

其中的Theta1就表示输入层到隐含层中每条边的权重,为25*401的向量。Theta2是隐含层到输出层每条边的权重,为10*26的向量。

为了把数据标准化减少误差,这里要对每一步的输出用sigmoid函数进行处理。

构造好神经网络后,首先是用训练数据进行训练,得出Theta1和Theta2的权重信息,然后就可以预测了。

主要的matlab代码如下:

%% Machine Learning Online Class - Exercise 3 | Part 2: Neural Networks

% Instructions

% ------------

%

% This file contains code that helps you get started on the

% linear exercise. You will need to complete the following functions

% in this exericse:

%

% lrCostFunction.m (logistic regression cost function)

% oneVsAll.m

% predictOneVsAll.m

% predict.m

%

% For this exercise, you will not need to change any code in this file,

% or any other files other than those mentioned above.

%

%% Initialization

clear ; close all; clc

%% Setup the parameters you will use for this exercise

input_layer_size = 400; % 20x20 Input Images of Digits

hidden_layer_size = 25; % 25 hidden units

num_labels = 10; % 10 labels, from 1 to 10

% (note that we have mapped "0" to label 10)

%% =========== Part 1: Loading and Visualizing Data =============

% We start the exercise by first loading and visualizing the dataset.

% You will be working with a dataset that contains handwritten digits.

%

% Load Training Data

fprintf('Loading and Visualizing Data ...\n')

load('ex3data1.mat');

m = size(X, 1);

% Randomly select 100 data points to display

sel = randperm(size(X, 1));

sel = sel(1:100);

displayData(X(sel, :));

fprintf('Program paused. Press enter to continue.\n');

pause;

%% ================ Part 2: Loading Pameters ================

% In this part of the exercise, we load some pre-initialized

% neural network parameters.

fprintf('\nLoading Saved Neural Network Parameters ...\n')

% Load the weights into variables Theta1 and Theta2

load('ex3weights.mat');

%% ================= Part 3: Implement Predict =================

% After training the neural network, we would like to use it to predict

% the labels. You will now implement the "predict" function to use the

% neural network to predict the labels of the training set. This lets

% you compute the training set accuracy.

pred = predict(Theta1, Theta2, X);

fprintf('\nTraining Set Accuracy: %f\n', mean(double(pred == y)) * 100);

fprintf('Program paused. Press enter to continue.\n');

pause;

% To give you an idea of the network's output, you can also run

% through the examples one at the a time to see what it is predicting.

% Randomly permute examples

rp = randperm(m);

for i = 1:m

% Display

fprintf('\nDisplaying Example Image\n');

displayData(X(rp(i), :));

pred = predict(Theta1, Theta2, X(rp(i),:));

fprintf('\nNeural Network Prediction: %d (digit %d)\n', pred, mod(pred, 10));

% Pause

fprintf('Program paused. Press enter to continue.\n');

pause;

end

预测函数如下:

function p = predict(Theta1, Theta2, X)

%PREDICT Predict the label of an input given a trained neural network

% p = PREDICT(Theta1, Theta2, X) outputs the predicted label of X given the

% trained weights of a neural network (Theta1, Theta2)

% Useful values

m = size(X, 1);

num_labels = size(Theta2, 1);

% You need to return the following variables correctly

p = zeros(size(X, 1), 1);

% ====================== YOUR CODE HERE ======================

% Instructions: Complete the following code to make predictions using

% your learned neural network. You should set p to a

% vector containing labels between 1 to num_labels.

%

% Hint: The max function might come in useful. In particular, the max

% function can also return the index of the max element, for more

% information see 'help max'. If your examples are in rows, then, you

% can use max(A, [], 2) to obtain the max for each row.

%

X = [ones(m, 1) X];

predictZ=X*Theta1';

predictZ=sigmoid(predictZ);

predictZ=[ones(m,1) predictZ];

predictZZ=predictZ*Theta2';

predictY=sigmoid(predictZZ);

[mp,imp]=max(predictY,[],2);

p=imp;

% =========================================================================

end

最终运算截图如下:

b8bc6d4d1114007648c4af3de4e4de51.png

最后与回归分析做个对比:

回归分析需要对每个数字训练一个分类器,这里需要训练10个,每一个分类器迭代50次,结果为:

959d2bf261bbf698fcf770d991840c6e.png

显然神经网络准确率要比回归分析高,同时也要显得简洁很多,代码行数也明显减少,这也正是神经网络的优势所在。



推荐阅读
  • 尽管使用TensorFlow和PyTorch等成熟框架可以显著降低实现递归神经网络(RNN)的门槛,但对于初学者来说,理解其底层原理至关重要。本文将引导您使用NumPy从头构建一个用于自然语言处理(NLP)的RNN模型。 ... [详细]
  • 本文将介绍如何编写一些有趣的VBScript脚本,这些脚本可以在朋友之间进行无害的恶作剧。通过简单的代码示例,帮助您了解VBScript的基本语法和功能。 ... [详细]
  • 本文深入探讨了 Java 中的 Serializable 接口,解释了其实现机制、用途及注意事项,帮助开发者更好地理解和使用序列化功能。 ... [详细]
  • ImmutableX Poised to Pioneer Web3 Gaming Revolution
    ImmutableX is set to spearhead the evolution of Web3 gaming, with its innovative technologies and strategic partnerships driving significant advancements in the industry. ... [详细]
  • 本章将深入探讨移动 UI 设计的核心原则,帮助开发者构建简洁、高效且用户友好的界面。通过学习设计规则和用户体验优化技巧,您将能够创建出既美观又实用的移动应用。 ... [详细]
  • 扫描线三巨头 hdu1928hdu 1255  hdu 1542 [POJ 1151]
    学习链接:http:blog.csdn.netlwt36articledetails48908031学习扫描线主要学习的是一种扫描的思想,后期可以求解很 ... [详细]
  • Explore how Matterverse is redefining the metaverse experience, creating immersive and meaningful virtual environments that foster genuine connections and economic opportunities. ... [详细]
  • Explore a common issue encountered when implementing an OAuth 1.0a API, specifically the inability to encode null objects and how to resolve it. ... [详细]
  • 本文探讨了Hive中内部表和外部表的区别及其在HDFS上的路径映射,详细解释了两者的创建、加载及删除操作,并提供了查看表详细信息的方法。通过对比这两种表类型,帮助读者理解如何更好地管理和保护数据。 ... [详细]
  • XNA 3.0 游戏编程:从 XML 文件加载数据
    本文介绍如何在 XNA 3.0 游戏项目中从 XML 文件加载数据。我们将探讨如何将 XML 数据序列化为二进制文件,并通过内容管道加载到游戏中。此外,还会涉及自定义类型读取器和写入器的实现。 ... [详细]
  • UNP 第9章:主机名与地址转换
    本章探讨了用于在主机名和数值地址之间进行转换的函数,如gethostbyname和gethostbyaddr。此外,还介绍了getservbyname和getservbyport函数,用于在服务器名和端口号之间进行转换。 ... [详细]
  • 机器学习中的相似度度量与模型优化
    本文探讨了机器学习中常见的相似度度量方法,包括余弦相似度、欧氏距离和马氏距离,并详细介绍了如何通过选择合适的模型复杂度和正则化来提高模型的泛化能力。此外,文章还涵盖了模型评估的各种方法和指标,以及不同分类器的工作原理和应用场景。 ... [详细]
  • 本文详细介绍了C语言中链表的两种动态创建方法——头插法和尾插法,包括具体的实现代码和运行示例。通过这些内容,读者可以更好地理解和掌握链表的基本操作。 ... [详细]
  • 深入探讨CPU虚拟化与KVM内存管理
    本文详细介绍了现代服务器架构中的CPU虚拟化技术,包括SMP、NUMA和MPP三种多处理器结构,并深入探讨了KVM的内存虚拟化机制。通过对比不同架构的特点和应用场景,帮助读者理解如何选择最适合的架构以优化性能。 ... [详细]
  • 本题探讨如何通过最大流算法解决农场排水系统的设计问题。题目要求计算从水源点到汇合点的最大水流速率,使用经典的EK(Edmonds-Karp)和Dinic算法进行求解。 ... [详细]
author-avatar
可怜小淖_135
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有