tensorflow中的学习率调整策略

摘要:
name=无)learning_Rate全局初始学习率_步长当前训练多少次迭代_步长在每xxx步后更改学习率_速率用于计算更改后的学习率_ Step/decay_计算步长结果是浮动还是向下舍入的公式为:

通常为了模型能更好的收敛,随着训练的进行,希望能够减小学习率,以使得模型能够更好地收敛,找到loss最低的那个点.

tensorflow中提供了多种学习率的调整方式.在https://www.tensorflow.org/api_docs/python/tf/compat/v1/train搜索decay.可以看到有多种学习率的衰减策略.

  • cosine_decay
  • exponential_decay
  • inverse_time_decay
  • linear_cosine_decay
  • natural_exp_decay
  • noisy_linear_cosine_decay
  • polynomial_decay

本文介绍两种学习率衰减策略,指数衰减和多项式衰减.

tf.compat.v1.train.exponential_decay(
    learning_rate,
    global_step,
    decay_steps,
    decay_rate,
    staircase=False,
    name=None
)

learning_rate 初始学习率
global_step 当前总共训练多少个迭代
decay_steps 每xxx steps后变更一次学习率
decay_rate 用以计算变更后的学习率
staircase: global_step/decay_steps的结果是float型还是向下取整

学习率的计算公式为:decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)

我们用一段测试代码来绘制一下学习率的变化情况.

#coding=utf-8
import matplotlib.pyplot as plt
import tensorflow as tf

x=[]
y=[]
N = 200 #总共训练200个迭代

num_epoch = tf.Variable(0, name='global_step', trainable=False)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for num_epoch in range(N):
        ##初始学习率0.5,每10个迭代更新一次学习率.
        learing_rate_decay = tf.train.exponential_decay(learning_rate=0.5, global_step=num_epoch, decay_steps=10, decay_rate=0.9, staircase=False)
        learning_rate = sess.run([learing_rate_decay])
        y.append(learning_rate)

#print(y)

x = range(N)
fig = plt.figure()
ax.set_xlabel('step')
ax.set_ylabel('learing rate')
plt.plot(x, y, 'r', linewidth=2)
plt.show()

结果如图:
tensorflow中的学习率调整策略第1张

  • 多项式衰减
tf.compat.v1.train.polynomial_decay(
    learning_rate,
    global_step,
    decay_steps,
    end_learning_rate=0.0001,
    power=1.0,
    cycle=False,
    name=None
)

设定一个初始学习率,一个终止学习率,然后线性衰减.cycle控制衰减到end_learning_rate后是否保持这个最小学习率不变,还是循环往复. 过小的学习率会导致收敛到局部最优解,循环往复可以一定程度上避免这个问题.
根据cycle是否为true,其计算方式不同,如下:
tensorflow中的学习率调整策略第2张

#coding=utf-8
import matplotlib.pyplot as plt
import tensorflow as tf

x=[]
y=[]
z=[]
N = 200 #总共训练200个迭代

num_epoch = tf.Variable(0, name='global_step', trainable=False)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for num_epoch in range(N):
        ##初始学习率0.5,每10个迭代更新一次学习率.
        learing_rate_decay = tf.train.polynomial_decay(learning_rate=0.5, global_step=num_epoch, decay_steps=10, end_learning_rate=0.0001, cycle=False)
        learning_rate = sess.run([learing_rate_decay])
        y.append(learning_rate)
        
        learing_rate_decay2 = tf.train.polynomial_decay(learning_rate=0.5, global_step=num_epoch, decay_steps=10, end_learning_rate=0.0001, cycle=True)
        learning_rate2 = sess.run([learing_rate_decay2])
        z.append(learning_rate2)
#print(y)

x = range(N)
fig = plt.figure()
ax.set_xlabel('step')
ax.set_ylabel('learing rate')
plt.plot(x, y, 'r', linewidth=2)
plt.plot(x, z, 'g', linewidth=2)
plt.show()

绘图结果如下:
tensorflow中的学习率调整策略第3张
cycle为false时对应红线,学习率下降到0.0001后不再下降. cycle=true时,下降到0.0001后再突变到一个更大的值,在继续衰减,循环往复.

在代码里,通常通过参数去控制不同的学习率策略,例如

