diff --git a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..12362b1bc6401cdd9488cf2edc49afa581549603 --- /dev/null +++ b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +class Pow2DecayWithLinearWarmupOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + auto dim = framework::make_ddim({1}); + ctx->SetOutputDim("LearningRateOut", dim); + ctx->SetOutputDim("StepOut", dim); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "LearningRate"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class Pow2DecayWithLinearWarmupOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("LearningRate", "(Tensor) The input learning rate Tensor."); + AddInput("Step", "(Tensor) The input global step Tensor."); + AddOutput("LearningRateOut", + "(Tensor) The output learning rate Tensor. Same with " + "Input(LearningRate)."); + AddOutput( + "StepOut", + "(Tensor) The output learning rate Tensor. Same with Input(Step)."); + AddAttr("warmup_steps", "(int64_t) The warmup steps."); + AddAttr( + "total_steps", + "(int64_t) The total steps for changing the learning rate."); + AddAttr("start_lr", + "(float) The initial value of the learning rate."); + AddAttr("base_lr", + "(float) The final learning rate value after warmup."); + AddAttr("end_lr", + "(float) The final learning rate value after total_steps."); + AddComment(R"DOC( +The Pow2DecayWithLinearWarmup learning rate scheduler. + +When step_num < warmup_steps, lr = (base_lr - start_lr) * step_num / warmup_steps + start_lr + +When warmup_steps <= step_num <= total_steps, + factor = 1 - (step_num - warmup_steps) / (total_steps - warmup_steps) + lr = (base_lr - end_lr) * factor * factor + end_lr + +When step_num > total_steps, lr = end_lr + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(pow2_decay_with_linear_warmup, + ops::Pow2DecayWithLinearWarmupOp, + ops::Pow2DecayWithLinearWarmupOpMaker); +REGISTER_OP_CPU_KERNEL( + pow2_decay_with_linear_warmup, + ops::Pow2DecayWithLinearWarmupOpKernel, + ops::Pow2DecayWithLinearWarmupOpKernel); diff --git a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cu b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..6695778dbac063fdfdaed0afb4cad5ae74b112f9 --- /dev/null +++ b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cu @@ -0,0 +1,24 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + pow2_decay_with_linear_warmup, + ops::Pow2DecayWithLinearWarmupOpKernel, + ops::Pow2DecayWithLinearWarmupOpKernel); diff --git a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h new file mode 100644 index 0000000000000000000000000000000000000000..41e07b0343e728280639cc7d44ab0ea1cf893357 --- /dev/null +++ b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h @@ -0,0 +1,119 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/for_range.h" +#include "paddle/fluid/platform/macros.h" + +namespace paddle { +namespace operators { + +template +struct Pow2DecayWithLinearWarmupFunctor { + template + using RestrictPtr = U *PADDLE_RESTRICT; + + public: + HOSTDEVICE Pow2DecayWithLinearWarmupFunctor( + RestrictPtr lr, RestrictPtr step, size_t warmup_steps, + size_t total_steps, AttrT start_lr, AttrT base_lr, AttrT end_lr) + : lr_(lr), + step_(step), + warmup_steps_(warmup_steps), + total_steps_(total_steps), + start_lr_(start_lr), + base_lr_(base_lr), + end_lr_(end_lr) {} + + HOSTDEVICE void operator()(size_t) const { + size_t step = static_cast(*step_); + *step_ = static_cast(step + 1); + if (step < warmup_steps_) { + auto new_lr = + static_cast(base_lr_ - start_lr_) * step / warmup_steps_ + + start_lr_; + *lr_ = static_cast(new_lr); + } else if (step < total_steps_) { + auto factor = 1 - + static_cast(step - warmup_steps_) / + (total_steps_ - warmup_steps_); + auto new_lr = + static_cast(base_lr_ - end_lr_) * factor * factor + end_lr_; + *lr_ = static_cast(new_lr); + } else { + *lr_ = static_cast(end_lr_); + } + } + + private: + RestrictPtr lr_; + RestrictPtr step_; + size_t warmup_steps_; + size_t total_steps_; + AttrT start_lr_; + AttrT base_lr_; + AttrT end_lr_; +}; + +template +class Pow2DecayWithLinearWarmupOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const { + const auto *lr = ctx.Input("LearningRate"); + const auto *step = ctx.Input("Step"); + auto *lr_out = ctx.Output("LearningRateOut"); + auto *step_out = ctx.Output("StepOut"); + PADDLE_ENFORCE_EQ( + lr, lr_out, platform::errors::InvalidArgument("Input(LearningRate) and " + "Output(LearningRateOut) " + "must be the same.")); + PADDLE_ENFORCE_NOT_NULL(lr, + platform::errors::InvalidArgument( + "Input(LearingRate) should not be nullptr.")); + PADDLE_ENFORCE_EQ(step, step_out, + platform::errors::InvalidArgument( + "Input(Step) and Output(StepOut) must be the same.")); + PADDLE_ENFORCE_NOT_NULL(step, platform::errors::InvalidArgument( + "Input(Step) should not be nullptr.")); + PADDLE_ENFORCE_EQ( + step->IsInitialized(), true, + platform::errors::InvalidArgument("Input(Step) must be initialized.")); + + auto warmup_steps = static_cast(ctx.Attr("warmup_steps")); + auto total_steps = static_cast(ctx.Attr("total_steps")); + PADDLE_ENFORCE_LE(warmup_steps, total_steps, + platform::errors::InvalidArgument( + "warmup_steps must not be larger than total_steps.")); + auto start_lr = ctx.Attr("start_lr"); + auto base_lr = ctx.Attr("base_lr"); + auto end_lr = ctx.Attr("end_lr"); + + auto *lr_data = lr_out->data(); + auto *step_data = step_out->data(); + auto &dev_ctx = ctx.template device_context(); + platform::ForRange for_range(dev_ctx, 1); + using AttrT = float; + Pow2DecayWithLinearWarmupFunctor functor( + lr_data, step_data, warmup_steps, total_steps, + static_cast(start_lr), static_cast(base_lr), + static_cast(end_lr)); + for_range(functor); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 99ede353c1081e11e36919ddebadb2730a560dc3..0d0addb17e9ae6846b1c6bf34cae9303f6ab794d 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -1932,3 +1932,39 @@ def fused_bn_add_act(x, attrs=attrs) return batch_norm_out + + +def pow2_decay_with_linear_warmup(warmup_steps, + total_steps, + start_lr, + base_lr, + end_lr, + dtype='float32', + name=None): + if paddle.fluid.in_dygraph_mode(): + raise NotImplementedError( + "pow2_warmup does not support dygraph mode yet.") + + helper = LayerHelper("pow2_decay_with_linear_warmup", **locals()) + lr = helper.create_global_variable(persistable=True, dtype=dtype, shape=[1]) + helper.set_variable_initializer(lr, Constant(value=start_lr)) + + step = helper.create_global_variable( + persistable=True, dtype='int64', shape=[1]) + helper.set_variable_initializer(step, Constant(value=0)) + assert warmup_steps <= total_steps, "warmup_steps cannot be larger than total_steps" + + helper.append_op( + type="pow2_decay_with_linear_warmup", + inputs={"LearningRate": lr, + "Step": step}, + outputs={"LearningRateOut": lr, + "StepOut": step}, + attrs={ + "warmup_steps": warmup_steps, + "total_steps": total_steps, + "start_lr": start_lr, + "base_lr": base_lr, + "end_lr": end_lr, + }) + return lr diff --git a/python/paddle/fluid/tests/unittests/test_pow2_decay_with_linear_warmup_op.py b/python/paddle/fluid/tests/unittests/test_pow2_decay_with_linear_warmup_op.py new file mode 100644 index 0000000000000000000000000000000000000000..641ea3eccf8d2b259d78afd87779cce2cbf9639a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pow2_decay_with_linear_warmup_op.py @@ -0,0 +1,90 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid.contrib.layers.nn import pow2_decay_with_linear_warmup +from paddle.optimizer.lr import LinearWarmup +from paddle.optimizer.lr import PolynomialDecay +import unittest + + +def gen_pow2_warmup_op_lr(warmup_steps, total_steps, start_lr, base_lr, end_lr, + place): + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + lr = pow2_decay_with_linear_warmup(warmup_steps, total_steps, start_lr, + base_lr, end_lr) + exe = paddle.static.Executor(place) + with paddle.static.scope_guard(paddle.static.Scope()): + exe.run(startup) + while True: + lr_np = exe.run(main, fetch_list=[lr])[0] + yield lr_np[0] + + +class Pow2Warmup(LinearWarmup): + def __init__(self, warmup_steps, total_steps, start_lr, base_lr, end_lr): + assert total_steps > warmup_steps + lr_sch = PolynomialDecay( + learning_rate=base_lr, + decay_steps=total_steps - warmup_steps, + end_lr=end_lr, + power=2) + + super(Pow2Warmup, self).__init__( + learning_rate=lr_sch, + warmup_steps=warmup_steps, + start_lr=start_lr, + end_lr=base_lr) + + +def gen_pow2_warmup_py_lr(warmup_steps, total_steps, start_lr, base_lr, end_lr, + place): + lr_sch = Pow2Warmup(warmup_steps, total_steps, start_lr, base_lr, end_lr) + while True: + yield lr_sch() + lr_sch.step() + + +class TestPow2WarmupLRScheduler(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.params = { + 'warmup_steps': 30, + 'total_steps': 100, + 'start_lr': 0.01, + 'base_lr': 0.02, + 'end_lr': 0.001, + } + self.step_num = 1000 + + def check_with_place(self, place): + kwargs = dict(self.params) + kwargs['place'] = place + lr_sch_op = gen_pow2_warmup_op_lr(**kwargs) + lr_sch_py = gen_pow2_warmup_py_lr(**kwargs) + for i, (lr_op, lr_py) in enumerate(zip(lr_sch_op, lr_sch_py)): + self.assertLess(abs(lr_op - lr_py), 1e-6) + if i > self.step_num: + break + + def test_main(self): + self.check_with_place(paddle.CPUPlace()) + if paddle.is_compiled_with_cuda(): + self.check_with_place(paddle.CUDAPlace(0)) + + +if __name__ == "__main__": + unittest.main()