热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

使用TensorFlow实现Top-K准确率计算的神经网络案例

本文通过一个具体的实例,介绍如何利用TensorFlow框架来计算神经网络模型在多分类任务中的Top-K准确率。代码中包含了随机种子设置、模拟预测结果生成、真实标签生成以及准确率计算等步骤。

首先,我们导入TensorFlow库,并设置随机种子以确保实验的可重复性:

import tensorflow as tf

tf.random.set_seed(2467) # 设置随机种子,确保每次运行时生成的数据一致

接着,生成模拟的预测结果和真实标签:

output = tf.random.normal([10, 6])  # 生成10个样本,每个样本属于6个可能类别的预测分数
output = tf.math.softmax(output, axis=1) # 应用Softmax函数,将预测分数转换为概率分布

target = tf.random.uniform([10], maxval=6, dtype=tf.int32) # 生成10个样本的真实标签,范围从0到5

打印原始数据、预测类别及实际类别以供检查:

print('原始数据:', output.numpy())
pred = tf.argmax(output, axis=1) # 获取每个样本预测的最大概率对应的类别
print('预测类别:', pred.numpy())
print('实际类别:', target.numpy())

定义一个函数来计算Top-K准确率:

def calculate_accuracy(predictions, labels, topk=(1, 2, 3, 4, 5, 6)):
maxk = max(topk)
batch_size = labels.shape[0]
topk_predictiOns= tf.math.top_k(predictions, maxk).indices # 获取每个样本最可能的前K个类别的索引
topk_predictiOns= tf.transpose(topk_predictions, perm=[1, 0]) # 转置预测矩阵以便于后续操作
labels_expanded = tf.broadcast_to(labels, topk_predictions.shape) # 广播真实标签至与预测矩阵相同的形状
matches = tf.equal(topk_predictions, labels_expanded) # 比较预测与真实标签是否匹配
accuracies = []
for k in topk:
correct_predictiOns= tf.reduce_sum(tf.cast(tf.reshape(matches[:k], [-1]), dtype=tf.int32))
accuracy = float(correct_predictions / batch_size)
accuracies.append(accuracy)
return accuracies

最后,调用上述函数并输出不同K值下的准确率:

accuracies = calculate_accuracy(output, target, topk=(1, 2, 3, 4, 5, 6))
print('Top1-6的准确率分别是:', accuracies)

推荐阅读
author-avatar
Xlady贩卖__铺
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有