From 348565b04bd51a5af8ebebb39b2fbe75135041da Mon Sep 17 00:00:00 2001 From: huangjiyi <43315610+huangjiyi@users.noreply.github.com> Date: Fri, 12 May 2023 11:40:06 +0800 Subject: [PATCH] move pow2_decay_with_linear_warmup kernel to phi (#53741) * update * update --- .../pow2_decay_with_linear_warmup_op.cc | 7 - .../pow2_decay_with_linear_warmup_op.h | 125 ------------------ .../pow2_decay_with_linear_warmup_op_xpu.cc | 84 ------------ .../pow2_decay_with_linear_warmup_kernel.cc} | 19 +-- .../pow2_decay_with_linear_warmup_kernel.cu | 25 ++++ ...ow2_decay_with_linear_warmup_kernel_impl.h | 110 +++++++++++++++ .../pow2_decay_with_linear_warmup_kernel.h | 31 +++++ .../pow2_decay_with_linear_warmup_kernel.cc | 71 ++++++++++ .../pow2_decay_with_linear_warmup_sig.cc | 30 +++++ 9 files changed, 277 insertions(+), 225 deletions(-) delete mode 100644 paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h delete mode 100644 paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op_xpu.cc rename paddle/{fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cu => phi/kernels/cpu/pow2_decay_with_linear_warmup_kernel.cc} (54%) create mode 100644 paddle/phi/kernels/gpu/pow2_decay_with_linear_warmup_kernel.cu create mode 100644 paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h create mode 100644 paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h create mode 100644 paddle/phi/kernels/xpu/pow2_decay_with_linear_warmup_kernel.cc create mode 100644 paddle/phi/ops/compat/pow2_decay_with_linear_warmup_sig.cc diff --git a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc index 8def9c961f7..105c698355e 100644 --- a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc +++ b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h" - #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/float16.h" @@ -78,12 +76,7 @@ When step_num > total_steps, lr = end_lr } // namespace paddle namespace ops = paddle::operators; -namespace plat = paddle::platform; REGISTER_OP_WITHOUT_GRADIENT(pow2_decay_with_linear_warmup, ops::Pow2DecayWithLinearWarmupOp, ops::Pow2DecayWithLinearWarmupOpMaker); -REGISTER_OP_CPU_KERNEL( - pow2_decay_with_linear_warmup, - ops::Pow2DecayWithLinearWarmupOpKernel, - ops::Pow2DecayWithLinearWarmupOpKernel); diff --git a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h deleted file mode 100644 index 8f3be79cd4c..00000000000 --- a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/core/macros.h" - -namespace paddle { -namespace operators { - -template -struct Pow2DecayWithLinearWarmupFunctor { - template - using RestrictPtr = U *PADDLE_RESTRICT; - - public: - HOSTDEVICE Pow2DecayWithLinearWarmupFunctor(RestrictPtr lr, - RestrictPtr step, - size_t warmup_steps, - size_t total_steps, - AttrT base_lr, - AttrT end_lr) - : lr_(lr), - step_(step), - warmup_steps_(warmup_steps), - total_steps_(total_steps), - base_lr_(base_lr), - end_lr_(end_lr) {} - - HOSTDEVICE void operator()(size_t) const { - size_t step = static_cast(*step_) + 1; - *step_ = static_cast(step); - if (step <= warmup_steps_) { - auto new_lr = static_cast(step) / warmup_steps_ * base_lr_; - *lr_ = static_cast(new_lr); - } else if (step < total_steps_) { - auto factor = 1 - static_cast(step - warmup_steps_) / - (total_steps_ - warmup_steps_); - auto new_lr = - static_cast(base_lr_ - end_lr_) * (factor * factor) + end_lr_; - *lr_ = static_cast(new_lr); - } else { - *lr_ = static_cast(end_lr_); - } - } - - private: - RestrictPtr lr_; - RestrictPtr step_; - size_t warmup_steps_; - size_t total_steps_; - AttrT base_lr_; - AttrT end_lr_; -}; - -template -class Pow2DecayWithLinearWarmupOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const { - const auto *lr = ctx.Input("LearningRate"); - const auto *step = ctx.Input("Step"); - auto *lr_out = ctx.Output("LearningRateOut"); - auto *step_out = ctx.Output("StepOut"); - PADDLE_ENFORCE_EQ( - lr, - lr_out, - platform::errors::InvalidArgument("Input(LearningRate) and " - "Output(LearningRateOut) " - "must be the same.")); - PADDLE_ENFORCE_NOT_NULL(lr, - platform::errors::InvalidArgument( - "Input(LearingRate) should not be nullptr.")); - PADDLE_ENFORCE_EQ(step, - step_out, - platform::errors::InvalidArgument( - "Input(Step) and Output(StepOut) must be the same.")); - PADDLE_ENFORCE_NOT_NULL(step, - platform::errors::InvalidArgument( - "Input(Step) should not be nullptr.")); - PADDLE_ENFORCE_EQ( - step->IsInitialized(), - true, - platform::errors::InvalidArgument("Input(Step) must be initialized.")); - - auto warmup_steps = static_cast(ctx.Attr("warmup_steps")); - auto total_steps = static_cast(ctx.Attr("total_steps")); - PADDLE_ENFORCE_LE(warmup_steps, - total_steps, - platform::errors::InvalidArgument( - "warmup_steps must not be larger than total_steps.")); - auto base_lr = ctx.Attr("base_lr"); - auto end_lr = ctx.Attr("end_lr"); - - auto *lr_data = lr_out->data(); - auto *step_data = step_out->data(); - auto &dev_ctx = ctx.template device_context(); - platform::ForRange for_range(dev_ctx, 1); - using AttrT = double; - Pow2DecayWithLinearWarmupFunctor functor( - lr_data, - step_data, - warmup_steps, - total_steps, - static_cast(base_lr), - static_cast(end_lr)); - for_range(functor); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op_xpu.cc b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op_xpu.cc deleted file mode 100644 index 543a4634c6d..00000000000 --- a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op_xpu.cc +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifdef PADDLE_WITH_XPU -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h" -#include "paddle/fluid/platform/macros.h" -#include "paddle/phi/backends/xpu/enforce_xpu.h" - -namespace paddle { -namespace operators { - -template -class Pow2DecayWithLinearWarmupXPUOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const { - const auto *lr = ctx.Input("LearningRate"); - const auto *step = ctx.Input("Step"); - auto *lr_out = ctx.Output("LearningRateOut"); - auto *step_out = ctx.Output("StepOut"); - PADDLE_ENFORCE_EQ( - lr, - lr_out, - platform::errors::InvalidArgument("Input(LearningRate) and " - "Output(LearningRateOut) " - "must be the same.")); - PADDLE_ENFORCE_NOT_NULL(lr, - platform::errors::InvalidArgument( - "Input(LearingRate) should not be nullptr.")); - PADDLE_ENFORCE_EQ(step, - step_out, - platform::errors::InvalidArgument( - "Input(Step) and Output(StepOut) must be the same.")); - PADDLE_ENFORCE_NOT_NULL(step, - platform::errors::InvalidArgument( - "Input(Step) should not be nullptr.")); - PADDLE_ENFORCE_EQ( - step->IsInitialized(), - true, - platform::errors::InvalidArgument("Input(Step) must be initialized.")); - - auto warmup_steps = static_cast(ctx.Attr("warmup_steps")); - auto total_steps = static_cast(ctx.Attr("total_steps")); - PADDLE_ENFORCE_LE(warmup_steps, - total_steps, - platform::errors::InvalidArgument( - "warmup_steps must not be larger than total_steps.")); - auto base_lr = ctx.Attr("base_lr"); - auto end_lr = ctx.Attr("end_lr"); - - auto *lr_data = lr_out->data(); - auto *step_data = step_out->data(); - auto &dev_ctx = ctx.template device_context(); - int r = xpu::pow2_decay_with_linear_warmup(dev_ctx.x_context(), - lr_data, - step_data, - warmup_steps, - total_steps, - base_lr, - end_lr); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "pow2_decay_with_linear_warmup"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL(pow2_decay_with_linear_warmup, - ops::Pow2DecayWithLinearWarmupXPUOpKernel); -#endif diff --git a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cu b/paddle/phi/kernels/cpu/pow2_decay_with_linear_warmup_kernel.cc similarity index 54% rename from paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cu rename to paddle/phi/kernels/cpu/pow2_decay_with_linear_warmup_kernel.cc index 6419e524f71..9bdf6bb2c86 100644 --- a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cu +++ b/paddle/phi/kernels/cpu/pow2_decay_with_linear_warmup_kernel.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h" -#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h" -namespace ops = paddle::operators; -namespace plat = paddle::platform; +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h" -REGISTER_OP_CUDA_KERNEL( - pow2_decay_with_linear_warmup, - ops::Pow2DecayWithLinearWarmupOpKernel, - ops::Pow2DecayWithLinearWarmupOpKernel); +PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup, + CPU, + ALL_LAYOUT, + phi::Pow2DecayWithLinearWarmupKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/pow2_decay_with_linear_warmup_kernel.cu b/paddle/phi/kernels/gpu/pow2_decay_with_linear_warmup_kernel.cu new file mode 100644 index 00000000000..57162bc7fb2 --- /dev/null +++ b/paddle/phi/kernels/gpu/pow2_decay_with_linear_warmup_kernel.cu @@ -0,0 +1,25 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h" + +PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup, + GPU, + ALL_LAYOUT, + phi::Pow2DecayWithLinearWarmupKernel, + float, + double) {} diff --git a/paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h b/paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h new file mode 100644 index 00000000000..bbca911b404 --- /dev/null +++ b/paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h @@ -0,0 +1,110 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/macros.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +namespace phi { + +template +struct Pow2DecayWithLinearWarmupFunctor { + template + using RestrictPtr = U* PADDLE_RESTRICT; + + public: + HOSTDEVICE Pow2DecayWithLinearWarmupFunctor(RestrictPtr lr, + RestrictPtr step, + size_t warmup_steps, + size_t total_steps, + AttrT base_lr, + AttrT end_lr) + : lr_(lr), + step_(step), + warmup_steps_(warmup_steps), + total_steps_(total_steps), + base_lr_(base_lr), + end_lr_(end_lr) {} + + HOSTDEVICE void operator()(size_t) const { + size_t step = static_cast(*step_) + 1; + *step_ = static_cast(step); + if (step <= warmup_steps_) { + auto new_lr = static_cast(step) / warmup_steps_ * base_lr_; + *lr_ = static_cast(new_lr); + } else if (step < total_steps_) { + auto factor = 1 - static_cast(step - warmup_steps_) / + (total_steps_ - warmup_steps_); + auto new_lr = + static_cast(base_lr_ - end_lr_) * (factor * factor) + end_lr_; + *lr_ = static_cast(new_lr); + } else { + *lr_ = static_cast(end_lr_); + } + } + + private: + RestrictPtr lr_; + RestrictPtr step_; + size_t warmup_steps_; + size_t total_steps_; + AttrT base_lr_; + AttrT end_lr_; +}; + +template +void Pow2DecayWithLinearWarmupKernel(const Context& dev_ctx, + const DenseTensor& lr, + const DenseTensor& step, + int64_t warmup_steps, + int64_t total_steps, + float base_lr, + float end_lr, + DenseTensor* lr_out, + DenseTensor* step_out) { + PADDLE_ENFORCE_EQ(&lr, + lr_out, + phi::errors::InvalidArgument("Input(LearningRate) and " + "Output(LearningRateOut) " + "must be the same.")); + PADDLE_ENFORCE_EQ(&step, + step_out, + phi::errors::InvalidArgument( + "Input(Step) and Output(StepOut) must be the same.")); + PADDLE_ENFORCE_EQ( + step.IsInitialized(), + true, + phi::errors::InvalidArgument("Input(Step) must be initialized.")); + + PADDLE_ENFORCE_LE(warmup_steps, + total_steps, + phi::errors::InvalidArgument( + "warmup_steps must not be larger than total_steps.")); + + auto* lr_data = lr_out->data(); + auto* step_data = step_out->data(); + phi::funcs::ForRange for_range(dev_ctx, 1); + using AttrT = double; + Pow2DecayWithLinearWarmupFunctor functor( + lr_data, + step_data, + static_cast(warmup_steps), + static_cast(total_steps), + static_cast(base_lr), + static_cast(end_lr)); + for_range(functor); +} +} // namespace phi diff --git a/paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h b/paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h new file mode 100644 index 00000000000..549e8d0add1 --- /dev/null +++ b/paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void Pow2DecayWithLinearWarmupKernel(const Context& dev_ctx, + const DenseTensor& lr, + const DenseTensor& step, + int64_t warmup_steps, + int64_t total_steps, + float base_lr, + float end_lr, + DenseTensor* lr_out, + DenseTensor* step_out); +} // namespace phi diff --git a/paddle/phi/kernels/xpu/pow2_decay_with_linear_warmup_kernel.cc b/paddle/phi/kernels/xpu/pow2_decay_with_linear_warmup_kernel.cc new file mode 100644 index 00000000000..8661f2ac2cb --- /dev/null +++ b/paddle/phi/kernels/xpu/pow2_decay_with_linear_warmup_kernel.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/macros.h" + +namespace phi { + +template +void Pow2DecayWithLinearWarmupKernel(const Context& dev_ctx, + const DenseTensor& lr, + const DenseTensor& step, + int64_t warmup_steps, + int64_t total_steps, + float base_lr, + float end_lr, + DenseTensor* lr_out, + DenseTensor* step_out) { + PADDLE_ENFORCE_EQ(&lr, + lr_out, + phi::errors::InvalidArgument("Input(LearningRate) and " + "Output(LearningRateOut) " + "must be the same.")); + PADDLE_ENFORCE_EQ(&step, + step_out, + phi::errors::InvalidArgument( + "Input(Step) and Output(StepOut) must be the same.")); + PADDLE_ENFORCE_EQ( + step.IsInitialized(), + true, + phi::errors::InvalidArgument("Input(Step) must be initialized.")); + + PADDLE_ENFORCE_LE(warmup_steps, + total_steps, + phi::errors::InvalidArgument( + "warmup_steps must not be larger than total_steps.")); + + auto* lr_data = lr_out->data(); + auto* step_data = step_out->data(); + int r = xpu::pow2_decay_with_linear_warmup(dev_ctx.x_context(), + lr_data, + step_data, + static_cast(warmup_steps), + static_cast(total_steps), + base_lr, + end_lr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "pow2_decay_with_linear_warmup"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup, + XPU, + ALL_LAYOUT, + phi::Pow2DecayWithLinearWarmupKernel, + float) {} diff --git a/paddle/phi/ops/compat/pow2_decay_with_linear_warmup_sig.cc b/paddle/phi/ops/compat/pow2_decay_with_linear_warmup_sig.cc new file mode 100644 index 00000000000..9a3323d90bf --- /dev/null +++ b/paddle/phi/ops/compat/pow2_decay_with_linear_warmup_sig.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature Pow2DecayWithLinearWarmupOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("pow2_decay_with_linear_warmup", + {"LearningRate", "Step"}, + {"warmup_steps", "total_steps", "base_lr", "end_lr"}, + {"LearningRateOut", "StepOut"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(pow2_decay_with_linear_warmup, + phi::Pow2DecayWithLinearWarmupOpArgumentMapping); -- GitLab