热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

CS231n作业笔记2.3:优化算法Momentum,RMSProp,Adam

CS231n简介详见CS231n课程笔记1:Introduction。本文都是作者自己的思考,正确性未经过验证,欢迎指教。作业笔记本部分实现的是Momentum,RMSProb,

CS231n简介

详见 CS231n课程笔记1:Introduction。
本文都是作者自己的思考,正确性未经过验证,欢迎指教。

作业笔记

本部分实现的是Momentum,RMSProb, Adam三种优化算法,优化算法是用于从随机点出发,逐渐找到局部最优点的算法。关于各种优化算法的详细介绍,请参考CS231n课程笔记6.1:优化迭代算法之SGD,Momentum,Netsterov Momentum,AdaGrad,RMSprop,Adam。

1. Momentum

方程:

v = mu*v - learning_rate*dx
x += v

代码:

  v = v*config['momentum']-config['learning_rate']*dw
next_w = w + v

2. RMSProp

方程:

cache = cache*decay_rate + (1-decay_rate)*dx*dx
x -= learning_rate * dx/(sqrt(cache)+1e-7)

代码:

  config['cache'] = config['cache']*config['decay_rate'] + (1-config['decay_rate'])*dx*dx
next_x = x - config['learning_rate']*dx/np.sqrt(config['cache']+config['epsilon'])

3. Adam

此算法需要注意的是ppt中的方程是错误的,正确方法如下图,主要区别在于bias correction的部分,不更新m和v,详见Adam: A Method for Stochastic Optimization
还要注意t的更新,此部分也没有显示的写在ppt里。
Adam
代码:

  m = config['m']*config['beta1']+(1-config['beta1'])*dx
v = config['v']*config['beta2']+(1-config['beta2'])*dx*dx
config['t'] += 1
mb = m / (1 - config['beta1']**config['t'])
vb = v / (1 - config['beta2']**config['t'])
next_x = x - config['learning_rate']*mb/(np.sqrt(vb)+config['epsilon'])
config['m'] = m
config['v'] = v

推荐阅读
author-avatar
qt70ewi
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有