热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

遇到问题EGNet(6):

defforward里面的内容不显示吗?当调用vgg的是后会显示,默认是resnetclassvgg16(nn.Module):def__init__(

def forward里面的内容不显示吗?当调用vgg的是后会显示,默认是resnet


class vgg16(nn.Module):def __init__(self):super(vgg16, self).__init__()self.cfg = {'tun': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'tun_ex': [512, 512, 512]}self.extract = [8, 15, 22, 29] # feature map in 'tun' -> c(2), c(3), c(4), c(5) # [3, 8, 15, 22, 29]self.extract_ex = [5]self.base = nn.ModuleList(vgg(self.cfg['tun'], 3))self.base_ex = vgg_ex(self.cfg['tun_ex'], 512)# init paramiterfor m in self.modules():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, 0.01)elif isinstance(m, nn.BatchNorm2d):m.weight.data.fill_(1)m.bias.data.zero_()def load_pretrained_model(self, model):self.base.load_state_dict(model)def forward(self, x, multi=0):tmp_x = []# through the 'tun' layer by layerfor k in range(len(self.base)): # 'tun' 37 layer -> 64 64, 128 128, 256 256 256, 512 512 512, 512 512 512print('=>len(self.base)', len(self.base)) # not show ???x = self.base[k](x) # get new x through every layer in 'tun'if k in self.extract: # feature map in 'tun' -> c(2), c(3), c(4), c(5) -> self.extract = [8, 15, 22, 29]tmp_x.append(x)x = self.base_ex(x) # 'tun_ex' layer -> 512 512 512tmp_x.append(x) # # feature map in 'tun_ex' -> c(6)if multi == 1:tmp_y = []tmp_y.append(tmp_x[0]) # feature map in 'tun' -> c(2)return tmp_yelse:return tmp_x

 

 

 

 

 


推荐阅读
author-avatar
颂歌万岁_119
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有