热门标签 | HotTags
当前位置:  开发笔记 > 人工智能 > 正文

可视化pytorch模型中不同BN层的runningmean曲线实例

这篇文章主要介绍了可视化pytorch模型中不同BN层的runningmean曲线实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

加载模型字典

逐一判断每一层,如果该层是bn 的 running mean,就取出参数并取平均作为该层的代表

对保存的每个BN层的数值进行曲线可视化

from functools import partial
import pickle
import torch
import matplotlib.pyplot as plt

pth_path = 'checkpoint.pth'

pickle.load = partial(pickle.load, encoding="latin1")
pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
pretrained_dict = torch.load(pth_path, map_location=lambda storage, loc: storage, pickle_module=pickle)
pretrained_dict = pretrained_dict['state_dict']

means = []
for name, param in pretrained_dict.items():
 print(name)
 if 'running_mean' in name:
  means.append(mean.numpy())

layers = [i for i in range(len(means))]

plt.plot(layers, means, color='blue')
plt.legend()
plt.xticks(layers)
plt.xlabel('layers')
plt.show()

补充知识:关于pytorch中BN层(具体实现)的一些小细节

最近在做目标检测,需要把训好的模型放到嵌入式设备上跑前向,因此得把各种层的实现都用C手撸一遍,,,此为背景。

其他层没什么好说的,但是BN层这有个小坑。pytorch在打印网络参数的时候,只打出weight和bias这两个参数。咦,说好的BN层有四个参数running_mean、running_var 、gamma 、beta的呢?一开始我以为是pytorch把BN层的计算简化成weight * X + bias,但马上反应过来应该没这么简单,因为pytorch中只有可学习的参数才称为parameter。上网找了一些资料但都没有说到这么细的,毕竟大部分用户使用时只要模型能跑起来就行了,,,于是开始看BN层有哪些属性,果然发现了熟悉的running_mean和running_var,原来pytorch的BN层实现并没有不同。这里吐个槽:为啥要把gamma和beta改叫weight、bias啊,很有迷惑性的好不好,,,

扯了这么多,干脆捋一遍pytorch里BN层的具体实现过程,帮自己理清思路,也可以给大家提供参考。再吐槽一下,在网上搜“pytorch bn层”出来的全是关于这一层怎么用的、初始化时要输入哪些参数,没找到一个pytorch中BN层是怎么实现的,,,

众所周知,BN层的输出Y与输入X之间的关系是:Y = (X - running_mean) / sqrt(running_var + eps) * gamma + beta,此不赘言。其中gamma、beta为可学习参数(在pytorch中分别改叫weight和bias),训练时通过反向传播更新;而running_mean、running_var则是在前向时先由X计算出mean和var,再由mean和var以动量momentum来更新running_mean和running_var。所以在训练阶段,running_mean和running_var在每次前向时更新一次;在测试阶段,则通过net.eval()固定该BN层的running_mean和running_var,此时这两个值即为训练阶段最后一次前向时确定的值,并在整个测试阶段保持不变。

以上这篇可视化pytorch 模型中不同BN层的running mean曲线实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。