def _configure_learning_rate(num_samples_per_epoch, global_step):
  """Configures the learning rate.

  Args:
    num_samples_per_epoch: The number of samples in each epoch of training.
    global_step: The global_step tensor.

  Returns:
    A `Tensor` representing the learning rate.

  Raises:
    ValueError: if
  """
  # Note: when num_clones is > 1, this will actually have each clone to go
  # over each epoch FLAGS.num_epochs_per_decay times. This is different
  # behavior from sync replicas and is expected to produce different results.
  decay_steps = int(num_samples_per_epoch * FLAGS.num_epochs_per_decay /
                    FLAGS.batch_size)

  if FLAGS.sync_replicas:
    decay_steps /= FLAGS.replicas_to_aggregate

  if FLAGS.learning_rate_decay_type == 'exponential':
    return tf.train.exponential_decay(FLAGS.learning_rate,
                                      global_step,
                                      decay_steps,
                                      FLAGS.learning_rate_decay_factor,
                                      staircase=True,
                                      name='exponential_decay_learning_rate')
  elif FLAGS.learning_rate_decay_type == 'fixed':
    return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
  elif FLAGS.learning_rate_decay_type == 'polynomial':
    return tf.train.polynomial_decay(FLAGS.learning_rate,
                                     global_step,
                                     decay_steps,
                                     FLAGS.end_learning_rate,
                                     power=1.0,
                                     cycle=False,
                                     name='polynomial_decay_learning_rate')
  else:
    raise ValueError('learning_rate_decay_type [%s] was not recognized' %
                     FLAGS.learning_rate_decay_type)

推荐一篇:https://blog.csdn.net/dcrmg/article/details/80017200 对各种学习率衰减策略描述的很详细.并且都有配图,可以很直观地看到各种衰减策略下学习率变换情况.

免责声明:文章转载自《tensorflow中的学习率调整策略》仅用于学习参考。如对内容有疑问,请及时联系本站处理。

上篇g_signal_connect 与 g_signal_connect_swappedC语言精要总结-内存地址对齐与struct大小判断篇下篇

宿迁高防,2C2G15M,22元/月;香港BGP,2C5G5M,25元/月 雨云优惠码:MjYwNzM=

相关文章

【Android Studio】为Android Studio设置HTTP代理

【Android Studio】为Android Studio设置HTTP代理  大陆的墙很厚很高,初次安装Android Studio下载SDK等必定失败,设置代理方法如下: 1.  到android studio安装目录,打开bin目录,编辑idea.properties, 在文件末尾添加: disable.android.first.run=tru...

博客基础_django_python从入门到实践_添加主题_添加条目_编辑条目

 要求及文件   用户可以添加新主题,添加新条目,以及编辑既有条目    forms.py  urls.py  views.py  html    new_topic.html  new_entry.html    edit_entry.html 添加新主题 new_topic.html   topics.html   添加新条目 new_entry.ht...

前端插件的使用

最近在做后台系统的图表报表显示 用到了很多的前端js插件 jquery插件,下面罗列一下 1.多选插件  bootstrap-multiselect  select2 bootstrap-select 我自己代码里用的是bootstrap-multiselect 和 bootstrap--select bootstrap-multiselect 插件 几个...

deeplab系列总结(deeplab v1& v2 & v3 & v3+)

deeplab系列总结(deeplab v1& v2 & v3 & v3+) Deeplab v1&v2paper: deeplab v1 && deeplab v2 远古版本的deeplab系列,就像RCNN一样,其实了解了后面的v3和v3+就可以不太管这些了(个人拙见)。但是为了完整性和连贯性,所以读了这...

数组中的filter函数,递归以及一些应用。

当我们用一个东西时候我们必须知道的是?why---where----how---when。一个东西我们为什么用?在哪用?怎么用?何时用?而不是被动的去接受一些东西。用在js里边我觉得也会试用。一直追求源生js,虽然也都背过好多东西,但是随着时间的流逝,工作的繁忙都忘了,有时甚至一点印象都没有,这让我开始思考我的学习方法了已经思维方式了。我们要记得不是简单的...

音视频文件的码率与大小计算

编码率/比特率直接与文件体积有关。且编码率与编码格式配合是否合适,直接关系到视频文件是否清晰。在视频编码领域,比特率常翻译为编码率,单位是Kbps,例如800Kbps其中, 1K=1024 1M=1024Kb 为 比特(bit) 这个就是电脑文件大小的计量单位,1KB=8Kb,区分大小写,B代表字节(Byte) s 为 秒(second) p 为 每(per...