热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pytorch实现straightthroughestimator(STE)

现在深度学习中一般我们学习的参数都是连续的,因为这样在反向传播的时候才可以对梯度进行更新。但是有的时候我们也会遇到参数是离散的情况,这样就没有办法进行

现在深度学习中一般我们学习的参数都是连续的,因为这样在反向传播的时候才可以对梯度进行更新。但是有的时候我们也会遇到参数是离>散的情况,这样就没有办法进行反向传播了,比如二值神经网络。本文中讲解了如何用pytorch对二值化的参数进行梯度更新的straight-through estimator算法。
Question:
STE核心的思想就是我们的参数初始化的时候就是float这样的连续值,当我们forward的时候就将原来的连续的参数映射到{-1, 1}带入到网络进行计算,这样就可以计算网络的输出。然后backward的时候直接对原来float的参数进行更新,而不是对二值化的参数更新。这样可以完成对整个网络的更新了。
首先我们对上面问题进行一下数学的讲解。

 

Example:
首先我们验证一下使用torch.sign会是参数的梯度基本上都是0:

>>> input = torch.randn(4, requires_grad = True)
>>> output = torch.sign(input)
>>> loss = output.mean()
>>> loss.backward()
>>> input
tensor([-0.8673, -0.0299, -1.1434, -0.6172], requires_grad=True)
>>> input.grad
tensor([0., 0., 0., 0.])

 我们需要重写sign这个函数,就好像写一个激活函数一样。

import torchclass LBSign(torch.autograd.Function):@staticmethoddef forward(ctx, input):return torch.sign(input)@staticmethoddef backward(ctx, grad_output):return grad_output.clamp_(-1, 1)

import torch
from LBSign import LBSignif __name__ == '__main__':sign = LBSign.applyparams = torch.randn(4, requires_grad = True) output = sign(params)loss = output.mean()loss.backward()

测试梯度:

>>> params
tensor([-0.9143, 0.8993, -1.1235, -0.7928], requires_grad=True)
>>> params.grad
tensor([0.2500, 0.2500, 0.2500, 0.2500])

 

 


推荐阅读
  • HSE8MHz。配置前将所有RCC重置为初始值RCC_DeInit();*这里选择外部晶振(HSE)作为时钟源,因此首先打开外部晶振*RC ... [详细]
  • 上次我们总结了React代码构建后的webpack模块组织关系,今天来介绍一下Babel编译JSX生成目标代码的一些规则,并且写一个简单的解析器,模拟整个生成的过程。我们还是拿最简 ... [详细]
  • DNNBrain:北师大团队出品,国内首款用于映射深层神经网络到大脑的统一工具箱...
    导读深度神经网络(DNN)通过端到端的深度学习策略在许多具有挑战性的任务上达到了人类水平的性能。深度学习产生了具有多层抽象层次的数据表示;然而,它没有明确地提供任何关 ... [详细]
  • 深度强化学习Policy Gradient基本实现
    全文共2543个字,2张图,预计阅读时间15分钟。基于值的强化学习算法的基本思想是根据当前的状态,计算采取每个动作的价值,然 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 浏览器中的异常检测算法及其在深度学习中的应用
    本文介绍了在浏览器中进行异常检测的算法,包括统计学方法和机器学习方法,并探讨了异常检测在深度学习中的应用。异常检测在金融领域的信用卡欺诈、企业安全领域的非法入侵、IT运维中的设备维护时间点预测等方面具有广泛的应用。通过使用TensorFlow.js进行异常检测,可以实现对单变量和多变量异常的检测。统计学方法通过估计数据的分布概率来计算数据点的异常概率,而机器学习方法则通过训练数据来建立异常检测模型。 ... [详细]
  • Icurrentlyhaveafunction[C#]whichtakesabyte[]andanalignmenttosetitto,butduringencr ... [详细]
  • [二分图]JZOJ 4612 游戏
    DescriptionInputOutputSampleInput44#****#****#*xxx#SampleOutput5DataConstraint分析非常眼熟࿰ ... [详细]
  • 微信小程序官方组件展示之表单组件input源码
    以下将展示微信小程序之表单组件input源码官方组件能力,组件样式仅供参考,开发者可根据自身需求定义组件样式,具体属性参数详见小程序开发文档。功能描述:输入框。该组件是原生组件, ... [详细]
  • 使用nodejs爬取b站番剧数据,计算最佳追番推荐
    本文介绍了如何使用nodejs爬取b站番剧数据,并通过计算得出最佳追番推荐。通过调用相关接口获取番剧数据和评分数据,以及使用相应的算法进行计算。该方法可以帮助用户找到适合自己的番剧进行观看。 ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 本文介绍了在Python张量流中使用make_merged_spec()方法合并设备规格对象的方法和语法,以及参数和返回值的说明,并提供了一个示例代码。 ... [详细]
  • 背景应用安全领域,各类攻击长久以来都危害着互联网上的应用,在web应用安全风险中,各类注入、跨站等攻击仍然占据着较前的位置。WAF(Web应用防火墙)正是为防御和阻断这类攻击而存在 ... [详细]
  • 建立分类感知器二元模型对样本数据进行分类
    本文介绍了建立分类感知器二元模型对样本数据进行分类的方法。通过建立线性模型,使用最小二乘、Logistic回归等方法进行建模,考虑到可能性的大小等因素。通过极大似然估计求得分类器的参数,使用牛顿-拉菲森迭代方法求解方程组。同时介绍了梯度上升算法和牛顿迭代的收敛速度比较。最后给出了公式法和logistic regression的实现示例。 ... [详细]
author-avatar
文人博客
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有