热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

[机器学习]分类问题常用评价指标、混淆矩阵及ROC曲线绘制方法

分类问题分类问题是人工智能领域中最常见的一类问题之一,掌握合适的评价指标,对模型进行恰当的评价,是至关重要的。同样地,分割问题是像素级别的分类,除了mAcc、mIoU之外,也可以采用分类问题的一些指标来评价。本文对分类问题的常见评价指标进行介绍,并附上利用sklearn库

分类问题

分类问题是人工智能领域中最常见的一类问题之一,掌握合适的评价指标,对模型进行恰当的评价,是至关重要的。

同样地,分割问题是像素级别的分类,除了mAcc、mIoU之外,也可以采用分类问题的一些指标来评价。

本文对分类问题的常见评价指标进行介绍,并附上利用sklearn库的python实现。

将从以下三个方面分别介绍:

  1. 常用评价指标
  2. 混淆矩阵绘制及评价指标计算
  3. ROC曲线绘制及AUC计算

1. 常用评价指标

混淆矩阵(confusion matrix)

一般用来描述一个分类器分类的准确程度。
根据分类器在测试数据集上的预测是否正确可以分为四种情况:

  • TP(True Positive)——将正类预测为正类数;
  • FN(False Negative)——将正类预测为负类数;
  • FP(False Positive)——将负类预测为正类数;
  • TN(True Negative)——将负类预测为负类数。
    构成一个二分类的混淆矩阵如图:
    image

均交并比(Mean Intersection over Union,mIoU):

语义分割的标准度量。其计算两个集合的交并比,在语义分割的问题中,这两个集合为真实值(ground truth)和预测值(predicted segmentation)。
image

分类问题评价指标

二分类问题经混淆矩阵的处理后,针对不同问题,可以选用不同的指标来评价系统。

  1. Accuracy:表示预测结果的精确度,预测正确的样本数除以总样本数;
  2. Precision:准确率,表示预测结果中,预测为正样本的样本中,正确预测为正样本的概率;
  3. Sensitivity:灵敏度,表示在原始样本的正样本中,最后被正确预测为正样本的概率;
  4. Specificity:常常称作特异性,它研究的样本集是原始样本中的负样本,表示的是在这些负样本中最后被正确预测为负样本的概率;
  5. F1-score:表示的是precision和recall的调和平均评估指标。
    image

受试者工作特征(Receiver Operating Characteristic,ROC)曲线

ROC曲线是以真阳性率(TPR)为Y轴,以假阳性率(FPR)为X轴做的图。同样用来综合评价模型分类情况。是反映敏感性和特异性连续变量的综合指标。
image

AUC(Area Under Curve)

AUC的值为ROC曲线下与x轴围成的面积,分类器的性能越接近完美,AUC的值越接近。当0.5>AUC>1时,效果优于“随机猜测”。一般情况下,模型的AUC值应当在此范围内。

2. 混淆矩阵绘制及评价指标计算

首先是分类器的训练,以sklearn库中的基础分类器为例

from sklearn.svm import SVC, LinearSVC
clf = LinearSVC()
clf.fit(train_features, train_target)
predict = clf.predict(test_features)

# 绘制混淆矩阵和评价指标计算
cal(test_target, pred)

# 获取分类score
score = clf.decision_function(test_features)

# 绘制ROC曲线和计算AUC
paint_ROC(test_target, test_score)

混淆矩阵的绘制和评价指标计算可以写在一起,在绘制混淆矩阵时,已经可以算出TP\TN\FP\FN的数值。

# 这是一个多分类问题,y_true是target,y_pred是模型预测结果,数据格式为numpy

def cal(y_true, y_pred):

    # confusion matrix row means GT, column means predication
    name = 'save_name'
    '''画混淆矩阵'''
    mat = confusion_matrix(y_true, y_pred)
    da = pd.DataFrame(mat, index = ['0', '1', '2'])
    sns.heatmap(da, annot =True, cbar = None, cmap = 'Blues')
    plt.title(name)
    # plt.tight_layout()yt
    plt.ylabel('True Label')
    plt.xlabel('Predict Label')
    plt.show()
    plt.savefig('{}/{}.png'.format('save_path', name)) # 将混淆矩阵图片保存下来
    plt.close()
    
    '''计算指标'''
    tp = np.diagonal(mat) # 每类的tp
    gt_num = np.sum(mat, axis=1) # axis = 1 指每行 ,每类的总数
    pre_num = np.sum(mat, axis=0)
    fp = pre_num - tp
    fn = gt_num - tp
    num = np.sum(gt_num)
    num = np.repeat(num, gt_num.shape[0])
    gt_num0 = num - gt_num
    tn = gt_num0 -fp
	
    recall = tp.astype(np.float32) / gt_num
    specificity = tn.astype(np.float32) / gt_num0
    precision = tp.astype(np.float32) / pre_num
    F1 = 2 * (precision * recall) / (precision + recall)
    acc = (tp + tn).astype(np.float32) / num

    print('recall:', recall, '\nmean recall:{:.4f}'.format(np.mean(recall)) )
    print('specificity:', specificity, '\nmean specificity:{:.4f}'.format(np.mean(specificity)))
    print('precision:', precision, '\nmean precision:{:.4f}'.format(np.mean(precision)))
    print('F1:', F1 , '\nmean F1:{:.4f}'.format(np.mean(F1)))
    print('acc:', acc , '\nmean acc:{:.4f}'.format(np.mean(acc)))

