tf.nn.softmax_cross_entropy_with_logits(labels,logits)这个方法是最大似然函数,也是损失函数;
这个函数有两个主要的参数,一个是标签,一个是最后全连接层输出的结果;注意这两个参数的维度必须一致!
标签一般都是用ont-hot热码,什么是ont-hot,自己百度,很简单的
如果是3类,那么就是 [0,0,1] [0,1,0] [1,0,0] ont-hot只有0和1 而且只有用ont-hot表示只有一个1
一个音频文件或者一张图片,经过卷积层或者全连接层以后会得到一个logits
求损失函数是将这个文件的真实标签与最后输出的logits进行比较的,怎么比较呢?
比如总共三类,而这个音频文件对应标签是第三类[1,0,0],那么labels就是[1,0,0]
经过深度网络以后得到的logits=[5.0,7.0,8.0]
标签 [1,0,0]
输出数据 [5.0,7.0,8.0]
输出数据中哪个位置的数据最大就属于哪一类,我们是8最大,应该属于第一类,但是现在是反的,实际是第三类,对应的是5!这个误差很大,所以经过tf.nn.softmax_cross_entropy_with_logits(labels,logits)这个方法得到的值很大!真实标签与预测标签的值大相径庭
我们得到的损失值3.34
如果Logit是[8,5,7]呢? 最大的值是8,对应的位置正好是1,说明预测文件类型与真实标签类型一致,误差值应该很小,我们运行程序看看
得到的损失值为0.349
如果logits为[7,5,8]呢?最大的值8对应的是0,显然预测错误,但是第二大的是7,7对应的是1,是真实的标签,误差肯定比3.34小
误差值为1.34
我们要理解这个函数是干嘛的?
labels是对应文件真实标签的热码形式
logits是我们的神经网络预测的结果,哪个位置的值最大,就代表哪一类(当然是我们神经网络预测的结果)
tf.nn.softmax_cross_entropy_with_logits(labels,logits)就是将真实标签与预测标签进行对比,如果一致那么误差越小,如果不一致,那么在logits中真实标签对应位置的值在所有值中的大小,如果是第二大的话,那么误差也很小,如果是最小的,误差最大!