推荐阅读
  • 如何用GPU服务器运行Python
    如何用GPU服务器运行Python-目录前言一、服务器登录1.1下载安装putty1.2putty远程登录 1.3查看GPU、显卡常用命令1.4Linux常用命令二、 ... [详细]
  • 本文详细介绍了如何在Python和PyTorch环境中实现Tensor与NumPy数组之间的转换,以及PIL图像对象与NumPy数组之间的相互转换。内容包括具体的转换函数及其使用示例。 ... [详细]
  • 本文介绍了一种方法,通过使用Python的ctypes库来调用C++代码。具体实例为实现一个简单的加法器,并详细说明了从编写C++代码到编译及最终在Python中调用的全过程。 ... [详细]
  • 解决Jupyter Notebook 中无法找到 TensorFlow 的问题
    本文记录了解决 Jupyter Notebook 在特定环境中无法识别已安装的 TensorFlow 的方法。主要原因是 Jupyter 默认在 base 环境中运行,而 TensorFlow 可能在其他环境中。通过配置 Jupyter 使其能够访问目标环境中的 TensorFlow。 ... [详细]
  • 本文介绍了如何使用 Google Colab 的免费 GPU 资源进行深度学习应用开发。Google Colab 是一个无需配置即可使用的云端 Jupyter 笔记本环境,支持多种深度学习框架,并且提供免费的 GPU 计算资源。 ... [详细]
  • 目录预备知识导包构建数据集神经网络结构训练测试精度可视化计算模型精度损失可视化输出网络结构信息训练神经网络定义参数载入数据载入神经网络结构、损失及优化训练及测试损失、精度可视化qu ... [详细]
  • Python 中变量类型的确定与默认类型解析
    本文详细探讨了 Python 中变量类型的确定方式及其默认类型,帮助初学者更好地理解变量类型的概念。 ... [详细]
  • Vision Transformer (ViT) 和 DETR 深度解析
    本文详细介绍了 Vision Transformer (ViT) 和 DETR 的工作原理,并提供了相关的代码实现和参考资料。通过观看教学视频和阅读博客,对 ViT 的全流程进行了详细的笔记整理,包括代码详解和关键概念的解释。 ... [详细]
  • PyTorch实用技巧汇总(持续更新中)
    空洞卷积(Dilated Convolutions)在卷积操作中通过在卷积核元素之间插入空格来扩大感受野,这一过程由超参数 dilation rate 控制。这种技术在保持参数数量不变的情况下,能够有效地捕捉更大范围的上下文信息,适用于多种视觉任务,如图像分割和目标检测。本文将详细介绍空洞卷积的计算原理及其应用场景。 ... [详细]
  • 在Conda环境中高效配置并安装PyTorch和TensorFlow GPU版的方法如下:首先,创建一个新的Conda环境以避免与基础环境发生冲突,例如使用 `conda create -n pytorch_gpu python=3.7` 命令。接着,激活该环境,确保所有依赖项都正确安装。此外,建议在安装过程中指定CUDA版本,以确保与GPU兼容性。通过这些步骤,可以确保PyTorch和TensorFlow GPU版的顺利安装和运行。 ... [详细]
  • 【图像分类实战】利用DenseNet在PyTorch中实现秃头识别
    本文详细介绍了如何使用DenseNet模型在PyTorch框架下实现秃头识别。首先,文章概述了项目所需的库和全局参数设置。接着,对图像进行预处理并读取数据集。随后,构建并配置DenseNet模型,设置训练和验证流程。最后,通过测试阶段验证模型性能,并提供了完整的代码实现。本文不仅涵盖了技术细节,还提供了实用的操作指南,适合初学者和有经验的研究人员参考。 ... [详细]
  • 通过使用CIFAR-10数据集,本文详细介绍了如何快速掌握Mixup数据增强技术,并展示了该方法在图像分类任务中的显著效果。实验结果表明,Mixup能够有效提高模型的泛化能力和分类精度,为图像识别领域的研究提供了有价值的参考。 ... [详细]
  • 在 PyTorch 的 `CrossEntropyLoss` 函数中,当目标标签 `target` 为类别 ID 时,实际上会进行 one-hot 编码处理。例如,假设总共有三个类别,其中一个类别的 ID 为 2,则该标签会被转换为 `[0, 0, 1]`。这一过程简化了多分类任务中的损失计算,使得模型能够更高效地进行训练和评估。此外,`CrossEntropyLoss` 还结合了 softmax 激活函数和负对数似然损失,进一步提高了模型的性能和稳定性。 ... [详细]
  • 本文探讨了BERT模型在自然语言处理领域的应用与实践。详细介绍了Transformers库(曾用名pytorch-transformers和pytorch-pretrained-bert)的使用方法,涵盖了从模型加载到微调的各个环节。此外,还分析了BERT在文本分类、情感分析和命名实体识别等任务中的性能表现,并讨论了其在实际项目中的优势和局限性。 ... [详细]
  • 在Windows环境下离线安装PyTorch GPU版时,首先需确认系统配置,例如本文作者使用的是Win8、CUDA 8.0和Python 3.6.5。用户应根据自身Python和CUDA版本,在PyTorch官网查找并下载相应的.whl文件。此外,建议检查系统环境变量设置,确保CUDA路径正确配置,以避免安装过程中可能出现的兼容性问题。 ... [详细]
author-avatar
林伯爵
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有