热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

使用十折交叉验证评估回归模型性能

本文介绍了如何通过十折交叉验证方法评估回归模型的性能。我们将使用PyTorch框架,详细展示数据处理、模型定义、训练及评估的完整流程。

使用十折交叉验证评估回归模型性能

首先,我们导入所有必要的库和模块,确保环境准备就绪。

import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from collections import OrderedDict
from torch.nn import init
import torch.utils.data as Data

接下来,定义一个函数用于获取每一折的数据,包括训练集和验证集。

def get_kfold_data(k, i, X, y):
fold_size = X.shape[0] // k
val_start = i * fold_size
if i != k - 1:
val_end = (i + 1) * fold_size
X_valid, y_valid = X[val_start:val_end], y[val_start:val_end]
X_train = torch.cat((X[0:val_start], X[val_end:]), dim=0)
y_train = torch.cat((y[0:val_start], y[val_end:]), dim=0)
else:
X_valid, y_valid = X[val_start:], y[val_start:]
X_train = X[0:val_start]
y_train = y[0:val_start]
return X_train, y_train, X_valid, y_valid

然后,实现一个执行多折交叉验证的函数,该函数将返回训练和验证的平均损失与准确率。

def k_fold(k, X, y):
train_loss_sum, valid_loss_sum = 0, 0
train_acc_sum, valid_acc_sum = 0, 0
data = []
train_loss_to_data, valid_loss_to_data = [], []
train_acc_to_data, valid_acc_to_data = [], []
for i in range(k):
print(f'第 {i + 1} 折验证结果')
X_train, y_train, X_valid, y_valid = get_kfold_data(k, i, X, y)
train_dataset = Data.TensorDataset(X_train, y_train)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
valid_dataset = Data.TensorDataset(X_valid, y_valid)
valid_loader = DataLoader(
dataset=valid_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
train_loss, valid_loss, train_acc, valid_acc = train(model, train_loader, valid_loader, loss, num_epochs, batch_size, lr)
train_loss_to_data.append(train_loss)
valid_loss_to_data.append(valid_loss)
train_acc_to_data.append(train_acc.detach().numpy())
valid_acc_to_data.append(valid_acc.detach().numpy())
train_loss_sum += train_loss
valid_loss_sum += valid_loss
train_acc_sum += train_acc
valid_acc_sum += valid_acc
print('\n', '最终k折交叉验证结果:')
print(f'average train loss: {train_loss_sum / k:.4f}, average train accuracy: {train_acc_sum / k * 100:.3f}%')
print(f'average valid loss: {valid_loss_sum / k:.4f}, average valid accuracy: {valid_acc_sum / k * 100:.3f}%')
data.extend([train_loss_to_data, valid_loss_to_data, train_acc_to_data, valid_acc_to_data])
return data

定义模型训练函数,该函数将完成模型的训练过程,并返回每个epoch的训练和验证损失及准确率。

def train(model, train_loader, valid_loader, loss, num_epochs, batch_size, lr):
train_losses, valid_losses = [], []
train_accuracies, valid_accuracies = [], []
for epoch in range(num_epochs):
train_loss_sum, valid_loss_sum = 0, 0
train_acc_sum, valid_acc_sum = 0, 0
n_train, n_valid = 0, 0
for X, y in train_loader:
y_pred = model(X)
l = loss(y_pred, y)
optimizer.zero_grad()
l.backward()
optimizer.step()
train_loss_sum += l.item()
acc = (1 - abs(y_pred - y) / y).mean()
train_acc_sum += acc
n_train += 1
with torch.no_grad():
for X, y in valid_loader:
y_pred = model(X)
l = loss(y_pred, y)
valid_loss_sum += l.item()
acc = (1 - abs(y_pred - y) / y).mean()
valid_acc_sum += acc
n_valid += 1
train_losses.append(train_loss_sum / n_train)
valid_losses.append(valid_loss_sum / n_valid)
train_accuracies.append(train_acc_sum / n_train)
valid_accuracies.append(valid_acc_sum / n_valid)
print(f'epoch {epoch + 1}, train_loss {train_losses[-1]:.6f}, train_acc {train_accuracies[-1] * 100:.3f}%, valid_loss {valid_losses[-1]:.6f}, valid_acc {valid_accuracies[-1] * 100:.3f}%')
return train_losses[-1], valid_losses[-1], train_accuracies[-1], valid_accuracies[-1]

生成模拟数据集,用于模型训练和验证。

num_features, num_samples = 500, 10000
true_weights = torch.ones(1, num_features) * 0.0056
true_bias = 0.028
x_data = torch.tensor(np.random.normal(0, 0.001, size=(num_samples, num_features)), dtype=torch.float32)
y = torch.mm(x_data, true_weights.t()) + true_bias
y += torch.normal(0, 0.001, y.shape)

构建回归模型,并初始化模型参数。

model = nn.Sequential(OrderedDict([
('linear1', nn.Linear(num_features, 256)),
('relu1', nn.ReLU()),
('linear2', nn.Linear(256, 128)),
('relu2', nn.ReLU()),
('linear3', nn.Linear(128, 1)),
]))
for param in model.parameters():
init.normal_(param, mean=0, std=0.001)

设置超参数并定义损失函数和优化器。

k_folds = 10
learning_rate = 0.001
batch_size = 50
epochs = 10
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

启动训练和验证过程,收集每折的结果。

results = k_fold(k_folds, x_data, y)

最后,使用Pandas将结果保存到CSV文件中,便于后续分析。

import pandas as pd

fold_names = [f'第{i + 1}折' for i in range(k_folds)]
data_frame = {
'Fold': fold_names,
'Train Loss': results[0],
'Valid Loss': results[1],
'Train Acc': results[2],
'Valid Acc': results[3],
}
df = pd.DataFrame(data_frame)
df.to_csv('./feedforward_neural_network_kfold_regression.csv', index=False)
df

推荐阅读
  • 尽管使用TensorFlow和PyTorch等成熟框架可以显著降低实现递归神经网络(RNN)的门槛,但对于初学者来说,理解其底层原理至关重要。本文将引导您使用NumPy从头构建一个用于自然语言处理(NLP)的RNN模型。 ... [详细]
  • 本文介绍如何使用 NSTimer 实现倒计时功能,详细讲解了初始化方法、参数配置以及具体实现步骤。通过示例代码展示如何创建和管理定时器,确保在指定时间间隔内执行特定任务。 ... [详细]
  • 基因组浏览器中的Wig格式解析
    本文详细介绍了Wiggle(Wig)格式及其在基因组浏览器中的应用,涵盖variableStep和fixedStep两种主要格式的特点、适用场景及具体使用方法。同时,还提供了关于数据值和自定义参数的补充信息。 ... [详细]
  • 本文详细介绍了IBM DB2数据库在大型应用系统中的应用,强调其卓越的可扩展性和多环境支持能力。文章深入分析了DB2在数据利用性、完整性、安全性和恢复性方面的优势,并提供了优化建议以提升其在不同规模应用程序中的表现。 ... [详细]
  • 本文将介绍如何编写一些有趣的VBScript脚本,这些脚本可以在朋友之间进行无害的恶作剧。通过简单的代码示例,帮助您了解VBScript的基本语法和功能。 ... [详细]
  • DNN Community 和 Professional 版本的主要差异
    本文详细解析了 DotNetNuke (DNN) 的两种主要版本:Community 和 Professional。通过对比两者的功能和附加组件,帮助用户选择最适合其需求的版本。 ... [详细]
  • 本章将深入探讨移动 UI 设计的核心原则,帮助开发者构建简洁、高效且用户友好的界面。通过学习设计规则和用户体验优化技巧,您将能够创建出既美观又实用的移动应用。 ... [详细]
  • 本文介绍了在Windows环境下使用pydoc工具的方法,并详细解释了如何通过命令行和浏览器查看Python内置函数的文档。此外,还提供了关于raw_input和open函数的具体用法和功能说明。 ... [详细]
  • 本文介绍如何使用阿里云的fastjson库解析包含时间戳、IP地址和参数等信息的JSON格式文本,并进行数据处理和保存。 ... [详细]
  • Explore a common issue encountered when implementing an OAuth 1.0a API, specifically the inability to encode null objects and how to resolve it. ... [详细]
  • 本文详细介绍如何使用Python进行配置文件的读写操作,涵盖常见的配置文件格式(如INI、JSON、TOML和YAML),并提供具体的代码示例。 ... [详细]
  • 本文深入探讨了 Java 中的 Serializable 接口,解释了其实现机制、用途及注意事项,帮助开发者更好地理解和使用序列化功能。 ... [详细]
  • 本文探讨了如何在给定整数N的情况下,找到两个不同的整数a和b,使得它们的和最大,并且满足特定的数学条件。 ... [详细]
  • 本文详细介绍了Java中org.w3c.dom.Text类的splitText()方法,通过多个代码示例展示了其实际应用。该方法用于将文本节点在指定位置拆分为两个节点,并保持在文档树中。 ... [详细]
  • 本文介绍如何通过创建替代插入触发器,使对视图的插入操作能够正确更新相关的基本表。涉及的表包括:飞机(Aircraft)、员工(Employee)和认证(Certification)。 ... [详细]
author-avatar
mobiledu2502887833
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有