热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Caffe实战(三):从零构建ResNet之网络架构详解

本文将详细介绍如何使用Caffe构建ResNet网络架构,重点在于网络各层的定义及其参数配置。我们将通过具体的代码示例,帮助读者理解ResNet的核心组件及其实现细节。

在前文中,我们已经完成了两个关键函数的定义,接下来我们将基于这些函数构建完整的 ResNet 网络。为了更好地理解网络的构建过程,我们再次回顾 ResNet 的结构表:



































层名

输出尺寸

20层 ResNet

Conv1

32 X 32

Kernel_size=3 X 3
Num_output = 16
Stride = 1
Pad = 1

Conv2_x

32 X 32

{3X3,16; 3X3,16} X 3

Conv3_x

16 X 16

{3X3,32; 3X3,32} X 3

Conv4_x

8 X 8

{3X3,64; 3X3,64} X 3

InnerProduct

1 X 1

Average pooling
10-d fc


在 Conv1 层中,我们对输入图像进行一次卷积处理。从 Conv2_x 到 Conv4_x,每个阶段包含 3 个残差块,每个块内的卷积核数量依次翻倍(16, 32, 64),同时特征图的尺寸逐渐减半(32, 16, 8)。尽管输入图像的实际尺寸可能是 28 X 28,但为了简化说明,我们假设输入图像的尺寸为 32 X 32。


大多数卷积层使用 3 X 3 的卷积核,且通常设置 padding 为 1,以保持输出尺寸与输入尺寸一致。例如,从 Conv1 到 Conv2_x,由于通道数相同,可以直接将输入和输出相加。然而,从 Conv2_x 到 Conv3_x,以及从 Conv3_x 到 Conv4_x,由于通道数不同且输出尺寸不同,我们采用了论文中提出的 B 方法,即使用 1 X 1 的卷积核来调整输入的维度,使其与输出匹配。这种设计确保了残差连接的有效性。


在 ResNet 的实现中,当投影步长(projection_stride)为 1 时,表示输入和输出的维度相同,可以直接相加;当投影步长为 2 时,则需要使用 1 X 1 的卷积核,步长设为 2,以使输出尺寸减半。以下是 ResNet 函数的一个示例实现:




图 14 展示了网络的结构图。


使用 draw_net.py 工具绘制网络结构图时,至少需要提供两个参数:prototxt 文件的路径和图像保存的位置。参数 --rankdir=TB 表示网络结构从上到下绘制,其他选项包括 BT、LR(从左到右)、RL。需要注意的是,某些高版本的 Caffe 可能会出现 ‘int’ object has no attribute '_values' 的错误,解决方法请参考:https://github.com/BVLC/caffe/issues/5324



图 15 显示了 draw_net.py 绘制的网络结构图。由于图像较长,这里仅展示了部分结构。接下来,我们需要生成 solver 的 prototxt 文件。


步骤 3:创建 solver 的 prototxt 文件


在 caffe-master/examples/pycaffe 目录下有一个 tools.py 文件,该文件可以帮助我们生成所需的 solver prototxt 文件。首先,在 /ResNet 目录下创建一个 tools 文件夹,并将 tools.py 文件复制到该文件夹中。为了确保系统能够找到 tools.py 文件,我们需要在 init_path.py 中添加以下代码:


tools_path = osp.join(this_dir, 'tools')
add_path(tools_path)

这样,ResNet/tools 路径就被添加到系统中,系统可以找到 tools.py 文件。接着,在 mydemo.py 文件的开头,在 import init_path 之后添加 import tools,以导入 tools.py 文件。最后,在 mydemo.py 文件的末尾,make_net() 函数之后添加以下代码:


# 将内容写入 res_net_model 文件夹中的 res_net_solver.prototxt
solver_dir = this_dir + '/res_net_model/res_net_solver.prototxt'
solver_prototxt = tools.CaffeSolver()
solver_prototxt.write(solver_dir)

生成的 res_net_solver.prototxt 文件包含了各种 solver 参数,如 base_lr(基础学习率)设置为 0.1,lr_policy 设置为 multistep 等。具体参数如下所示:


def __init__(self, testnet_prototxt_path=this_dir+"/../res_net_model/test.prototxt",
trainnet_prototxt_path=this_dir+"/../res_net_model/train.prototxt", debug=False):

self.sp = {}

# 关键参数
self.sp['base_lr'] = '0.1'
self.sp['momentum'] = '0.9'

# 速度相关
self.sp['test_iter'] = '100'
self.sp['test_interval'] = '500'

# 显示设置
self.sp['display'] = '100'
self.sp['snapshot'] = '2500'
self.sp['snapshot_prefix'] = '/home/your_name/ResNet/res_net_model/snapshot/snapshot'

# 学习率策略
self.sp['lr_policy'] = 'multistep'
self.sp['step_value'] = '32000'
self.sp['step_value1'] = '48000'

# 其他重要参数
self.sp['gamma'] = '0.1'
self.sp['weight_decay'] = '0.0001'
self.sp['train_net'] = '"' + trainnet_prototxt_path + '"'
self.sp['test_net'] = '"' + testnet_prototxt_path + '"'

# 很少更改的参数
self.sp['max_iter'] = '100000'
self.sp['test_initialization'] = 'false'
self.sp['average_loss'] = '25'
self.sp['iter_size'] = '1'

if debug:
self.sp['max_iter'] = '12'
self.sp['test_iter'] = '1'
self.sp['test_interval'] = '4'
self.sp['display'] = '1'

这些参数的设置依据论文中的推荐值。例如,基础学习率 base_lr 设为 0.1,学习率策略 lr_policy 设为 multistep,分别在 32000 和 48000 迭代次数时将学习率降低 10 倍。权重衰减 weight_decay 设为 0.0001,动量 momentum 设为 0.9。其他参数如 max_iter(最大迭代次数)、test_iter(测试迭代次数)、test_interval(测试间隔)等根据实际需求进行设置。


为了确保工具文件能够正确运行,我们还需要在 tools.py 文件的开头添加以下代码:


import os.path as osp
this_dir = osp.dirname(__file__)

最终,mydemo.py 文件的末尾代码应如下所示:


if __name__ == '__main__':

make_net()

# 定义生成 solver 的路径
solver_dir = this_dir + '/res_net_model/res_net_solver.prototxt'
solver_prototxt = tools.CaffeSolver()
# 将内容写入 res_net_model 文件夹中的 res_net_solver.prototxt
solver_prototxt.write(solver_dir)

执行 mydemo.py 后,生成的 solver prototxt 文件将按照上述设置生成。注意,生成后需要手动将 stepvalue1 改为 stepvalue,因为 tools.py 不能存储相同名称的参数,否则会导致参数被覆盖。


至此,我们已经成功生成了 solver 的 prototxt 文件。


推荐阅读
author-avatar
手机用户2702932962_848
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有