热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

技术分享:线性回归模型的双路径构建——基于sklearn库的实践探索

篇首语:本文由编程笔记#小编为大家整理,主要介绍了一线性回归的两种实现方式:sklearn实现相关的知识,希望对你有一定的参考价值。

篇首语:本文由编程笔记#小编为大家整理,主要介绍了一线性回归的两种实现方式:sklearn实现相关的知识,希望对你有一定的参考价值。






线性回归的sklearn实现

导入必要的模块

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

数据集

x = np.array([50, 30, 15, 40, 55, 20, 45, 10, 60, 25])
y = np.array([5.9, 4.6, 2.7, 4.8, 6.5, 3.6, 5.1, 2.0, 6.3, 3.8])

画出数据集的散点图

plt.scatter(x, y)
plt.grid(True)
plt.xlabel('area')
plt.ylabel('price')
plt.show()

在这里插入图片描述


数据划分

划分训练集和测试集

使用到的api:

数据划分sklearn.model_selection.train_test_split

用到的参数:


  • *arrays:输入数据集。

  • test_size:划分出来的测试集占总数据量的比例,取值0~1。

  • shuffle:是否在划分前打乱数据的顺序,默认True。

  • random_state:shuffle的随机种子,取值正整数。

返回:


  • splitting:列表包含划分后的训练集与测试集。

x_train, x_test, y_train, y_test = train_test_split(
x, y, test_size=0.3, shuffle=True, random_state=23)

查看训练集的散点图

plt.scatter(x_train,y_train)
plt.grid('True')
plt.xlabel('area')
plt.ylabel('price')
plt.show()

在这里插入图片描述

查看测试集的散点图

plt.scatter(x_test,y_test)
plt.grid('True')
plt.xlabel('area')
plt.ylabel('price')
plt.show()

在这里插入图片描述


模型搭建

使用到的api:

线性回归sklearn.linear_model.LinearRegression

model = LinearRegression()

模型训练

使用到的api:

线性回归模型训练sklearn.linear_model.LinearRegression.fit

用到的参数:


  • X:输入特征,如果输入是np.array格式,shape必须是(n_sample, n_feature)。

  • y:输入标签。

# x_train的shape由(7,)变为(7,1)
x_train = x_train.reshape(-1,1)
model.fit(X=x_train, y=y_train)

LinearRegression()

模型预测

对测试集做预测

使用到的api:

线性回归模型预测sklearn.linear_model.LinearRegression.predict

用到的参数:


  • X:输入特征,如果输入是np.array格式,shape必须是(n_sample, n_feature)。

返回:


  • C:预测结果。

# x_test的shape由(7,)变为(7, 1)
x_test = x_test.reshape(-1,1)
y_test_pred = model.predict(x_test)

画出数据集的散点图和预测直线

x_test = x_test.reshape(-1)
plt.scatter(x_test, y_test, color='g', label='test dataset')
plt.scatter(x_train, y_train, color='b',label='train dataset')
plt.plot(np.sort(x_test), y_test_pred[np.argsort(x_test)], color='r', label='linear regression')
plt.legend()
plt.show()

在这里插入图片描述


计算评价指标mse

使用到的api:

均方误差sklearn.metrics.mean_squared_error

用到的参数:


  • y_true:真实值(ground truth)。

  • y_pred:预测值。

返回:


  • loss:mse计算结果。

mse = mean_squared_error(y_true=y_test, y_pred=y_test_pred)
print('MSE: {}'.format(mse))

MSE: 0.15383086014546365

查看线性回归模型的系数w和截距b

使用到的api:

回归系数sklearn.linear_model.LinearRegression.coef_

截距项sklearn.linear_model.LinearRegression.intercept_

w, b = model.coef_[0], model.intercept_
print('Weight={0} bias={1}'.format(w, b))

Weight=0.09139423076923077 bias=1.3420673076923069





