热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

深入解析PyTorch中的交叉熵损失函数(CrossEntropyLoss)

在PyTorch的`CrossEntropyLoss`函数中,当目标标签`target`为类别ID时,实际上会进行one-hot编码处理。例如,假设总共有三个类别,其中一个类别的ID为2,则该标签会被转换为`[0,0,1]`。这一过程简化了多分类任务中的损失计算,使得模型能够更高效地进行训练和评估。此外,`CrossEntropyLoss`还结合了softmax激活函数和负对数似然损失,进一步提高了模型的性能和稳定性。

对于target为类id,其实做了one-hot 操作,例如对于总共有三类,其中id为2,则转换后的标签如下:[0,0,1]。 这些标签作为权重乘上input的值进行叠加。

参见代码,秒懂。 三个输出一致

if __name__ == "__main__":import torchimport torch.nn as nnnllloss = nn.NLLLoss()x = torch.tensor([[1.5,2.5,3.0],[1.2,2.0,2.9]])onehot_y = torch.tensor([[0,1.0,0],[0,0,1]])logsoft_out = nn.LogSoftmax()(x)y = torch.tensor([1,2])print(nllloss(logsoft_out,y))print(nn.CrossEntropyLoss()(x,onehot_y))print(nn.CrossEntropyLoss()(x,y))exit()

如果input和target都是相同维度,例如3x5。

其实做了一个这样的操作,

torch.matmul(input, target.T), 再对这个3x3的矩阵[0,0],[1,1],[2,2]的值累加做平均

 

 


推荐阅读
author-avatar
快乐天使小可爱66
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有