热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Pytorch中nn.Module中self.register_buffer的解释

Pytorch中nn.Module中的self.register_b

self.register_buffer作用解释

今天遇到了这样一种用法,self.register_buffer(‘name’,Tensor),该方法的作用在于定义一组参数。该组参数在模型训练时不会更新(即调用optimizer.step()后该组参数不会变化,只可人为地改变它们的值),但是该组参数又作为模型参数不可或缺的一部分。


实验

四种方式初始化模型中的参数



  1. 定义常见模型时的操作

  2. 使用register_buffer()定义一组参数

  3. 使用register_parameter()定义一组参数

  4. 使用python类的属性方式定义一组变量

import torch
import torch.nn as nn
from collections import OrderedDict
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()

#(1)定义常见模型时的操作
self.param_nn = nn.Sequential(OrderedDict([
('conv',nn.Conv2d(1,1,3,bias=False)),
('fc',nn.Linear(1,2,bias=False))
]))

#(2)使用register_buffer()定义一组参数
self.register_buffer('reg_buf',torch.randn(1,2))
#(3)使用register_parameter()定义一组参数
self.register_parameter('reg_param',nn.Parameter(torch.randn(1,2)))
#(4)使用python类的属性方式定义一组变量
self.param_attr = torch.randn(1,2)
net = Model()


问题1:哪些参数会在模型训练时被更新?

因为定义优化器时会传入一个参数net.parameters,所以在模型训练时更新的参数可以通过list(net.named_parameters())查看

结果说明,只有方式(1)和方式(3)定义的参数可以被更新


问题2:模型中的参数到底有哪些?

模型中的所有参数都装在state_dict()中,所以可以通过net.state_dict()方式查看

结果说明,只有方式(4)的参数不在模型的参数列表,没有被模型训练时更新的参数reg_buf,依然在模型的参数列表里


self.register_buffer()的使用方法

  1. 传入参数:第一个参数传入一个字符串,表示这组参数的名字,第二个就是tensor形式的参数

  2. 在模型定义中调用:使用self.name方法,本例中就是self. reg_buf

  3. 在实例化模型后调用:使用net.buffers()方法。


其他知识

实际上,Pytorch定义的模型用OrderedDict()方式记录这三种类型,分别保存在self._modules, self._parameters 和self.buffer三个私有属性中

在模型实例化后可以用以下方法看三个私有属性中的变量
net.modules()
net.parameters()
net.buffers()

self._parameters 和net.parameters() 的返回值并不相同,self._parameters只记录了使用self.register_parameter()定义的参数,而net.parameters()返回所有可学习参数。

参考:
[1]Pytorchnn.Module中的self.register_buffer()解析



推荐阅读
author-avatar
ai琳伟_261
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有