热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

MATLAB实现单变量线性回归

本文介绍了如何在MATLAB中实现单变量线性回归,这是基于Coursera上AndrewNg教授的机器学习课程中的一个实践项目。文章详细讲解了从数据可视化到模型训练的每一个步骤。
### MATLAB实现单变量线性回归

#### 1. 数据可视化

此部分是Coursera上Andrew Ng教授的《机器学习》课程中的一个作业,提供了一个完整的框架,只需实现几个关键函数即可。首先,我们需要绘制原始数据图。

```matlab
fprintf('绘制数据...
');
data = load('ex1data1.txt');
X = data(:, 1); y = data(:, 2);
m = length(y); % 训练样本数量

% 绘制数据点
% 注意:需要在plotData.m文件中完成绘图代码
plotData(X, y);

fprintf('程序暂停。按回车继续。
');
pause;
```

在这个步骤中,我们使用`load`函数加载数据,并调用自定义的`plotData`函数来绘制数据点。

```matlab
function plotData(x, y)
% PLOTDATA 绘制数据点x和y
% PLOTDATA(x, y) 绘制数据点,并设置坐标轴标签为人口和利润

figure; % 打开一个新的图形窗口
plot(x, y, 'rx', 'MarkerSize', 10);
ylabel('利润($10,000)'); % 设置y轴标签
xlabel('城市人口(10,000人)'); % 设置x轴标签
end
```

#### 2. 损失函数与梯度下降

绘制完原始数据图后,接下来是线性回归的核心部分——梯度下降算法。

```matlab
% =================== Part 2: 成本函数和梯度下降 ===================

X = [ones(m, 1), data(:, 1)]; % 在X前添加一列1
theta = zeros(2, 1); % 初始化拟合参数

% 梯度下降设置
iteratiOns= 1500;
alpha = 0.01;

fprintf('\n测试成本函数...
');

% 计算并显示初始成本
J = computeCost(X, y, theta);
fprintf('当theta = [0 ; 0]\n计算的成本 = %f\n', J);
fprintf('预期成本值(近似) 32.07\n');

% 进一步测试成本函数
J = computeCost(X, y, [-1 ; 2]);
fprintf('\n当theta = [-1 ; 2]\n计算的成本 = %f\n', J);
fprintf('预期成本值(近似) 54.24\n');

fprintf('程序暂停。按回车继续。
');
pause;

fprintf('\n运行梯度下降...
');

% 运行梯度下降
theta = gradientDescent(X, y, theta, alpha, iterations);

% 显示theta
fprintf('通过梯度下降找到的theta:\n');
fprintf('%f\n', theta);
fprintf('预期的theta值(近似)\n');
fprintf(' -3.6303\n 1.1664\n\n');

% 绘制线性拟合曲线
hold on; % 保持之前的图形可见
plot(X(:, 2), X * theta, '-');
legend('训练数据', '线性回归');
hold off; % 不再覆盖当前图形

% 预测人口规模为35,000和70,000时的利润
predict1 = [1, 3.5] * theta;
fprintf('对于人口 = 35,000,预测利润为 %f\n', predict1 * 10000);
predict2 = [1, 7] * theta;
fprintf('对于人口 = 70,000,预测利润为 %f\n', predict2 * 10000);

fprintf('程序暂停。按回车继续。
');
pause;
```

这段代码首先对输入的X进行了处理,添加了一列1,以便将常数项theta0与之相乘。然后,我们测试了损失函数`computeCost`的正确性。

```matlab
function J = computeCost(X, y, theta)
% COMPUTECOST 计算线性回归的成本
% J = COMPUTECOST(X, y, theta) 计算使用theta作为参数进行线性回归的成本

m = length(y); % 训练样本数量
J = 0;

for i = 1:m
J = J + ((X(i, :) * theta - y(i))^2) / (2 * m);
end

end
```

接下来,我们使用梯度下降函数`gradientDescent`来更新theta,并记录每次迭代的成本变化。

```matlab
function [theta, J_history] = gradientDescent(X, y, theta, alpha, num_iters)
% GRADIENTDESCENT 执行梯度下降以学习theta
% theta = GRADIENTDESCENT(X, y, theta, alpha, num_iters) 通过梯度下降更新theta

m = length(y); % 训练样本数量
J_history = zeros(num_iters, 1);

for iter = 1:num_iters
% 执行一次梯度下降步骤
temp0 = theta(1) - alpha * sum(X * theta - y) / m;
temp1 = theta(2) - alpha * (X(:, 2)' * (X * theta - y)) / m;
theta(1) = temp0;
theta(2) = temp1;

% 保存每次迭代的成本
J_history(iter) = computeCost(X, y, theta);
end

end
```

#### 3. 损失函数的可视化

最后,我们绘制损失函数和theta的关系图,并画出等高线,标记出最终得到的theta。