3. ROC曲线绘制及AUC计算

# 这是一个多分类问题(三分类),可以在一张图上绘制多条ROC曲线

def paint_ROC(y_test, y_score):

    '''画ROC曲线'''
    plt.figure()
    # 修改颜色
    colors = ['','darkred', 'darkorange', 'cornflowerblue']

    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    # print('label',y_test)
    # print('score', y_score)

    label = np.zeros((len(y_test), 3),  dtype="uint8")
    for i in range(len(y_test)):
        label[i][int(y_test[i])-1] = 1
    # print('label',label)

    for i in range(1,4):
        fpr[i], tpr[i], _ = metrics.roc_curve(label[:,i-1], y_score[:, i-1])
        roc_auc[i] = metrics.auc(fpr[i], tpr[i])

    fpr["mean"], tpr["mean"], _ = metrics.roc_curve(label.ravel(), y_score.ravel())
    roc_auc["mean"] = metrics.auc(fpr["mean"], tpr["mean"])

    lw = 2
    plt.plot(fpr["mean"], tpr["mean"],
         label='average, ROC curve (area = {0:0.2f})'
               ''.format(roc_auc["mean"]),
         color='k', linewidth=lw)

    for i in range(1,4):
        auc = roc_auc[i]
        # 输出不同类别的FPR\TPR\AUC
        print('label: {}, fpr: {}, tpr: {}, auc: {}'.format(i, np.mean(fpr[i]), np.mean(tpr[i]), auc))
        plt.plot(fpr[i], tpr[i], color=colors[i],line,lw = lw, label='Label = {0}, ROC curve (area = {1:0.2f})'.format(i, auc))

    plt.plot([0, 1], [0, 1], color='navy', line)
    plt.xlim([0.0, 1.05])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    # plt.title('Receiver operating characteristic example')
    plt.grid(line)  
    plt.grid(True)
    plt.legend(loc="lower right")
    plt.show()
    # 保存绘制好的ROC曲线
    plt.savefig('{}/{}.png'.format('save_path', 'save_name'))
    plt.close()


推荐阅读
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 本文介绍了贝叶斯垃圾邮件分类的机器学习代码,代码来源于https://www.cnblogs.com/huangyc/p/10327209.html,并对代码进行了简介。朴素贝叶斯分类器训练函数包括求p(Ci)和基于词汇表的p(w|Ci)。 ... [详细]
  • keras归一化激活函数dropout
    激活函数:1.softmax函数在多分类中常用的激活函数,是基于逻辑回归的,常用在输出一层,将输出压缩在0~1之间,且保证所有元素和为1,表示输入值属于每个输出值的概率大小2、Si ... [详细]
  • python机器学习之数据探索
    🐱今天我们来讲解数据建模之前需要处理的工作,也就是数据探索的过程,很多同学会说,不就是处理缺失值,异常值&# ... [详细]
  • 上一篇《手把手教你用深度学习做物体检测(三):模型训练》中介绍了如何使用yolov3训练我们自己的物体检测模型,本篇文章将重点介绍如何使用我们训练好的模型来检测图片或视频中的物体 ... [详细]
  • python seaborn_大白话Python绘图系列Seaborn篇
    1.目的了解python第三方绘图包seaborn,从常用绘图实例开始,快速体验seaborn绘图。建议用时:10分钟绘图例子:12个每个例子代码量:1 ... [详细]
  • K-Means算法原理
    原理给定样本集,k-means算法得到聚类,使得下面平方误差最小其中表示聚类的中心点。实现上式最小化是一个NP难问题,实际上采用EM算法可以求得近似解。算法伪代码如下输入:,聚 ... [详细]
  • 分类与聚类
    一:分类1:定义分类其实是从特定的数据中挖掘模式,做出判断的过程。分类是在一群已经知道类别标号的样本中,训练一种分类器 ... [详细]
  • IamtryingtogetasparsematrixintoH2OandIwaswonderingwhetherthatwaspossible.Supposew ... [详细]
  • 海量数据分类 liblinear使用总结
    liblinear使用总结liblinear是libsvm的线性核的改进版本,专门适用于百万数据量的分类。正好适用于我这次数据挖掘的实验。liblinear用法和li ... [详细]
  • 注意力汇聚:NadarayaWatson 核回归
    Nadaraya-Watson核回归是具有注意力机制的机器学习范例。Nadaraya-Watson核回归的注意力汇聚是对训练数据中输出的加权平均。从注意力的角度来看, ... [详细]
  • 使用Flutternewintegration_test进行示例集成测试?回答首先在dev下的p ... [详细]
  • 初识顶部导航栏【flutter20个实例之一】
    初识顶部导航栏【flutter20个实例之一】-一、老套路,先看样式二图是我的实际开发中业务界面,用作展示而已二、讲解(后附源码)1.这里主要是用户AppBar组件** ... [详细]
  • Unity Graphic功能,实现UGUI上三角形,四边形,圆环的绘制
    前言这篇简单的纪录下利用Graphic类,实现UGUI圆环的绘制。效果图如下:github目录:https:github.comluck ... [详细]
  • 搜索:eclipse:ctrlhidea:ctrlshiftf(如果失效,两种方法,搜狗拼音 ... [详细]
author-avatar
惜靜吾_919
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有