未验证 提交 602cb6a5 编写于 作者: Q qingqing01 提交者: GitHub

Enhance linear_lr_warmup (#18463)

* make it support float/int learning as input.
上级 74538573
......@@ -406,7 +406,7 @@ paddle.fluid.layers.polynomial_decay (ArgSpec(args=['learning_rate', 'decay_step
paddle.fluid.layers.piecewise_decay (ArgSpec(args=['boundaries', 'values'], varargs=None, keywords=None, defaults=None), ('document', 'd9f654117542c6b702963dda107a247f'))
paddle.fluid.layers.noam_decay (ArgSpec(args=['d_model', 'warmup_steps'], varargs=None, keywords=None, defaults=None), ('document', 'fd57228fb76195e66bbcc8d8e42c494d'))
paddle.fluid.layers.cosine_decay (ArgSpec(args=['learning_rate', 'step_each_epoch', 'epochs'], varargs=None, keywords=None, defaults=None), ('document', 'f0d65d8c89d0fe78051ca689daa15e35'))
paddle.fluid.layers.linear_lr_warmup (ArgSpec(args=['learning_rate', 'warmup_steps', 'start_lr', 'end_lr'], varargs=None, keywords=None, defaults=None), ('document', '0b529386b62cc73d27b711a5f618f3e4'))
paddle.fluid.layers.linear_lr_warmup (ArgSpec(args=['learning_rate', 'warmup_steps', 'start_lr', 'end_lr'], varargs=None, keywords=None, defaults=None), ('document', 'dc7292c456847ba41cfd318e9f7f4363'))
paddle.fluid.contrib.InitState ('paddle.fluid.contrib.decoder.beam_search_decoder.InitState', ('document', '3afd1f84232718e628e9e566941c5f05'))
paddle.fluid.contrib.InitState.__init__ (ArgSpec(args=['self', 'init', 'shape', 'value', 'init_boot', 'need_reorder', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 0.0, None, False, 'float32')), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.StateCell ('paddle.fluid.contrib.decoder.beam_search_decoder.StateCell', ('document', 'ecd0066c02867d445d7b461e28220c50'))
......
......@@ -23,6 +23,7 @@ strategy according to this module.
from __future__ import print_function
import math
import numbers
from . import control_flow
from . import nn
......@@ -30,6 +31,7 @@ from . import ops
from . import tensor
from ..initializer import init_on_cpu
from ..framework import default_main_program, Parameter, unique_name, name_scope
from ..framework import Variable
from ..dygraph import base as imperative_base
from ..dygraph import learning_rate_scheduler as imperate_lr
......@@ -450,8 +452,8 @@ def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
Args:
learning_rate (float | Variable): A float value or Variable.
warmup_steps (int): The warmup steps.
start_lr (float): The start learning of warmup.
end_lr (float): The end learning of warmup.
start_lr (float): The start learning rate of warmup.
end_lr (float): The end learning rate of warmup.
Returns:
The decayed learning rate in warmup period.
......@@ -470,14 +472,16 @@ def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
warmup_steps, start_lr, end_lr)
"""
assert (isinstance(end_lr, float))
assert (isinstance(start_lr, float))
linear_step = end_lr - start_lr
dtype = 'float32'
if isinstance(learning_rate, Variable):
dtype = learning_rate.dtype
linear_step = float(end_lr) - float(start_lr)
with default_main_program()._lr_schedule_guard():
lr = tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
dtype=dtype,
persistable=True,
name="learning_rate_warmup")
......@@ -489,5 +493,8 @@ def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
float(warmup_steps))
tensor.assign(decayed_lr, lr)
with switch.default():
if not isinstance(learning_rate, Variable):
learning_rate = tensor.fill_constant(
shape=[1], dtype=dtype, value=float(learning_rate))
tensor.assign(learning_rate, lr)
return lr
......@@ -185,7 +185,7 @@ class TestLinearWamrupLearningRateDecay(TestLearningRateDecay):
startup_prog = fluid.Program()
warmup_steps = 10
start_lr = 1. / 3.
start_lr = 0.1 / 3.
end_lr = 0.1
with fluid.program_guard(main_prog, startup_prog):
......@@ -212,5 +212,59 @@ class TestLinearWamrupLearningRateDecay(TestLearningRateDecay):
str(step), str(python_decayed_lr), str(lr_val[0])))
class TestLinearWamrupLearningRateDecayWithScalarInput(unittest.TestCase):
def run_scalar_lr(self, place, lr, start_lr, end_lr):
main_prog = fluid.Program()
startup_prog = fluid.Program()
warmup_steps = 10
with fluid.program_guard(main_prog, startup_prog):
decayed_lr = layers.linear_lr_warmup(lr, warmup_steps, start_lr,
end_lr)
exe = fluid.Executor(place)
exe.run(startup_prog)
for step in range(20):
lr_val, = exe.run(main_prog, feed={}, fetch_list=[decayed_lr])
if step < warmup_steps:
expected_lr = linear_lr_warmup(
float(step), warmup_steps, start_lr, end_lr)
else:
expected_lr = lr
self.assertAlmostEqual(
expected_lr,
lr_val[0],
msg='Test failed, step {0}, expected {1}, but got {2}'.format(
step, expected_lr, lr_val[0]))
def test_scalar_lr(self):
def run_places(lr, start_lr, end_lr):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.run_scalar_lr(p, lr, start_lr, end_lr)
# float
lr = 0.2
start_lr = 0.1 / 3.
end_lr = 0.2
run_places(lr, start_lr, end_lr)
# int end_lr
lr = 2.
start_lr = 0.1 / 3.
end_lr = 1
run_places(lr, start_lr, end_lr)
# int
lr = 1
start_lr = 0
end_lr = 1
run_places(lr, start_lr, end_lr)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册