```matlab
% ============= Part 3: 可视化J(theta_0, theta_1) =============
fprintf('可视化J(theta_0, theta_1)...
');

% 创建网格以计算J
theta0_vals = linspace(-10, 10, 100);
theta1_vals = linspace(-1, 4, 100);

% 初始化J_vals为0矩阵
J_vals = zeros(length(theta0_vals), length(theta1_vals));

% 填充J_vals
for i = 1:length(theta0_vals)
for j = 1:length(theta1_vals)
t = [theta0_vals(i); theta1_vals(j)];
J_vals(i, j) = computeCost(X, y, t);
end
end

% 因为surf命令的工作方式,需要转置J_vals
J_vals = J_vals';

% 绘制表面图
figure;
surf(theta0_vals, theta1_vals, J_vals);
xlabel('\theta_0'); ylabel('\theta_1');

% 绘制等高线图
figure;
contour(theta0_vals, theta1_vals, J_vals, logspace(-2, 3, 20));
xlabel('\theta_0'); ylabel('\theta_1');
hold on;
plot(theta(1), theta(2), 'rx', 'MarkerSize', 10, 'LineWidth', 2);
```

通过这些步骤,我们完成了单变量线性回归的实现。整个过程相对简单,主要得益于提供的框架和详细的指导。
推荐阅读
  • 基因组浏览器中的Wig格式解析
    本文详细介绍了Wiggle(Wig)格式及其在基因组浏览器中的应用,涵盖variableStep和fixedStep两种主要格式的特点、适用场景及具体使用方法。同时,还提供了关于数据值和自定义参数的补充信息。 ... [详细]
  • 本文探讨了如何在给定整数N的情况下,找到两个不同的整数a和b,使得它们的和最大,并且满足特定的数学条件。 ... [详细]
  • 本文详细介绍了Akka中的BackoffSupervisor机制,探讨其在处理持久化失败和Actor重启时的应用。通过具体示例,展示了如何配置和使用BackoffSupervisor以实现更细粒度的异常处理。 ... [详细]
  • 尽管使用TensorFlow和PyTorch等成熟框架可以显著降低实现递归神经网络(RNN)的门槛,但对于初学者来说,理解其底层原理至关重要。本文将引导您使用NumPy从头构建一个用于自然语言处理(NLP)的RNN模型。 ... [详细]
  • 毕业设计:基于机器学习与深度学习的垃圾邮件(短信)分类算法实现
    本文详细介绍了如何使用机器学习和深度学习技术对垃圾邮件和短信进行分类。内容涵盖从数据集介绍、预处理、特征提取到模型训练与评估的完整流程,并提供了具体的代码示例和实验结果。 ... [详细]
  • 本文详细介绍了Java中org.neo4j.helpers.collection.Iterators.single()方法的功能、使用场景及代码示例,帮助开发者更好地理解和应用该方法。 ... [详细]
  • 本文将介绍如何编写一些有趣的VBScript脚本,这些脚本可以在朋友之间进行无害的恶作剧。通过简单的代码示例,帮助您了解VBScript的基本语法和功能。 ... [详细]
  • 本文介绍了如何使用JQuery实现省市二级联动和表单验证。首先,通过change事件监听用户选择的省份,并动态加载对应的城市列表。其次,详细讲解了使用Validation插件进行表单验证的方法,包括内置规则、自定义规则及实时验证功能。 ... [详细]
  • 前言--页数多了以后需要指定到某一页(只做了功能,样式没有细调)html ... [详细]
  • 本文介绍了在Windows环境下使用pydoc工具的方法,并详细解释了如何通过命令行和浏览器查看Python内置函数的文档。此外,还提供了关于raw_input和open函数的具体用法和功能说明。 ... [详细]
  • 本文详细介绍了中央电视台电影频道的节目预告,并通过专业工具分析了其加载方式,确保用户能够获取最准确的电视节目信息。 ... [详细]
  • 本文详细介绍如何在VSCode中配置自定义代码片段,使其具备与IDEA相似的代码生成快捷键功能。通过具体的Java和HTML代码片段示例,展示配置步骤及效果。 ... [详细]
  • 本文详细介绍了如何使用 Yii2 的 GridView 组件在列表页面实现数据的直接编辑功能。通过具体的代码示例和步骤,帮助开发者快速掌握这一实用技巧。 ... [详细]
  • C++: 实现基于类的四面体体积计算
    本文介绍如何使用C++编程语言,通过定义类和方法来计算由四个三维坐标点构成的四面体体积。文中详细解释了四面体体积的数学公式,并提供了两种不同的实现方式。 ... [详细]
  • 本章将深入探讨移动 UI 设计的核心原则,帮助开发者构建简洁、高效且用户友好的界面。通过学习设计规则和用户体验优化技巧,您将能够创建出既美观又实用的移动应用。 ... [详细]
author-avatar
游你精彩_980_469
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有