未验证 提交 605e7f08 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix pow2 decay (#36559)

上级 7325c9fb
...@@ -54,8 +54,6 @@ class Pow2DecayWithLinearWarmupOpMaker ...@@ -54,8 +54,6 @@ class Pow2DecayWithLinearWarmupOpMaker
AddAttr<int64_t>( AddAttr<int64_t>(
"total_steps", "total_steps",
"(int64_t) The total steps for changing the learning rate."); "(int64_t) The total steps for changing the learning rate.");
AddAttr<float>("start_lr",
"(float) The initial value of the learning rate.");
AddAttr<float>("base_lr", AddAttr<float>("base_lr",
"(float) The final learning rate value after warmup."); "(float) The final learning rate value after warmup.");
AddAttr<float>("end_lr", AddAttr<float>("end_lr",
...@@ -63,7 +61,7 @@ class Pow2DecayWithLinearWarmupOpMaker ...@@ -63,7 +61,7 @@ class Pow2DecayWithLinearWarmupOpMaker
AddComment(R"DOC( AddComment(R"DOC(
The Pow2DecayWithLinearWarmup learning rate scheduler. The Pow2DecayWithLinearWarmup learning rate scheduler.
When step_num < warmup_steps, lr = (base_lr - start_lr) * step_num / warmup_steps + start_lr When step_num < warmup_steps, lr = base_lr * step_num / warmup_steps
When warmup_steps <= step_num <= total_steps, When warmup_steps <= step_num <= total_steps,
factor = 1 - (step_num - warmup_steps) / (total_steps - warmup_steps) factor = 1 - (step_num - warmup_steps) / (total_steps - warmup_steps)
......
...@@ -28,31 +28,30 @@ struct Pow2DecayWithLinearWarmupFunctor { ...@@ -28,31 +28,30 @@ struct Pow2DecayWithLinearWarmupFunctor {
using RestrictPtr = U *PADDLE_RESTRICT; using RestrictPtr = U *PADDLE_RESTRICT;
public: public:
HOSTDEVICE Pow2DecayWithLinearWarmupFunctor( HOSTDEVICE Pow2DecayWithLinearWarmupFunctor(RestrictPtr<T> lr,
RestrictPtr<T> lr, RestrictPtr<int64_t> step, size_t warmup_steps, RestrictPtr<int64_t> step,
size_t total_steps, AttrT start_lr, AttrT base_lr, AttrT end_lr) size_t warmup_steps,
size_t total_steps, AttrT base_lr,
AttrT end_lr)
: lr_(lr), : lr_(lr),
step_(step), step_(step),
warmup_steps_(warmup_steps), warmup_steps_(warmup_steps),
total_steps_(total_steps), total_steps_(total_steps),
start_lr_(start_lr),
base_lr_(base_lr), base_lr_(base_lr),
end_lr_(end_lr) {} end_lr_(end_lr) {}
HOSTDEVICE void operator()(size_t) const { HOSTDEVICE void operator()(size_t) const {
size_t step = static_cast<size_t>(*step_); size_t step = static_cast<size_t>(*step_) + 1;
*step_ = static_cast<int64_t>(step + 1); *step_ = static_cast<int64_t>(step);
if (step < warmup_steps_) { if (step <= warmup_steps_) {
auto new_lr = auto new_lr = static_cast<double>(step) / warmup_steps_ * base_lr_;
static_cast<double>(base_lr_ - start_lr_) * step / warmup_steps_ +
start_lr_;
*lr_ = static_cast<T>(new_lr); *lr_ = static_cast<T>(new_lr);
} else if (step < total_steps_) { } else if (step < total_steps_) {
auto factor = 1 - auto factor = 1 -
static_cast<double>(step - warmup_steps_) / static_cast<double>(step - warmup_steps_) /
(total_steps_ - warmup_steps_); (total_steps_ - warmup_steps_);
auto new_lr = auto new_lr =
static_cast<double>(base_lr_ - end_lr_) * factor * factor + end_lr_; static_cast<double>(base_lr_ - end_lr_) * (factor * factor) + end_lr_;
*lr_ = static_cast<T>(new_lr); *lr_ = static_cast<T>(new_lr);
} else { } else {
*lr_ = static_cast<T>(end_lr_); *lr_ = static_cast<T>(end_lr_);
...@@ -64,7 +63,6 @@ struct Pow2DecayWithLinearWarmupFunctor { ...@@ -64,7 +63,6 @@ struct Pow2DecayWithLinearWarmupFunctor {
RestrictPtr<int64_t> step_; RestrictPtr<int64_t> step_;
size_t warmup_steps_; size_t warmup_steps_;
size_t total_steps_; size_t total_steps_;
AttrT start_lr_;
AttrT base_lr_; AttrT base_lr_;
AttrT end_lr_; AttrT end_lr_;
}; };
...@@ -98,7 +96,6 @@ class Pow2DecayWithLinearWarmupOpKernel : public framework::OpKernel<T> { ...@@ -98,7 +96,6 @@ class Pow2DecayWithLinearWarmupOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_LE(warmup_steps, total_steps, PADDLE_ENFORCE_LE(warmup_steps, total_steps,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"warmup_steps must not be larger than total_steps.")); "warmup_steps must not be larger than total_steps."));
auto start_lr = ctx.Attr<float>("start_lr");
auto base_lr = ctx.Attr<float>("base_lr"); auto base_lr = ctx.Attr<float>("base_lr");
auto end_lr = ctx.Attr<float>("end_lr"); auto end_lr = ctx.Attr<float>("end_lr");
...@@ -106,11 +103,10 @@ class Pow2DecayWithLinearWarmupOpKernel : public framework::OpKernel<T> { ...@@ -106,11 +103,10 @@ class Pow2DecayWithLinearWarmupOpKernel : public framework::OpKernel<T> {
auto *step_data = step_out->data<int64_t>(); auto *step_data = step_out->data<int64_t>();
auto &dev_ctx = ctx.template device_context<DeviceContext>(); auto &dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, 1); platform::ForRange<DeviceContext> for_range(dev_ctx, 1);
using AttrT = float; using AttrT = double;
Pow2DecayWithLinearWarmupFunctor<T, AttrT> functor( Pow2DecayWithLinearWarmupFunctor<T, AttrT> functor(
lr_data, step_data, warmup_steps, total_steps, lr_data, step_data, warmup_steps, total_steps,
static_cast<AttrT>(start_lr), static_cast<AttrT>(base_lr), static_cast<AttrT>(base_lr), static_cast<AttrT>(end_lr));
static_cast<AttrT>(end_lr));
for_range(functor); for_range(functor);
} }
}; };
......
...@@ -1936,18 +1936,18 @@ def fused_bn_add_act(x, ...@@ -1936,18 +1936,18 @@ def fused_bn_add_act(x,
def pow2_decay_with_linear_warmup(warmup_steps, def pow2_decay_with_linear_warmup(warmup_steps,
total_steps, total_steps,
start_lr,
base_lr, base_lr,
end_lr, end_lr,
dtype='float32', dtype='float32',
name=None): name=None):
if paddle.fluid.in_dygraph_mode(): if paddle.fluid.in_dygraph_mode():
raise NotImplementedError( raise NotImplementedError(
"pow2_warmup does not support dygraph mode yet.") "pow2_decay_with_linear_warmup does not support dygraph mode yet.")
helper = LayerHelper("pow2_decay_with_linear_warmup", **locals()) helper = LayerHelper("pow2_decay_with_linear_warmup", **locals())
lr = helper.create_global_variable(persistable=True, dtype=dtype, shape=[1]) lr = helper.create_global_variable(persistable=True, dtype=dtype, shape=[1])
helper.set_variable_initializer(lr, Constant(value=start_lr)) helper.set_variable_initializer(
lr, Constant(value=float(base_lr) / warmup_steps))
step = helper.create_global_variable( step = helper.create_global_variable(
persistable=True, dtype='int64', shape=[1]) persistable=True, dtype='int64', shape=[1])
...@@ -1963,7 +1963,6 @@ def pow2_decay_with_linear_warmup(warmup_steps, ...@@ -1963,7 +1963,6 @@ def pow2_decay_with_linear_warmup(warmup_steps,
attrs={ attrs={
"warmup_steps": warmup_steps, "warmup_steps": warmup_steps,
"total_steps": total_steps, "total_steps": total_steps,
"start_lr": start_lr,
"base_lr": base_lr, "base_lr": base_lr,
"end_lr": end_lr, "end_lr": end_lr,
}) })
......
...@@ -19,13 +19,12 @@ from paddle.optimizer.lr import PolynomialDecay ...@@ -19,13 +19,12 @@ from paddle.optimizer.lr import PolynomialDecay
import unittest import unittest
def gen_pow2_warmup_op_lr(warmup_steps, total_steps, start_lr, base_lr, end_lr, def gen_pow2_warmup_op_lr(warmup_steps, total_steps, base_lr, end_lr, place):
place):
main = paddle.static.Program() main = paddle.static.Program()
startup = paddle.static.Program() startup = paddle.static.Program()
with paddle.static.program_guard(main, startup): with paddle.static.program_guard(main, startup):
lr = pow2_decay_with_linear_warmup(warmup_steps, total_steps, start_lr, lr = pow2_decay_with_linear_warmup(warmup_steps, total_steps, base_lr,
base_lr, end_lr) end_lr)
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
with paddle.static.scope_guard(paddle.static.Scope()): with paddle.static.scope_guard(paddle.static.Scope()):
exe.run(startup) exe.run(startup)
...@@ -35,7 +34,7 @@ def gen_pow2_warmup_op_lr(warmup_steps, total_steps, start_lr, base_lr, end_lr, ...@@ -35,7 +34,7 @@ def gen_pow2_warmup_op_lr(warmup_steps, total_steps, start_lr, base_lr, end_lr,
class Pow2Warmup(LinearWarmup): class Pow2Warmup(LinearWarmup):
def __init__(self, warmup_steps, total_steps, start_lr, base_lr, end_lr): def __init__(self, warmup_steps, total_steps, base_lr, end_lr):
assert total_steps > warmup_steps assert total_steps > warmup_steps
lr_sch = PolynomialDecay( lr_sch = PolynomialDecay(
learning_rate=base_lr, learning_rate=base_lr,
...@@ -46,13 +45,13 @@ class Pow2Warmup(LinearWarmup): ...@@ -46,13 +45,13 @@ class Pow2Warmup(LinearWarmup):
super(Pow2Warmup, self).__init__( super(Pow2Warmup, self).__init__(
learning_rate=lr_sch, learning_rate=lr_sch,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
start_lr=start_lr, start_lr=0.0,
end_lr=base_lr) end_lr=base_lr)
def gen_pow2_warmup_py_lr(warmup_steps, total_steps, start_lr, base_lr, end_lr, def gen_pow2_warmup_py_lr(warmup_steps, total_steps, base_lr, end_lr, place):
place): lr_sch = Pow2Warmup(warmup_steps, total_steps, base_lr, end_lr)
lr_sch = Pow2Warmup(warmup_steps, total_steps, start_lr, base_lr, end_lr) lr_sch.step()
while True: while True:
yield lr_sch() yield lr_sch()
lr_sch.step() lr_sch.step()
...@@ -64,7 +63,6 @@ class TestPow2WarmupLRScheduler(unittest.TestCase): ...@@ -64,7 +63,6 @@ class TestPow2WarmupLRScheduler(unittest.TestCase):
self.params = { self.params = {
'warmup_steps': 30, 'warmup_steps': 30,
'total_steps': 100, 'total_steps': 100,
'start_lr': 0.01,
'base_lr': 0.02, 'base_lr': 0.02,
'end_lr': 0.001, 'end_lr': 0.001,
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册