未验证 提交 605e7f08 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix pow2 decay (#36559)

上级 7325c9fb
......@@ -54,8 +54,6 @@ class Pow2DecayWithLinearWarmupOpMaker
AddAttr<int64_t>(
"total_steps",
"(int64_t) The total steps for changing the learning rate.");
AddAttr<float>("start_lr",
"(float) The initial value of the learning rate.");
AddAttr<float>("base_lr",
"(float) The final learning rate value after warmup.");
AddAttr<float>("end_lr",
......@@ -63,7 +61,7 @@ class Pow2DecayWithLinearWarmupOpMaker
AddComment(R"DOC(
The Pow2DecayWithLinearWarmup learning rate scheduler.
When step_num < warmup_steps, lr = (base_lr - start_lr) * step_num / warmup_steps + start_lr
When step_num < warmup_steps, lr = base_lr * step_num / warmup_steps
When warmup_steps <= step_num <= total_steps,
factor = 1 - (step_num - warmup_steps) / (total_steps - warmup_steps)
......
......@@ -28,31 +28,30 @@ struct Pow2DecayWithLinearWarmupFunctor {
using RestrictPtr = U *PADDLE_RESTRICT;
public:
HOSTDEVICE Pow2DecayWithLinearWarmupFunctor(
RestrictPtr<T> lr, RestrictPtr<int64_t> step, size_t warmup_steps,
size_t total_steps, AttrT start_lr, AttrT base_lr, AttrT end_lr)
HOSTDEVICE Pow2DecayWithLinearWarmupFunctor(RestrictPtr<T> lr,
RestrictPtr<int64_t> step,
size_t warmup_steps,
size_t total_steps, AttrT base_lr,
AttrT end_lr)
: lr_(lr),
step_(step),
warmup_steps_(warmup_steps),
total_steps_(total_steps),
start_lr_(start_lr),
base_lr_(base_lr),
end_lr_(end_lr) {}
HOSTDEVICE void operator()(size_t) const {
size_t step = static_cast<size_t>(*step_);
*step_ = static_cast<int64_t>(step + 1);
if (step < warmup_steps_) {
auto new_lr =
static_cast<double>(base_lr_ - start_lr_) * step / warmup_steps_ +
start_lr_;
size_t step = static_cast<size_t>(*step_) + 1;
*step_ = static_cast<int64_t>(step);
if (step <= warmup_steps_) {
auto new_lr = static_cast<double>(step) / warmup_steps_ * base_lr_;
*lr_ = static_cast<T>(new_lr);
} else if (step < total_steps_) {
auto factor = 1 -
static_cast<double>(step - warmup_steps_) /
(total_steps_ - warmup_steps_);
auto new_lr =
static_cast<double>(base_lr_ - end_lr_) * factor * factor + end_lr_;
static_cast<double>(base_lr_ - end_lr_) * (factor * factor) + end_lr_;
*lr_ = static_cast<T>(new_lr);
} else {
*lr_ = static_cast<T>(end_lr_);
......@@ -64,7 +63,6 @@ struct Pow2DecayWithLinearWarmupFunctor {
RestrictPtr<int64_t> step_;
size_t warmup_steps_;
size_t total_steps_;
AttrT start_lr_;
AttrT base_lr_;
AttrT end_lr_;
};
......@@ -98,7 +96,6 @@ class Pow2DecayWithLinearWarmupOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_LE(warmup_steps, total_steps,
platform::errors::InvalidArgument(
"warmup_steps must not be larger than total_steps."));
auto start_lr = ctx.Attr<float>("start_lr");
auto base_lr = ctx.Attr<float>("base_lr");
auto end_lr = ctx.Attr<float>("end_lr");
......@@ -106,11 +103,10 @@ class Pow2DecayWithLinearWarmupOpKernel : public framework::OpKernel<T> {
auto *step_data = step_out->data<int64_t>();
auto &dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, 1);
using AttrT = float;
using AttrT = double;
Pow2DecayWithLinearWarmupFunctor<T, AttrT> functor(
lr_data, step_data, warmup_steps, total_steps,
static_cast<AttrT>(start_lr), static_cast<AttrT>(base_lr),
static_cast<AttrT>(end_lr));
static_cast<AttrT>(base_lr), static_cast<AttrT>(end_lr));
for_range(functor);
}
};
......
......@@ -1936,18 +1936,18 @@ def fused_bn_add_act(x,
def pow2_decay_with_linear_warmup(warmup_steps,
total_steps,
start_lr,
base_lr,
end_lr,
dtype='float32',
name=None):
if paddle.fluid.in_dygraph_mode():
raise NotImplementedError(
"pow2_warmup does not support dygraph mode yet.")
"pow2_decay_with_linear_warmup does not support dygraph mode yet.")
helper = LayerHelper("pow2_decay_with_linear_warmup", **locals())
lr = helper.create_global_variable(persistable=True, dtype=dtype, shape=[1])
helper.set_variable_initializer(lr, Constant(value=start_lr))
helper.set_variable_initializer(
lr, Constant(value=float(base_lr) / warmup_steps))
step = helper.create_global_variable(
persistable=True, dtype='int64', shape=[1])
......@@ -1963,7 +1963,6 @@ def pow2_decay_with_linear_warmup(warmup_steps,
attrs={
"warmup_steps": warmup_steps,
"total_steps": total_steps,
"start_lr": start_lr,
"base_lr": base_lr,
"end_lr": end_lr,
})
......
......@@ -19,13 +19,12 @@ from paddle.optimizer.lr import PolynomialDecay
import unittest
def gen_pow2_warmup_op_lr(warmup_steps, total_steps, start_lr, base_lr, end_lr,
place):
def gen_pow2_warmup_op_lr(warmup_steps, total_steps, base_lr, end_lr, place):
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
lr = pow2_decay_with_linear_warmup(warmup_steps, total_steps, start_lr,
base_lr, end_lr)
lr = pow2_decay_with_linear_warmup(warmup_steps, total_steps, base_lr,
end_lr)
exe = paddle.static.Executor(place)
with paddle.static.scope_guard(paddle.static.Scope()):
exe.run(startup)
......@@ -35,7 +34,7 @@ def gen_pow2_warmup_op_lr(warmup_steps, total_steps, start_lr, base_lr, end_lr,
class Pow2Warmup(LinearWarmup):
def __init__(self, warmup_steps, total_steps, start_lr, base_lr, end_lr):
def __init__(self, warmup_steps, total_steps, base_lr, end_lr):
assert total_steps > warmup_steps
lr_sch = PolynomialDecay(
learning_rate=base_lr,
......@@ -46,13 +45,13 @@ class Pow2Warmup(LinearWarmup):
super(Pow2Warmup, self).__init__(
learning_rate=lr_sch,
warmup_steps=warmup_steps,
start_lr=start_lr,
start_lr=0.0,
end_lr=base_lr)
def gen_pow2_warmup_py_lr(warmup_steps, total_steps, start_lr, base_lr, end_lr,
place):
lr_sch = Pow2Warmup(warmup_steps, total_steps, start_lr, base_lr, end_lr)
def gen_pow2_warmup_py_lr(warmup_steps, total_steps, base_lr, end_lr, place):
lr_sch = Pow2Warmup(warmup_steps, total_steps, base_lr, end_lr)
lr_sch.step()
while True:
yield lr_sch()
lr_sch.step()
......@@ -64,7 +63,6 @@ class TestPow2WarmupLRScheduler(unittest.TestCase):
self.params = {
'warmup_steps': 30,
'total_steps': 100,
'start_lr': 0.01,
'base_lr': 0.02,
'end_lr': 0.001,
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册