热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pytorch里的log_softmax,nll_loss,cross_entropy

importtorch.nn.functionalasF常用的函数有:log_softmax,nll_loss,cross_entropy1.log_softma

import torch.nn.functional as F

常用的函数有: log_softmax,nll_loss, cross_entropy

1.log_softmax, 就是log和softmax合并在一起执行

2. nll_loss ,函数全称是negative log likelihood loss, 函数表达式为

   f(x,class)=-x[class]

例如:假设x=[1,2,3], class=2, 则

f(x,class)=-x[2]=-3

3. cross_entropy

交叉熵的计算公式为

cross_{entropy}=-\frac{1}{N}\sum _{k=1}^N(p_k*log q_k)

其中p表示真实值,在这个公式中是one-hot形式,q是预测值,在这里假设是经过softmax后的结果了。

仔细观察可知,因为p的元素不是0就是1,而且又是乘法,所以很自然地,如果我们知道1所对应的index,那么就不用做其他无意义的运算了。所以在pytorch代码中target不是以one-hot形式表示的,而是直接用scalar表示,所以交叉熵的公式(m表示真实类别)可变形为

cross_{entropy}=-\sum _{k=1}^N(p_k*log q_k)=-log q_m    

仔细看看,就是等同于log_softmax和nll_loss两个步骤

所以pytorch中的F.cross_entropy会自动调用上面介绍的log_softmax和nll_loss来计算交叉熵,其计算方式如下:

loss(x,class)=-log(\frac{e^{x[class]}}{\sum _j e^{x[j]}})

代码实现

import torch
import torch.nn.functional as F
import torchsnooper
@torchsnooper.snoop() #用于调试,观察哪里会出错误
def main(): input=torch.from_numpy(np.array([[-0.0979, -0.3264, -0.1561, 0.2403, -0.4000], [-1.6561, 0.9922, -0.2922, 0.4640, -0.6327], [-0.3424, 1.3673, 0.8343, -0.1376, 1.7234]])) target=torch.from_numpy(np.array([2, 2, 4])) #方法一, 直接调用cross_entropy 计算损失函数 loss = F.cross_entropy(input, target)#直接调用函数进行计算tensor(1.5196, dtype=torch.float64) print(loss) #方法二, 分步骤计算损失函数 #step1: 先计算softmax() probability=F.softmax(input,dim=1)#shape [num_samples,num_classes] #step2 计算预测值的对数 log_P=torch.log(probability) '''对输入的target标签进行 one-hot编码,使用_scatter方法''' one_hot=torch.zeros(probability.shape,dtype=float).scatter_(1,torch.unsqueeze(target,dim=1),1) #根据交叉熵公式loss=-target*log(predict) loss3=-one_hot*log_P loss3=loss3.sum() loss3/=probability.shape[0] print('loss3',loss3) # 1.5196与第一种方法结果是一样的 #方法三, 借助于log_softmax logsoftmax=F.log_softmax(input) loss4=-one_hot*logsoftmax loss4=loss4.sum() loss4/=probability.shape[0] print('loss4',loss4) #1.5196 #方法四, 借助于nllloss nllloss=F.nll_loss(logsoftmax,target) print('nllloss',nllloss)
if __name__ == '__main__': main()

 


推荐阅读
  • 第四章高阶函数(参数传递、高阶函数、lambda表达式)(python进阶)的讲解和应用
    本文主要讲解了第四章高阶函数(参数传递、高阶函数、lambda表达式)的相关知识,包括函数参数传递机制和赋值机制、引用传递的概念和应用、默认参数的定义和使用等内容。同时介绍了高阶函数和lambda表达式的概念,并给出了一些实例代码进行演示。对于想要进一步提升python编程能力的读者来说,本文将是一个不错的学习资料。 ... [详细]
  • 语义分割系列3SegNet(pytorch实现)
    SegNet手稿最早是在2015年12月投出,和FCN属于同时期作品。稍晚于FCN,既然属于后来者,又是与FCN同属于语义分割网络 ... [详细]
  • Java太阳系小游戏分析和源码详解
    本文介绍了一个基于Java的太阳系小游戏的分析和源码详解。通过对面向对象的知识的学习和实践,作者实现了太阳系各行星绕太阳转的效果。文章详细介绍了游戏的设计思路和源码结构,包括工具类、常量、图片加载、面板等。通过这个小游戏的制作,读者可以巩固和应用所学的知识,如类的继承、方法的重载与重写、多态和封装等。 ... [详细]
  • Commit1ced2a7433ea8937a1b260ea65d708f32ca7c95eintroduceda+Clonetraitboundtom ... [详细]
  • Java容器中的compareto方法排序原理解析
    本文从源码解析Java容器中的compareto方法的排序原理,讲解了在使用数组存储数据时的限制以及存储效率的问题。同时提到了Redis的五大数据结构和list、set等知识点,回忆了作者大学时代的Java学习经历。文章以作者做的思维导图作为目录,展示了整个讲解过程。 ... [详细]
  • 本文讨论了如何优化解决hdu 1003 java题目的动态规划方法,通过分析加法规则和最大和的性质,提出了一种优化的思路。具体方法是,当从1加到n为负时,即sum(1,n)sum(n,s),可以继续加法计算。同时,还考虑了两种特殊情况:都是负数的情况和有0的情况。最后,通过使用Scanner类来获取输入数据。 ... [详细]
  • 本文介绍了OC学习笔记中的@property和@synthesize,包括属性的定义和合成的使用方法。通过示例代码详细讲解了@property和@synthesize的作用和用法。 ... [详细]
  • 本文主要解析了Open judge C16H问题中涉及到的Magical Balls的快速幂和逆元算法,并给出了问题的解析和解决方法。详细介绍了问题的背景和规则,并给出了相应的算法解析和实现步骤。通过本文的解析,读者可以更好地理解和解决Open judge C16H问题中的Magical Balls部分。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • 关键词:Golang, Cookie, 跟踪位置, net/http/cookiejar, package main, golang.org/x/net/publicsuffix, io/ioutil, log, net/http, net/http/cookiejar ... [详细]
  • 本文介绍了PE文件结构中的导出表的解析方法,包括获取区段头表、遍历查找所在的区段等步骤。通过该方法可以准确地解析PE文件中的导出表信息。 ... [详细]
  • [大整数乘法] java代码实现
    本文介绍了使用java代码实现大整数乘法的过程,同时也涉及到大整数加法和大整数减法的计算方法。通过分治算法来提高计算效率,并对算法的时间复杂度进行了研究。详细代码实现请参考文章链接。 ... [详细]
  • 本文讨论了一个数列求和问题,该数列按照一定规律生成。通过观察数列的规律,我们可以得出求解该问题的算法。具体算法为计算前n项i*f[i]的和,其中f[i]表示数列中有i个数字。根据参考的思路,我们可以将算法的时间复杂度控制在O(n),即计算到5e5即可满足1e9的要求。 ... [详细]
  • 本文讨论了编写可保护的代码的重要性,包括提高代码的可读性、可调试性和直观性。同时介绍了优化代码的方法,如代码格式化、解释函数和提炼函数等。还提到了一些常见的坏代码味道,如不规范的命名、重复代码、过长的函数和参数列表等。最后,介绍了如何处理数据泥团和进行函数重构,以提高代码质量和可维护性。 ... [详细]
  • pytorch Dropout过拟合的操作
    这篇文章主要介绍了pytorchDropout过拟合的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完 ... [详细]
author-avatar
半夏✔
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有