热门标签 | HotTags
当前位置:  开发笔记 > 前端 > 正文

tensorflow教程之tf.nn.softmax_cross_entropy_with_logits()方法解析

tf.nn.softmax_cross_entropy_with_logits(labels,logits)这个方法是最大似然函数,也是损失函数;

这个函数有两个主要的参数,一个是标签,一个是最后全连接层输出的结果;注意这两个参数的维度必须一致!

标签一般都是用ont-hot热码,什么是ont-hot,自己百度,很简单的

如果是3类,那么就是  [0,0,1] [0,1,0] [1,0,0]  ont-hot只有0和1 而且只有用ont-hot表示只有一个1

一个音频文件或者一张图片,经过卷积层或者全连接层以后会得到一个logits

求损失函数是将这个文件的真实标签与最后输出的logits进行比较的,怎么比较呢?


比如总共三类,而这个音频文件对应标签是第三类[1,0,0],那么labels就是[1,0,0]

经过深度网络以后得到的logits=[5.0,7.0,8.0] 

标签   [1,0,0]

输出数据  [5.0,7.0,8.0]

输出数据中哪个位置的数据最大就属于哪一类,我们是8最大,应该属于第一类,但是现在是反的,实际是第三类,对应的是5!这个误差很大,所以经过tf.nn.softmax_cross_entropy_with_logits(labels,logits)这个方法得到的值很大!真实标签与预测标签的值大相径庭

tensorflow教程之tf.nn.softmax_cross_entropy_with_logits()方法解析

我们得到的损失值3.34

 

如果Logit是[8,5,7]呢? 最大的值是8,对应的位置正好是1,说明预测文件类型与真实标签类型一致,误差值应该很小,我们运行程序看看

tensorflow教程之tf.nn.softmax_cross_entropy_with_logits()方法解析

得到的损失值为0.349

 

如果logits为[7,5,8]呢?最大的值8对应的是0,显然预测错误,但是第二大的是7,7对应的是1,是真实的标签,误差肯定比3.34小

tensorflow教程之tf.nn.softmax_cross_entropy_with_logits()方法解析

误差值为1.34

 

 


我们要理解这个函数是干嘛的?

labels是对应文件真实标签的热码形式

logits是我们的神经网络预测的结果,哪个位置的值最大,就代表哪一类(当然是我们神经网络预测的结果)

 tf.nn.softmax_cross_entropy_with_logits(labels,logits)就是将真实标签与预测标签进行对比,如果一致那么误差越小,如果不一致,那么在logits中真实标签对应位置的值在所有值中的大小,如果是第二大的话,那么误差也很小,如果是最小的,误差最大!


 


推荐阅读
author-avatar
李妙妙_minioniu_173
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有