import torch.nn.functional as F
常用的函数有: log_softmax,nll_loss, cross_entropy
1.log_softmax, 就是log和softmax合并在一起执行
2. nll_loss ,函数全称是negative log likelihood loss, 函数表达式为
![f(x,class)=-x[class]](https://img7.php1.cn/3cdc5/cd80/b64/ea933904f0743e29.gif)
例如:假设x=[1,2,3], class=2, 则
![f(x,class)=-x[2]=-3](https://img7.php1.cn/3cdc5/cd80/b64/7580ea0ad61a33c9.gif)
3. cross_entropy
交叉熵的计算公式为
![cross_{entropy}=-\frac{1}{N}\sum _{k=1}^N(p_k*log q_k)](https://img7.php1.cn/3cdc5/cd80/b64/284482fb9f7ac487.gif)
其中p表示真实值,在这个公式中是one-hot形式,q是预测值,在这里假设是经过softmax后的结果了。
仔细观察可知,因为p的元素不是0就是1,而且又是乘法,所以很自然地,如果我们知道1所对应的index,那么就不用做其他无意义的运算了。所以在pytorch代码中target不是以one-hot形式表示的,而是直接用scalar表示,所以交叉熵的公式(m表示真实类别)可变形为
仔细看看,就是等同于log_softmax和nll_loss两个步骤
所以pytorch中的F.cross_entropy会自动调用上面介绍的log_softmax和nll_loss来计算交叉熵,其计算方式如下:
![loss(x,class)=-log(\frac{e^{x[class]}}{\sum _j e^{x[j]}})](https://img7.php1.cn/3cdc5/cd80/b64/770b640338147e34.gif)
代码实现
import torch
import torch.nn.functional as F
import torchsnooper
@torchsnooper.snoop() #用于调试,观察哪里会出错误
def main(): input=torch.from_numpy(np.array([[-0.0979, -0.3264, -0.1561, 0.2403, -0.4000], [-1.6561, 0.9922, -0.2922, 0.4640, -0.6327], [-0.3424, 1.3673, 0.8343, -0.1376, 1.7234]])) target=torch.from_numpy(np.array([2, 2, 4])) #方法一, 直接调用cross_entropy 计算损失函数 loss = F.cross_entropy(input, target)#直接调用函数进行计算tensor(1.5196, dtype=torch.float64) print(loss) #方法二, 分步骤计算损失函数 #step1: 先计算softmax() probability=F.softmax(input,dim=1)#shape [num_samples,num_classes] #step2 计算预测值的对数 log_P=torch.log(probability) '''对输入的target标签进行 one-hot编码,使用_scatter方法''' one_hot=torch.zeros(probability.shape,dtype=float).scatter_(1,torch.unsqueeze(target,dim=1),1) #根据交叉熵公式loss=-target*log(predict) loss3=-one_hot*log_P loss3=loss3.sum() loss3/=probability.shape[0] print('loss3',loss3) # 1.5196与第一种方法结果是一样的 #方法三, 借助于log_softmax logsoftmax=F.log_softmax(input) loss4=-one_hot*logsoftmax loss4=loss4.sum() loss4/=probability.shape[0] print('loss4',loss4) #1.5196 #方法四, 借助于nllloss nllloss=F.nll_loss(logsoftmax,target) print('nllloss',nllloss)
if __name__ == '__main__': main()