推荐阅读
  • 本文详细介绍了macOS系统的核心组件,包括如何管理其安全特性——系统完整性保护(SIP),并探讨了不同版本的更新亮点。对于使用macOS系统的用户来说,了解这些信息有助于更好地管理和优化系统性能。 ... [详细]
  • 本文将介绍如何编写一些有趣的VBScript脚本,这些脚本可以在朋友之间进行无害的恶作剧。通过简单的代码示例,帮助您了解VBScript的基本语法和功能。 ... [详细]
  • 技术分享:从动态网站提取站点密钥的解决方案
    本文探讨了如何从动态网站中提取站点密钥,特别是针对验证码(reCAPTCHA)的处理方法。通过结合Selenium和requests库,提供了详细的代码示例和优化建议。 ... [详细]
  • 本文详细介绍了 MySQL 中 LAST_INSERT_ID() 函数的使用方法及其工作原理,包括如何获取最后一个插入记录的自增 ID、多行插入时的行为以及在不同客户端环境下的表现。 ... [详细]
  • 利用决策树预测NBA比赛胜负的Python数据挖掘实践
    本文通过使用2013-14赛季NBA赛程与结果数据集以及2013年NBA排名数据,结合《Python数据挖掘入门与实践》一书中的方法,展示如何应用决策树算法进行比赛胜负预测。我们将详细讲解数据预处理、特征工程及模型评估等关键步骤。 ... [详细]
  • ML学习笔记20210824分类算法模型选择与调优
    3.模型选择和调优3.1交叉验证定义目的为了让模型得精度更加可信3.2超参数搜索GridSearch对K值进行选择。k[1,2,3,4,5,6]循环遍历搜索。API参数1& ... [详细]
  • Explore a common issue encountered when implementing an OAuth 1.0a API, specifically the inability to encode null objects and how to resolve it. ... [详细]
  • 本文详细介绍了Akka中的BackoffSupervisor机制,探讨其在处理持久化失败和Actor重启时的应用。通过具体示例,展示了如何配置和使用BackoffSupervisor以实现更细粒度的异常处理。 ... [详细]
  • 本文介绍如何使用 NSTimer 实现倒计时功能,详细讲解了初始化方法、参数配置以及具体实现步骤。通过示例代码展示如何创建和管理定时器,确保在指定时间间隔内执行特定任务。 ... [详细]
  • 本文介绍了如何通过 Maven 依赖引入 SQLiteJDBC 和 HikariCP 包,从而在 Java 应用中高效地连接和操作 SQLite 数据库。文章提供了详细的代码示例,并解释了每个步骤的实现细节。 ... [详细]
  • 使用Powershell Studio快速构建GUI应用程序
    本文介绍了如何利用Powershell Studio创建功能强大的可视化界面。相较于传统的开发工具,Powershell Studio提供了更为简便和高效的开发体验,尤其适合需要快速构建图形用户界面(GUI)的场景。 ... [详细]
  • 本文详细介绍了如何在Kendo UI for jQuery的数据管理组件中,将行标题字段呈现为锚点(即可点击链接),帮助开发人员更高效地实现这一功能。通过具体的代码示例和解释,即使是新手也能轻松掌握。 ... [详细]
  • 本文档介绍了如何在Visual Studio 2010环境下,利用C#语言连接SQL Server 2008数据库,并实现基本的数据操作,如增删改查等功能。通过构建一个面向对象的数据库工具类,简化了数据库操作流程。 ... [详细]
  • 本文介绍如何在C#中将GridView控件的内容保存为图片文件。通过代码示例,详细说明了创建位图、绘制图形并保存图像的步骤。 ... [详细]
  • Win10 UWP 开发技巧:利用 XamlTreeDump 获取 XAML 元素树
    本文介绍如何在 Win10 UWP 开发中使用 XamlTreeDump 库来获取和转换 XAML 元素树为 JSON 字符串,这对于 UI 单元测试非常有用。 ... [详细]
author-avatar
永不言败LM
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有