热门标签 | HotTags
当前位置:  开发笔记 > 人工智能 > 正文

在pytorch中计算精度、回归率、F1score等指标的实例

今天小编就为大家分享一篇在pytorch中计算精度、回归率、F1score等指标的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

pytorch中训练完网络后,需要对学习的结果进行测试。官网上例程用的方法统统都是正确率,使用的是torch.eq()这个函数。

但是为了更精细的评价结果,我们还需要计算其他各个指标。在把官网API翻了一遍之后发现并没有用于计算TP,TN,FP,FN的函数。。。

在动了无数歪脑筋之后,心想pytorch完全支持numpy,那能不能直接进行判断,试了一下果然可以,上代码:

# TP predict 和 label 同时为1
TP += ((pred_choice == 1) & (target.data == 1)).cpu().sum()
# TN predict 和 label 同时为0
TN += ((pred_choice == 0) & (target.data == 0)).cpu().sum()
# FN predict 0 label 1
FN += ((pred_choice == 0) & (target.data == 1)).cpu().sum()
# FP predict 1 label 0
FP += ((pred_choice == 1) & (target.data == 0)).cpu().sum()

p = TP / (TP + FP)
r = TP / (TP + FN)
F1 = 2 * r * p / (r + p)
acc = (TP + TN) / (TP + TN + FP + FN

这样就能看到各个指标了。

因为target是Variable所以需要用target.data取到对应的tensor,又因为是在gpu上算的,需要用 .cpu() 移到cpu上。

因为这是一个batch的统计,所以需要用+=累计出整个epoch的统计。当然,在epoch开始之前需要清零

以上这篇在pytorch 中计算精度、回归率、F1 score等指标的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。


推荐阅读
  • 在Conda环境中高效配置并安装PyTorch和TensorFlow GPU版的方法如下:首先,创建一个新的Conda环境以避免与基础环境发生冲突,例如使用 `conda create -n pytorch_gpu python=3.7` 命令。接着,激活该环境,确保所有依赖项都正确安装。此外,建议在安装过程中指定CUDA版本,以确保与GPU兼容性。通过这些步骤,可以确保PyTorch和TensorFlow GPU版的顺利安装和运行。 ... [详细]
  • 【图像分类实战】利用DenseNet在PyTorch中实现秃头识别
    本文详细介绍了如何使用DenseNet模型在PyTorch框架下实现秃头识别。首先,文章概述了项目所需的库和全局参数设置。接着,对图像进行预处理并读取数据集。随后,构建并配置DenseNet模型,设置训练和验证流程。最后,通过测试阶段验证模型性能,并提供了完整的代码实现。本文不仅涵盖了技术细节,还提供了实用的操作指南,适合初学者和有经验的研究人员参考。 ... [详细]
  • 2019年斯坦福大学CS224n课程笔记:深度学习在自然语言处理中的应用——Word2Vec与GloVe模型解析
    本文详细解析了2019年斯坦福大学CS224n课程中关于深度学习在自然语言处理(NLP)领域的应用,重点探讨了Word2Vec和GloVe两种词嵌入模型的原理与实现方法。通过具体案例分析,深入阐述了这两种模型在提升NLP任务性能方面的优势与应用场景。 ... [详细]
  • 不用蘑菇,不拾金币,我通过强化学习成功通关29关马里奥,创造全新纪录
    《超级马里奥兄弟》由任天堂于1985年首次发布,是一款经典的横版过关游戏,至今已在多个平台上售出超过5亿套。该游戏不仅勾起了许多玩家的童年回忆,也成为强化学习领域的热门研究对象。近日,通过先进的强化学习技术,研究人员成功让AI通关了29关,创造了新的纪录。这一成就不仅展示了强化学习在游戏领域的潜力,也为未来的人工智能应用提供了宝贵的经验。 ... [详细]
  • 本文提供了PyTorch框架中常用的预训练模型的下载链接及详细使用指南,涵盖ResNet、Inception、DenseNet、AlexNet、VGGNet等六大分类模型。每种模型的预训练参数均经过精心调优,适用于多种计算机视觉任务。文章不仅介绍了模型的下载方式,还详细说明了如何在实际项目中高效地加载和使用这些模型,为开发者提供全面的技术支持。 ... [详细]
  • 利用 PyTorch 实现 Python 中的高效矩阵运算 ... [详细]
  • 本文介绍了一款高效的开源OCR文本识别模型,结合了TextBoxes++和RetinaNet的优势。该模型在文本检测方面表现出色,适用于多种场景。项目代码已托管至GitHub,方便研究人员和开发者使用和改进。 ... [详细]
  • 在上一节中,我们完成了网络的前向传播实现。本节将重点探讨如何为检测输出设定目标置信度阈值,并应用非极大值抑制技术以提高检测精度。为了更好地理解和实践这些内容,建议读者已经完成本系列教程的前三部分,并具备一定的PyTorch基础知识。此外,我们将详细介绍这些技术的原理及其在实际应用中的重要性,帮助读者深入理解目标检测算法的核心机制。 ... [详细]
  • 在 PyTorch 中,`pin_memory` 技术用于锁定页面内存。当在创建 `DataLoader` 时将 `pin_memory` 参数设置为 `True`,这意味着生成的 Tensor 数据最初会被存储在锁定的内存中。这一技术能够显著提高数据从 CPU 到 GPU 的传输效率,从而加快训练速度。通过合理利用 `pin_memory`,可以有效减少数据加载的瓶颈,提升整体性能。 ... [详细]
  • 谷歌工程师:TensorFlow已重获新生;网友:我还是用PyTorch
    乾明发自凹非寺量子位报道|公众号QbitAI道友留步!TensorFlow已重获新生。在“PyTorch真香”的潮流中,有人站出来为TensorFlow说话了。这次来自谷歌的工程师 ... [详细]
  • 1.如何进行迁移 使用Pytorch写的模型: 对模型和相应的数据使用.cuda()处理。通过这种方式,我们就可以将内存中的数据复制到GPU的显存中去。 ... [详细]
  • 5.Numpy 索引(一维索引/二维索引)
    本文内容是根据莫烦Python网站的视频整理的笔记,笔记中对代码的注释更加清晰明了,同时根据所有笔记还整理了精简版的思维导图,可在此专栏查看,想观看视频可直接去他的网 ... [详细]
  • python教程分享Pytorchmlu 实现添加逐层算子方法详解
    目录1、注册算子2、算子分发3、修改opmethods基类4、下发算子5、添加wrapper6、添加wrapper7、算子测试本教程分享了在寒武纪设备上pytorch-mlu中添加 ... [详细]
  • [TensorFlow系列3]:初学者是选择Tensorflow2.x还是1.x? 2.x与1.x的主要区别?
    作者主页(文火冰糖的硅基工坊):https:blog.csdn.netHiWangWenBing本文网址:https:blog.csdn.netHiW ... [详细]
  • pytorch(网络模型训练)
    上一篇目录标题网络模型训练小插曲训练模型数据训练GPU训练第一种方式方式二:查看GPU信息完整模型验证网络模型训练小插曲区别importtorchatorch ... [详细]
author-avatar
郑谊099_448
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有