热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

使用十折交叉验证评估回归模型性能

本文介绍了如何通过十折交叉验证方法评估回归模型的性能。我们将使用PyTorch框架,详细展示数据处理、模型定义、训练及评估的完整流程。

使用十折交叉验证评估回归模型性能

首先,我们导入所有必要的库和模块,确保环境准备就绪。

import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from collections import OrderedDict
from torch.nn import init
import torch.utils.data as Data

接下来,定义一个函数用于获取每一折的数据,包括训练集和验证集。

def get_kfold_data(k, i, X, y):
fold_size = X.shape[0] // k
val_start = i * fold_size
if i != k - 1:
val_end = (i + 1) * fold_size
X_valid, y_valid = X[val_start:val_end], y[val_start:val_end]
X_train = torch.cat((X[0:val_start], X[val_end:]), dim=0)
y_train = torch.cat((y[0:val_start], y[val_end:]), dim=0)
else:
X_valid, y_valid = X[val_start:], y[val_start:]
X_train = X[0:val_start]
y_train = y[0:val_start]
return X_train, y_train, X_valid, y_valid

然后,实现一个执行多折交叉验证的函数,该函数将返回训练和验证的平均损失与准确率。

def k_fold(k, X, y):
train_loss_sum, valid_loss_sum = 0, 0
train_acc_sum, valid_acc_sum = 0, 0
data = []
train_loss_to_data, valid_loss_to_data = [], []
train_acc_to_data, valid_acc_to_data = [], []
for i in range(k):
print(f'第 {i + 1} 折验证结果')
X_train, y_train, X_valid, y_valid = get_kfold_data(k, i, X, y)
train_dataset = Data.TensorDataset(X_train, y_train)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
valid_dataset = Data.TensorDataset(X_valid, y_valid)
valid_loader = DataLoader(
dataset=valid_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
train_loss, valid_loss, train_acc, valid_acc = train(model, train_loader, valid_loader, loss, num_epochs, batch_size, lr)
train_loss_to_data.append(train_loss)
valid_loss_to_data.append(valid_loss)
train_acc_to_data.append(train_acc.detach().numpy())
valid_acc_to_data.append(valid_acc.detach().numpy())
train_loss_sum += train_loss
valid_loss_sum += valid_loss
train_acc_sum += train_acc
valid_acc_sum += valid_acc
print('\n', '最终k折交叉验证结果:')
print(f'average train loss: {train_loss_sum / k:.4f}, average train accuracy: {train_acc_sum / k * 100:.3f}%')
print(f'average valid loss: {valid_loss_sum / k:.4f}, average valid accuracy: {valid_acc_sum / k * 100:.3f}%')
data.extend([train_loss_to_data, valid_loss_to_data, train_acc_to_data, valid_acc_to_data])
return data

定义模型训练函数,该函数将完成模型的训练过程,并返回每个epoch的训练和验证损失及准确率。

def train(model, train_loader, valid_loader, loss, num_epochs, batch_size, lr):
train_losses, valid_losses = [], []
train_accuracies, valid_accuracies = [], []
for epoch in range(num_epochs):
train_loss_sum, valid_loss_sum = 0, 0
train_acc_sum, valid_acc_sum = 0, 0
n_train, n_valid = 0, 0
for X, y in train_loader:
y_pred = model(X)
l = loss(y_pred, y)
optimizer.zero_grad()
l.backward()
optimizer.step()
train_loss_sum += l.item()
acc = (1 - abs(y_pred - y) / y).mean()
train_acc_sum += acc
n_train += 1
with torch.no_grad():
for X, y in valid_loader:
y_pred = model(X)
l = loss(y_pred, y)
valid_loss_sum += l.item()
acc = (1 - abs(y_pred - y) / y).mean()
valid_acc_sum += acc
n_valid += 1
train_losses.append(train_loss_sum / n_train)
valid_losses.append(valid_loss_sum / n_valid)
train_accuracies.append(train_acc_sum / n_train)
valid_accuracies.append(valid_acc_sum / n_valid)
print(f'epoch {epoch + 1}, train_loss {train_losses[-1]:.6f}, train_acc {train_accuracies[-1] * 100:.3f}%, valid_loss {valid_losses[-1]:.6f}, valid_acc {valid_accuracies[-1] * 100:.3f}%')
return train_losses[-1], valid_losses[-1], train_accuracies[-1], valid_accuracies[-1]

生成模拟数据集,用于模型训练和验证。

num_features, num_samples = 500, 10000
true_weights = torch.ones(1, num_features) * 0.0056
true_bias = 0.028
x_data = torch.tensor(np.random.normal(0, 0.001, size=(num_samples, num_features)), dtype=torch.float32)
y = torch.mm(x_data, true_weights.t()) + true_bias
y += torch.normal(0, 0.001, y.shape)

构建回归模型,并初始化模型参数。

model = nn.Sequential(OrderedDict([
('linear1', nn.Linear(num_features, 256)),
('relu1', nn.ReLU()),
('linear2', nn.Linear(256, 128)),
('relu2', nn.ReLU()),
('linear3', nn.Linear(128, 1)),
]))
for param in model.parameters():
init.normal_(param, mean=0, std=0.001)

设置超参数并定义损失函数和优化器。

k_folds = 10
learning_rate = 0.001
batch_size = 50
epochs = 10
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

启动训练和验证过程,收集每折的结果。

results = k_fold(k_folds, x_data, y)

最后,使用Pandas将结果保存到CSV文件中,便于后续分析。

import pandas as pd

fold_names = [f'第{i + 1}折' for i in range(k_folds)]
data_frame = {
'Fold': fold_names,
'Train Loss': results[0],
'Valid Loss': results[1],
'Train Acc': results[2],
'Valid Acc': results[3],
}
df = pd.DataFrame(data_frame)
df.to_csv('./feedforward_neural_network_kfold_regression.csv', index=False)
df

推荐阅读
  • 本文探讨了K近邻(KNN)算法中K值的选择对模型复杂度的影响,通过实验分析不同K值下的模型表现,旨在为KNN算法的应用提供指导。 ... [详细]
  • 本文探讨了在Android平台下编写和读取.JSON文件的方法,解决读取文件时遇到的字符间异常空格问题。 ... [详细]
  • 本文介绍了如何利用高德地图API实现一个高效的地点选择组件,适用于需要用户选择具体位置的应用场景,如活动邀请函填写等。该组件支持从地图中选择地点,并自动将地点信息回填至表单中。 ... [详细]
  • 本文将指导你如何通过自定义配置,使 Windows Terminal 中的 PowerShell 7 更加高效且美观。我们将移除默认的广告和提示符,设置快捷键,并添加实用的别名和功能。 ... [详细]
  • 本文档详细介绍了Excel VBA编程中的基本语法,包括循环结构、条件判断、数据处理以及用户界面设计等内容,旨在帮助初学者快速掌握VBA编程技巧。 ... [详细]
  • 本文详细介绍了如何将After Effects中的动画相机数据导入到Vizrt系统中,提供了一种有效的解决方案,适用于需要在广播级图形制作中使用AE动画的专业人士。 ... [详细]
  • 本文章介绍了如何将阿拉伯数字形式的金额转换为中国传统的大写形式,适用于财务报告和正式文件中的金额表示。 ... [详细]
  • 探讨了一个关于Windows C++开发中遇到的乱码问题,特别是在处理宽字符时出现的情况。本文通过一个具体的示例——一个简单的窗口应用程序,展示了如何正确地使用宽字符以避免乱码。 ... [详细]
  • 本文探讨了一起由物化视图统计信息不当引起的查询性能下降问题,并详细介绍了问题的诊断与解决方法。通过调整统计信息收集策略,最终显著提升了查询效率。 ... [详细]
  • 本文详细介绍了MySQL 5.5及以上版本中事务管理的全过程,包括事务的启动、设置、锁机制以及解锁方法,旨在为开发者提供一个清晰、全面的操作指南,避免因网络资料分散而导致的学习障碍。 ... [详细]
  • 从零开始学重构——重构的流程及基础重构手法
    重构的流程重构手法  正如上一次所讲的那样,重构有两个基本条件,一是要保持代码在重构前后的行为基本不变,二是整个过程是受控且尽可能少地产生错误。尤其是对于第二点,产生了一系列的重构手 ... [详细]
  • KNN算法在海伦约会预测中的应用
    本文介绍如何使用KNN算法进行海伦约会的预测。我们将从数据导入、数据预处理、数据可视化到最终的模型训练和测试进行全面解析。 ... [详细]
  • 大数据SQL优化:全面解析数据倾斜解决方案
    本文深入探讨了大数据SQL优化中的数据倾斜问题,提供了多种解决策略和实际案例,旨在帮助读者理解和应对这一常见挑战。 ... [详细]
  • 本文详细介绍了 Go 语言的关键特性和编程理念,包括其强大的并发处理能力、简洁的语法设计以及高效的开发效率。 ... [详细]
  • 搜索引擎架构设计
    本文详细介绍了搜索引擎的主要组成部分,包括爬虫模块、索引模块和搜索模块。其中,索引模块采用了高效的二元分词技术进行数据存储,而搜索模块则基于ASP.NET框架实现了一个用户友好的界面和高效的搜索算法。 ... [详细]
author-avatar
mobiledu2502887833
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有