对于target为类id,其实做了one-hot 操作,例如对于总共有三类,其中id为2,则转换后的标签如下:[0,0,1]。 这些标签作为权重乘上input的值进行叠加。
参见代码,秒懂。 三个输出一致
if __name__ == "__main__":import torchimport torch.nn as nnnllloss = nn.NLLLoss()x = torch.tensor([[1.5,2.5,3.0],[1.2,2.0,2.9]])onehot_y = torch.tensor([[0,1.0,0],[0,0,1]])logsoft_out = nn.LogSoftmax()(x)y = torch.tensor([1,2])print(nllloss(logsoft_out,y))print(nn.CrossEntropyLoss()(x,onehot_y))print(nn.CrossEntropyLoss()(x,y))exit()
如果input和target都是相同维度,例如3x5。
其实做了一个这样的操作,
torch.matmul(input, target.T), 再对这个3x3的矩阵[0,0],[1,1],[2,2]的值累加做平均