未验证 提交 305b99a0 编写于 作者: Z Zeng Jinle 提交者: GitHub

Add pow2_decay_with_linear_warmup op (#36421)

* add pow2_warmup op

* remove contrib __all__

* add AttrT

* rename

* follow comments

* fix duplicate PADDLE_RESTRICT
上级 10f0a0f6
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
class Pow2DecayWithLinearWarmupOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
auto dim = framework::make_ddim({1});
ctx->SetOutputDim("LearningRateOut", dim);
ctx->SetOutputDim("StepOut", dim);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "LearningRate");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class Pow2DecayWithLinearWarmupOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("LearningRate", "(Tensor) The input learning rate Tensor.");
AddInput("Step", "(Tensor) The input global step Tensor.");
AddOutput("LearningRateOut",
"(Tensor) The output learning rate Tensor. Same with "
"Input(LearningRate).");
AddOutput(
"StepOut",
"(Tensor) The output learning rate Tensor. Same with Input(Step).");
AddAttr<int64_t>("warmup_steps", "(int64_t) The warmup steps.");
AddAttr<int64_t>(
"total_steps",
"(int64_t) The total steps for changing the learning rate.");
AddAttr<float>("start_lr",
"(float) The initial value of the learning rate.");
AddAttr<float>("base_lr",
"(float) The final learning rate value after warmup.");
AddAttr<float>("end_lr",
"(float) The final learning rate value after total_steps.");
AddComment(R"DOC(
The Pow2DecayWithLinearWarmup learning rate scheduler.
When step_num < warmup_steps, lr = (base_lr - start_lr) * step_num / warmup_steps + start_lr
When warmup_steps <= step_num <= total_steps,
factor = 1 - (step_num - warmup_steps) / (total_steps - warmup_steps)
lr = (base_lr - end_lr) * factor * factor + end_lr
When step_num > total_steps, lr = end_lr
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(pow2_decay_with_linear_warmup,
ops::Pow2DecayWithLinearWarmupOp,
ops::Pow2DecayWithLinearWarmupOpMaker);
REGISTER_OP_CPU_KERNEL(
pow2_decay_with_linear_warmup,
ops::Pow2DecayWithLinearWarmupOpKernel<plat::CPUDeviceContext, double>,
ops::Pow2DecayWithLinearWarmupOpKernel<plat::CPUDeviceContext, float>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
pow2_decay_with_linear_warmup,
ops::Pow2DecayWithLinearWarmupOpKernel<plat::CUDADeviceContext, double>,
ops::Pow2DecayWithLinearWarmupOpKernel<plat::CUDADeviceContext, float>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace operators {
template <typename T, typename AttrT>
struct Pow2DecayWithLinearWarmupFunctor {
template <typename U>
using RestrictPtr = U *PADDLE_RESTRICT;
public:
HOSTDEVICE Pow2DecayWithLinearWarmupFunctor(
RestrictPtr<T> lr, RestrictPtr<int64_t> step, size_t warmup_steps,
size_t total_steps, AttrT start_lr, AttrT base_lr, AttrT end_lr)
: lr_(lr),
step_(step),
warmup_steps_(warmup_steps),
total_steps_(total_steps),
start_lr_(start_lr),
base_lr_(base_lr),
end_lr_(end_lr) {}
HOSTDEVICE void operator()(size_t) const {
size_t step = static_cast<size_t>(*step_);
*step_ = static_cast<int64_t>(step + 1);
if (step < warmup_steps_) {
auto new_lr =
static_cast<double>(base_lr_ - start_lr_) * step / warmup_steps_ +
start_lr_;
*lr_ = static_cast<T>(new_lr);
} else if (step < total_steps_) {
auto factor = 1 -
static_cast<double>(step - warmup_steps_) /
(total_steps_ - warmup_steps_);
auto new_lr =
static_cast<double>(base_lr_ - end_lr_) * factor * factor + end_lr_;
*lr_ = static_cast<T>(new_lr);
} else {
*lr_ = static_cast<T>(end_lr_);
}
}
private:
RestrictPtr<T> lr_;
RestrictPtr<int64_t> step_;
size_t warmup_steps_;
size_t total_steps_;
AttrT start_lr_;
AttrT base_lr_;
AttrT end_lr_;
};
template <typename DeviceContext, typename T>
class Pow2DecayWithLinearWarmupOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const {
const auto *lr = ctx.Input<framework::Tensor>("LearningRate");
const auto *step = ctx.Input<framework::Tensor>("Step");
auto *lr_out = ctx.Output<framework::Tensor>("LearningRateOut");
auto *step_out = ctx.Output<framework::Tensor>("StepOut");
PADDLE_ENFORCE_EQ(
lr, lr_out, platform::errors::InvalidArgument("Input(LearningRate) and "
"Output(LearningRateOut) "
"must be the same."));
PADDLE_ENFORCE_NOT_NULL(lr,
platform::errors::InvalidArgument(
"Input(LearingRate) should not be nullptr."));
PADDLE_ENFORCE_EQ(step, step_out,
platform::errors::InvalidArgument(
"Input(Step) and Output(StepOut) must be the same."));
PADDLE_ENFORCE_NOT_NULL(step, platform::errors::InvalidArgument(
"Input(Step) should not be nullptr."));
PADDLE_ENFORCE_EQ(
step->IsInitialized(), true,
platform::errors::InvalidArgument("Input(Step) must be initialized."));
auto warmup_steps = static_cast<size_t>(ctx.Attr<int64_t>("warmup_steps"));
auto total_steps = static_cast<size_t>(ctx.Attr<int64_t>("total_steps"));
PADDLE_ENFORCE_LE(warmup_steps, total_steps,
platform::errors::InvalidArgument(
"warmup_steps must not be larger than total_steps."));
auto start_lr = ctx.Attr<float>("start_lr");
auto base_lr = ctx.Attr<float>("base_lr");
auto end_lr = ctx.Attr<float>("end_lr");
auto *lr_data = lr_out->data<T>();
auto *step_data = step_out->data<int64_t>();
auto &dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, 1);
using AttrT = float;
Pow2DecayWithLinearWarmupFunctor<T, AttrT> functor(
lr_data, step_data, warmup_steps, total_steps,
static_cast<AttrT>(start_lr), static_cast<AttrT>(base_lr),
static_cast<AttrT>(end_lr));
for_range(functor);
}
};
} // namespace operators
} // namespace paddle
......@@ -1932,3 +1932,39 @@ def fused_bn_add_act(x,
attrs=attrs)
return batch_norm_out
def pow2_decay_with_linear_warmup(warmup_steps,
total_steps,
start_lr,
base_lr,
end_lr,
dtype='float32',
name=None):
if paddle.fluid.in_dygraph_mode():
raise NotImplementedError(
"pow2_warmup does not support dygraph mode yet.")
helper = LayerHelper("pow2_decay_with_linear_warmup", **locals())
lr = helper.create_global_variable(persistable=True, dtype=dtype, shape=[1])
helper.set_variable_initializer(lr, Constant(value=start_lr))
step = helper.create_global_variable(
persistable=True, dtype='int64', shape=[1])
helper.set_variable_initializer(step, Constant(value=0))
assert warmup_steps <= total_steps, "warmup_steps cannot be larger than total_steps"
helper.append_op(
type="pow2_decay_with_linear_warmup",
inputs={"LearningRate": lr,
"Step": step},
outputs={"LearningRateOut": lr,
"StepOut": step},
attrs={
"warmup_steps": warmup_steps,
"total_steps": total_steps,
"start_lr": start_lr,
"base_lr": base_lr,
"end_lr": end_lr,
})
return lr
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid.contrib.layers.nn import pow2_decay_with_linear_warmup
from paddle.optimizer.lr import LinearWarmup
from paddle.optimizer.lr import PolynomialDecay
import unittest
def gen_pow2_warmup_op_lr(warmup_steps, total_steps, start_lr, base_lr, end_lr,
place):
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
lr = pow2_decay_with_linear_warmup(warmup_steps, total_steps, start_lr,
base_lr, end_lr)
exe = paddle.static.Executor(place)
with paddle.static.scope_guard(paddle.static.Scope()):
exe.run(startup)
while True:
lr_np = exe.run(main, fetch_list=[lr])[0]
yield lr_np[0]
class Pow2Warmup(LinearWarmup):
def __init__(self, warmup_steps, total_steps, start_lr, base_lr, end_lr):
assert total_steps > warmup_steps
lr_sch = PolynomialDecay(
learning_rate=base_lr,
decay_steps=total_steps - warmup_steps,
end_lr=end_lr,
power=2)
super(Pow2Warmup, self).__init__(
learning_rate=lr_sch,
warmup_steps=warmup_steps,
start_lr=start_lr,
end_lr=base_lr)
def gen_pow2_warmup_py_lr(warmup_steps, total_steps, start_lr, base_lr, end_lr,
place):
lr_sch = Pow2Warmup(warmup_steps, total_steps, start_lr, base_lr, end_lr)
while True:
yield lr_sch()
lr_sch.step()
class TestPow2WarmupLRScheduler(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.params = {
'warmup_steps': 30,
'total_steps': 100,
'start_lr': 0.01,
'base_lr': 0.02,
'end_lr': 0.001,
}
self.step_num = 1000
def check_with_place(self, place):
kwargs = dict(self.params)
kwargs['place'] = place
lr_sch_op = gen_pow2_warmup_op_lr(**kwargs)
lr_sch_py = gen_pow2_warmup_py_lr(**kwargs)
for i, (lr_op, lr_py) in enumerate(zip(lr_sch_op, lr_sch_py)):
self.assertLess(abs(lr_op - lr_py), 1e-6)
if i > self.step_num:
break
def test_main(self):
self.check_with_place(paddle.CPUPlace())
if paddle.is_compiled_with_cuda():
self.check_with_place(paddle.CUDAPlace(0))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册