热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

tensorflow课堂笔记(四)

学习率学习率learning_rate:每次参数更新的幅度wn1wn-learning_rate▽相当于每次在梯度反方向减少的幅度,因为梯

学习率

"""
学习率 learning_rate:每次参数更新的幅度
wn+1 = wn - learning_rate▽
相当于每次在梯度反方向减少的幅度,因为梯度是增加最大的方向,我们要找到极小值
我们优化参数的目的就是让loss损失函数最小,所以每次都减少一点梯度方向的值
"""
#coding utf-8
import tensorflow as tf
#定义待优化参数w初值赋5
"""
w = tf.Variable(tf.constant(5, dtype=tf.float32))
#定义损失函数
loss = tf.square(w+1)
#定义反向传播方法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
学习率如果是1,则不收敛,如果为0.0001,则收敛速度很慢
指数衰减学习率
learning_rate = LEARNING_RATE_BASE*LEARNING_RATE_DECAY*(global_step/LEARNING_RATE_STEP)
右边参数依次为 学习率初始值,学习率衰减率(0,1),运行了几轮/多少轮更新一次学习率
global_step = tf.Variable(0, trainable=False)
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,
LEARNING_RATE_STEP,
LEARNING_RATE_DECAY,
staircase=True)
#生成会话,训练40轮
with tf.Session() as sess:init_op = tf.global_variables_initializer()sess.run(init_op)for i in range(40):sess.run(train_step)w_val = sess.run(w)loss_val = sess.run(loss)print("After %d steps: w is %f, loss is %f"%(i,w_val,loss_val))收敛结果:
After 39 steps: w is -1.000000, loss is 0.000000
"""
#下面是指数衰减学习率代码
LEARNING_RATE_BASE = 0.1 #最初学习率
LEARNING_RATE_DECAY = 0.99 #学习率衰减率
LEARNING_RATE_STEP = 1 #喂入多少轮BATCH_STEP后,更新依次学习率,总样本数/BATCH_SIZE
#运行了几轮BATCH_SIZE的计数器,初值给0,设为不被训练
global_step = tf.Variable(0, trainable=False)
#定义指数下降学习率
learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,LEARNING_RATE_STEP,\LEARNING_RATE_DECAY,staircase=True)
#定义待优化参数,赋初值为5
w = tf.Variable(tf.constant(5, dtype=tf.float32))
#定义损失函数
loss = tf.square(w+1)
#定义反向传播方法
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step)
#声称会话,训练40轮
with tf.Session() as sess:init_op = tf.global_variables_initializer()sess.run(init_op)STEPS=40for i in range(STEPS):sess.run(train_step)w_val = sess.run(w)loss_val = sess.run(loss)print("After %d steps,global_step is %f, w is %f,learning_rate is %f loss is %f"%(i+1,sess.run(global_step),w_val,sess.run(learning_rate),loss_val))
"""
训练结果:
After 40 steps,global_step is 40.000000, w is -0.995731,learning_rate is 0.066897 loss is 0.000018
"""

滑动平均

"""
滑动平均(影子值):记录了每个参数一段时间内过往值的平均,增加了模型的泛化性
影子=衰减率影子+(1-衰减率)*参数 影子初值=参数初值
ema = tf.train.ExponentialMovingAverage(
衰减率MOVING_AVERAGE_DECAY,
当前轮数global_step)
mea_op=ema.apple([])
ema_op=ema.apply(tf.trainable_variables())
with if.control_dependencies([train_step,ema_op]):train_op=tf.no_op(name='train')ema.average(参数名) 查看某参数的滑动平均值
"""
#coding utf-8
import tensorflow as tf
#1定义变量及滑动平均类
#定义一个32为浮点变量,初始值为0.0,这个代码就是不断更新w1参数,优化w1参数,滑动平均做了个w1的影子
w1 = tf.Variable(0, dtype=tf.float32)
#定义num_updates(NN的迭代轮数), 初始值为0, 不可被优化
global_step = tf.Variable(0, trainable=False)
#实例化滑动平均类,给删减率为0.99,当前轮数global_step
MOVING_AVERAGE_DECAY = 0.99
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
#ema.apply后的括号里是更新列表,每次运行sess.run(ema_op)时,对更新列表中元素取滑动平均
#在实际运用中会使用tf.trainable_variables()自动将所有待训练的参数汇总为列表
#ema_op = ema.apply([w1])
ema_op = ema.apply(tf.trainable_variables())#2查看不同迭代中变量取值的变化
with tf.Session() as sess:#初始化init_op = tf.global_variables_initializer()sess.run(init_op)#用ema.averge(w1)获取w1滑动平均值#打印参数w1和其滑动平均值print(sess.run([w1,ema.average(w1)]))#参数w1赋值为1sess.run(tf.assign(w1, 1))sess.run(ema_op)print(sess.run([w1, ema.average(w1)]))#更新step和w1的值,模拟出100轮迭代后,参数w1变为10sess.run(tf.assign(global_step, 100)) #赋值操作sess.run(tf.assign(w1, 10)) #赋值操作sess.run(ema_op) #更新一次滑动平均print(sess.run([w1, ema.average(w1)]))#每次sess.run会更新依次w1滑动平均值sess.run(ema_op)print(sess.run([w1, ema.average(w1)]))sess.run(ema_op)print(sess.run([w1, ema.average(w1)]))sess.run(ema_op)print(sess.run([w1, ema.average(w1)]))sess.run(ema_op)print(sess.run([w1, ema.average(w1)]))sess.run(ema_op)print(sess.run([w1, ema.average(w1)]))sess.run(ema_op)print(sess.run([w1, ema.average(w1)]))"""
运行结果:
[0.0, 0.0]
[1.0, 0.9]
[10.0, 1.6445453]
[10.0, 2.3281732]
[10.0, 2.955868]
[10.0, 3.532206]
[10.0, 4.061389]
[10.0, 4.547275]
[10.0, 4.9934072]
"""

 


推荐阅读
author-avatar
邱喷壶_381
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有