From 605e7f0849eab68deac0c1972441e24824ba1b63 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 20 Oct 2021 13:30:11 +0800 Subject: [PATCH] fix pow2 decay (#36559) --- .../pow2_decay_with_linear_warmup_op.cc | 4 +-- .../pow2_decay_with_linear_warmup_op.h | 28 ++++++++----------- python/paddle/fluid/contrib/layers/nn.py | 7 ++--- .../test_pow2_decay_with_linear_warmup_op.py | 18 ++++++------ 4 files changed, 24 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc index 12362b1bc64..4d919c94f61 100644 --- a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc +++ b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc @@ -54,8 +54,6 @@ class Pow2DecayWithLinearWarmupOpMaker AddAttr( "total_steps", "(int64_t) The total steps for changing the learning rate."); - AddAttr("start_lr", - "(float) The initial value of the learning rate."); AddAttr("base_lr", "(float) The final learning rate value after warmup."); AddAttr("end_lr", @@ -63,7 +61,7 @@ class Pow2DecayWithLinearWarmupOpMaker AddComment(R"DOC( The Pow2DecayWithLinearWarmup learning rate scheduler. -When step_num < warmup_steps, lr = (base_lr - start_lr) * step_num / warmup_steps + start_lr +When step_num < warmup_steps, lr = base_lr * step_num / warmup_steps When warmup_steps <= step_num <= total_steps, factor = 1 - (step_num - warmup_steps) / (total_steps - warmup_steps) diff --git a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h index 41e07b0343e..74cf7627450 100644 --- a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h +++ b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h @@ -28,31 +28,30 @@ struct Pow2DecayWithLinearWarmupFunctor { using RestrictPtr = U *PADDLE_RESTRICT; public: - HOSTDEVICE Pow2DecayWithLinearWarmupFunctor( - RestrictPtr lr, RestrictPtr step, size_t warmup_steps, - size_t total_steps, AttrT start_lr, AttrT base_lr, AttrT end_lr) + HOSTDEVICE Pow2DecayWithLinearWarmupFunctor(RestrictPtr lr, + RestrictPtr step, + size_t warmup_steps, + size_t total_steps, AttrT base_lr, + AttrT end_lr) : lr_(lr), step_(step), warmup_steps_(warmup_steps), total_steps_(total_steps), - start_lr_(start_lr), base_lr_(base_lr), end_lr_(end_lr) {} HOSTDEVICE void operator()(size_t) const { - size_t step = static_cast(*step_); - *step_ = static_cast(step + 1); - if (step < warmup_steps_) { - auto new_lr = - static_cast(base_lr_ - start_lr_) * step / warmup_steps_ + - start_lr_; + size_t step = static_cast(*step_) + 1; + *step_ = static_cast(step); + if (step <= warmup_steps_) { + auto new_lr = static_cast(step) / warmup_steps_ * base_lr_; *lr_ = static_cast(new_lr); } else if (step < total_steps_) { auto factor = 1 - static_cast(step - warmup_steps_) / (total_steps_ - warmup_steps_); auto new_lr = - static_cast(base_lr_ - end_lr_) * factor * factor + end_lr_; + static_cast(base_lr_ - end_lr_) * (factor * factor) + end_lr_; *lr_ = static_cast(new_lr); } else { *lr_ = static_cast(end_lr_); @@ -64,7 +63,6 @@ struct Pow2DecayWithLinearWarmupFunctor { RestrictPtr step_; size_t warmup_steps_; size_t total_steps_; - AttrT start_lr_; AttrT base_lr_; AttrT end_lr_; }; @@ -98,7 +96,6 @@ class Pow2DecayWithLinearWarmupOpKernel : public framework::OpKernel { PADDLE_ENFORCE_LE(warmup_steps, total_steps, platform::errors::InvalidArgument( "warmup_steps must not be larger than total_steps.")); - auto start_lr = ctx.Attr("start_lr"); auto base_lr = ctx.Attr("base_lr"); auto end_lr = ctx.Attr("end_lr"); @@ -106,11 +103,10 @@ class Pow2DecayWithLinearWarmupOpKernel : public framework::OpKernel { auto *step_data = step_out->data(); auto &dev_ctx = ctx.template device_context(); platform::ForRange for_range(dev_ctx, 1); - using AttrT = float; + using AttrT = double; Pow2DecayWithLinearWarmupFunctor functor( lr_data, step_data, warmup_steps, total_steps, - static_cast(start_lr), static_cast(base_lr), - static_cast(end_lr)); + static_cast(base_lr), static_cast(end_lr)); for_range(functor); } }; diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 0d0addb17e9..cb26f05b549 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -1936,18 +1936,18 @@ def fused_bn_add_act(x, def pow2_decay_with_linear_warmup(warmup_steps, total_steps, - start_lr, base_lr, end_lr, dtype='float32', name=None): if paddle.fluid.in_dygraph_mode(): raise NotImplementedError( - "pow2_warmup does not support dygraph mode yet.") + "pow2_decay_with_linear_warmup does not support dygraph mode yet.") helper = LayerHelper("pow2_decay_with_linear_warmup", **locals()) lr = helper.create_global_variable(persistable=True, dtype=dtype, shape=[1]) - helper.set_variable_initializer(lr, Constant(value=start_lr)) + helper.set_variable_initializer( + lr, Constant(value=float(base_lr) / warmup_steps)) step = helper.create_global_variable( persistable=True, dtype='int64', shape=[1]) @@ -1963,7 +1963,6 @@ def pow2_decay_with_linear_warmup(warmup_steps, attrs={ "warmup_steps": warmup_steps, "total_steps": total_steps, - "start_lr": start_lr, "base_lr": base_lr, "end_lr": end_lr, }) diff --git a/python/paddle/fluid/tests/unittests/test_pow2_decay_with_linear_warmup_op.py b/python/paddle/fluid/tests/unittests/test_pow2_decay_with_linear_warmup_op.py index 641ea3eccf8..056db5b8590 100644 --- a/python/paddle/fluid/tests/unittests/test_pow2_decay_with_linear_warmup_op.py +++ b/python/paddle/fluid/tests/unittests/test_pow2_decay_with_linear_warmup_op.py @@ -19,13 +19,12 @@ from paddle.optimizer.lr import PolynomialDecay import unittest -def gen_pow2_warmup_op_lr(warmup_steps, total_steps, start_lr, base_lr, end_lr, - place): +def gen_pow2_warmup_op_lr(warmup_steps, total_steps, base_lr, end_lr, place): main = paddle.static.Program() startup = paddle.static.Program() with paddle.static.program_guard(main, startup): - lr = pow2_decay_with_linear_warmup(warmup_steps, total_steps, start_lr, - base_lr, end_lr) + lr = pow2_decay_with_linear_warmup(warmup_steps, total_steps, base_lr, + end_lr) exe = paddle.static.Executor(place) with paddle.static.scope_guard(paddle.static.Scope()): exe.run(startup) @@ -35,7 +34,7 @@ def gen_pow2_warmup_op_lr(warmup_steps, total_steps, start_lr, base_lr, end_lr, class Pow2Warmup(LinearWarmup): - def __init__(self, warmup_steps, total_steps, start_lr, base_lr, end_lr): + def __init__(self, warmup_steps, total_steps, base_lr, end_lr): assert total_steps > warmup_steps lr_sch = PolynomialDecay( learning_rate=base_lr, @@ -46,13 +45,13 @@ class Pow2Warmup(LinearWarmup): super(Pow2Warmup, self).__init__( learning_rate=lr_sch, warmup_steps=warmup_steps, - start_lr=start_lr, + start_lr=0.0, end_lr=base_lr) -def gen_pow2_warmup_py_lr(warmup_steps, total_steps, start_lr, base_lr, end_lr, - place): - lr_sch = Pow2Warmup(warmup_steps, total_steps, start_lr, base_lr, end_lr) +def gen_pow2_warmup_py_lr(warmup_steps, total_steps, base_lr, end_lr, place): + lr_sch = Pow2Warmup(warmup_steps, total_steps, base_lr, end_lr) + lr_sch.step() while True: yield lr_sch() lr_sch.step() @@ -64,7 +63,6 @@ class TestPow2WarmupLRScheduler(unittest.TestCase): self.params = { 'warmup_steps': 30, 'total_steps': 100, - 'start_lr': 0.01, 'base_lr': 0.02, 'end_lr': 0.001, } -- GitLab