From d15b490ad6ce251a1c3ef1386f73e7da824a807c Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Thu, 14 Jul 2022 10:43:44 +0800 Subject: [PATCH] [operator migration] Migrate merged momentum cpu/gpu kernels (#44300) --- .../optimizers/merged_momentum_op.cc | 6 +- .../operators/optimizers/merged_momentum_op.h | 370 ---------------- .../optimizers/merged_momentum_op_mlu.cc | 8 +- .../optimizers/merged_momentum_op_npu.cc | 7 +- .../pow2_decay_with_linear_warmup_op.h | 2 +- paddle/fluid/platform/macros.h | 6 - paddle/phi/core/macros.h | 6 + .../kernels/cpu/merged_momentum_kernel.cc} | 20 +- .../phi/kernels/gpu/merged_momentum_kernel.cu | 25 ++ .../phi/kernels/impl/merged_momentum_impl.h | 400 ++++++++++++++++++ paddle/phi/kernels/merged_momentum_kernel.h | 42 ++ paddle/phi/ops/compat/merged_momentum_sig.cc | 40 ++ 12 files changed, 538 insertions(+), 394 deletions(-) delete mode 100644 paddle/fluid/operators/optimizers/merged_momentum_op.h rename paddle/{fluid/operators/optimizers/merged_momentum_op.cu => phi/kernels/cpu/merged_momentum_kernel.cc} (55%) create mode 100644 paddle/phi/kernels/gpu/merged_momentum_kernel.cu create mode 100644 paddle/phi/kernels/impl/merged_momentum_impl.h create mode 100644 paddle/phi/kernels/merged_momentum_kernel.h create mode 100644 paddle/phi/ops/compat/merged_momentum_sig.cc diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op.cc b/paddle/fluid/operators/optimizers/merged_momentum_op.cc index e6aec5cec9e..220c0be9ddf 100644 --- a/paddle/fluid/operators/optimizers/merged_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/merged_momentum_op.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -103,7 +103,3 @@ namespace plat = paddle::platform; REGISTER_OP_WITHOUT_GRADIENT(merged_momentum, ops::MergedMomentumOp, ops::MergedMomentumOpMaker); - -REGISTER_OP_CPU_KERNEL(merged_momentum, - ops::MergedMomentumOpKernel, - ops::MergedMomentumOpKernel); diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op.h b/paddle/fluid/operators/optimizers/merged_momentum_op.h deleted file mode 100644 index 77c8f3dbd35..00000000000 --- a/paddle/fluid/operators/optimizers/merged_momentum_op.h +++ /dev/null @@ -1,370 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/amp/fp16_type_traits.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/fluid/platform/macros.h" -#include "paddle/phi/kernels/impl/momentum_kernel_impl.h" - -namespace paddle { -namespace operators { - -template -using MultiPrecisionType = typename details::MPTypeTrait::Type; - -template -struct MergedMomentumMasterParams { - MT *PADDLE_RESTRICT master_params[kParamNum]; - - HOSTDEVICE MT *MasterParam(size_t idx) const { return master_params[idx]; } - HOSTDEVICE void SetMasterParam(size_t idx, MT *p) { master_params[idx] = p; } -}; - -template -struct MergedMomentumMasterParams { - HOSTDEVICE constexpr MT *MasterParam(size_t) const { return nullptr; } - HOSTDEVICE constexpr void SetMasterParam(size_t, MT *) {} -}; - -template -struct MergedMomentumKernelParam - : public MergedMomentumMasterParams { - static constexpr auto N = kParamNum; - size_t sizes[N]; - T *PADDLE_RESTRICT params[N]; - const T *PADDLE_RESTRICT grads[N]; - MT *PADDLE_RESTRICT velocitys[N]; - const MultiPrecisionType *PADDLE_RESTRICT lr; - MT mu; - MT rescale_grad; - uint32_t param_num; - - HOSTDEVICE void operator()(size_t i) const { - const MT lr_val = static_cast(*lr); - for (uint32_t idx = 0; idx < param_num; ++idx) { - auto size = sizes[idx]; - if (i >= size) continue; - - auto param_p = params[idx]; - auto grad_p = grads[idx]; - auto velocity_p = velocitys[idx]; - auto master_param_p = this->MasterParam(idx); - - const MT param = - master_param_p ? master_param_p[i] : static_cast(param_p[i]); - const MT grad = static_cast(grad_p[i]) * rescale_grad; - const MT velocity = velocity_p[i]; - const MT velocity_out = velocity * mu + grad; - const MT param_out = param - lr_val * velocity_out; - velocity_p[i] = velocity_out; - param_p[i] = static_cast(param_out); - if (master_param_p) { - master_param_p[i] = param_out; - } - } - } -}; - -template -class MergedMomentumOpKernel : public framework::OpKernel { - using MPType = typename operators::details::MPTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext &ctx) const override { - const bool multi_precision = ctx.Attr("multi_precision"); - if (multi_precision) { - InnerCompute(ctx, multi_precision); - } else { - InnerCompute(ctx, multi_precision); - } - } - - private: - template - void InnerCompute(const framework::ExecutionContext &ctx, - const bool multi_precision) const { - auto params = ctx.MultiInput("Param"); - auto params_out = ctx.MultiOutput("ParamOut"); - size_t n = params.size(); - PADDLE_ENFORCE_EQ(n, - params_out.size(), - platform::errors::InvalidArgument( - "The size of Output(ParamOut) must be equal to " - "Input(Param), but got the size of Output(ParamOut) " - "is %d, the size of Input(Param) is %d.", - params_out.size(), - n)); - for (size_t i = 0; i < n; ++i) { - PADDLE_ENFORCE_EQ(params[i], - params_out[i], - platform::errors::InvalidArgument( - "The size of Input(Param) and Output(ParamOut) " - "must be the same Tensors.")); - } - - auto grads = ctx.MultiInput("Grad"); - PADDLE_ENFORCE_EQ( - n, - grads.size(), - platform::errors::InvalidArgument( - "The size of Input(Grad) must be equal to Input(Param), but got " - "the size of Input(Grad) is %d, the size of Input(Param) is %d.", - grads.size(), - n)); - - auto velocitys = ctx.MultiInput("Velocity"); - PADDLE_ENFORCE_EQ(n, - velocitys.size(), - platform::errors::InvalidArgument( - "The size of Input(Velocity) must be equal to " - "Input(Param), but got the size of Input(Velocity) " - "is %d, the size of Input(Param) is %d.", - velocitys.size(), - n)); - - auto velocitys_out = ctx.MultiOutput("VelocityOut"); - PADDLE_ENFORCE_EQ( - n, - velocitys_out.size(), - platform::errors::InvalidArgument( - "The size of Output(VelocityOut) must be " - "equal to Input(Param), but got the size of Output(VelocityOut) is " - "%d, the size of Input(Param) is %d.", - velocitys_out.size(), - n)); - for (size_t i = 0; i < n; ++i) { - PADDLE_ENFORCE_EQ(velocitys[i], - velocitys_out[i], - platform::errors::InvalidArgument( - "Input(Velocity) and Output(VelocityOut) must be " - "the same Tensors.")); - } - - auto master_params = ctx.MultiInput("MasterParam"); - auto master_params_out = - ctx.MultiOutput("MasterParamOut"); - if (multi_precision) { - PADDLE_ENFORCE_EQ( - n, - master_params.size(), - platform::errors::InvalidArgument( - "The size of Input(MasterParam) must be " - "equal to Input(Param), but got the size of Input(MasterParam) " - "is %d, the size of Input(Param) is %d.", - master_params.size(), - n)); - PADDLE_ENFORCE_EQ( - n, - master_params_out.size(), - platform::errors::InvalidArgument( - "The size of Output(MasterParamOut) must be equal to " - "Input(MasterParam), but got the size of Output(MasterParamOut) " - "is %d, the size of Input(Param) is %d.", - master_params_out.size(), - n)); - for (size_t i = 0; i < n; ++i) { - PADDLE_ENFORCE_EQ(master_params[i], - master_params_out[i], - platform::errors::InvalidArgument( - "Input(MasterParam) and Output(MasterParamOut) " - "must be the same Tensors.")); - PADDLE_ENFORCE_NOT_NULL(master_params[i], - platform::errors::InvalidArgument( - "Input(MasterParam) must be provided when " - "multi_precision=True.")); - } - } else { - master_params.clear(); - master_params_out.clear(); - } - - auto mu = ctx.Attr("mu"); - auto rescale_grad = ctx.Attr("rescale_grad"); - auto lrs = ctx.MultiInput("LearningRate"); - if (lrs.size() != 1) { - PADDLE_ENFORCE_EQ( - n, - lrs.size(), - platform::errors::InvalidArgument( - "If the size of Input(LearningRate) is not 1, the size of " - "Input(LearningRate) must be " - "equal to Input(Param), but got the size of Input(LearningRate) " - "is %d, the size of Input(Param) is %d.", - lrs.size(), - n)); - } - auto use_nesterov = ctx.Attr("use_nesterov"); - auto regularization_methods = - ctx.Attr>("regularization_method"); - auto regularization_coeffs = - ctx.Attr>("regularization_coeff"); - if (regularization_methods.size() != 0) { - PADDLE_ENFORCE_EQ( - n, - regularization_methods.size(), - platform::errors::InvalidArgument( - "The size of Attr(regularization_method) must be equal " - "to Input(Param), but got the size of " - "Attr(regularization_method) is %d, the size of Input(Param) is " - "%d.", - regularization_methods.size(), - n)); - PADDLE_ENFORCE_EQ( - n, - regularization_coeffs.size(), - platform::errors::InvalidArgument( - "The size of Attr(regularization_coeff) must be equal " - "to Input(Param), but got the size of Attr(regularization_coeff) " - "is %d, the size of Input(Param) is %d.", - regularization_coeffs.size(), - n)); - } - - VLOG(5) << "use_nesterov: " << use_nesterov - << ", regularization_methods.size(): " - << regularization_methods.size() - << ", regularization_coeffs.size(): " - << regularization_coeffs.size(); - - auto &dev_ctx = ctx.template device_context(); - - if (lrs.size() == 1 && use_nesterov == false && - regularization_methods.size() == 0) { -#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \ - MergedMomentumKernelParam kernel_params; \ - constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \ - size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \ - kernel_params.mu = static_cast(mu); \ - kernel_params.rescale_grad = static_cast(rescale_grad); \ - kernel_params.lr = lrs[0]->data(); \ - for (size_t i = 0; i < kernel_num; ++i) { \ - size_t start = i * kMaxMergedNum; \ - size_t end = std::min((i + 1) * kMaxMergedNum, n); \ - kernel_params.param_num = static_cast(end - start); \ - size_t max_size = 0; \ - for (size_t j = 0; j < kernel_params.param_num; ++j) { \ - auto size = static_cast(params_out[j + start]->numel()); \ - max_size = std::max(max_size, size); \ - kernel_params.sizes[j] = size; \ - kernel_params.params[j] = params_out[j + start]->data(); \ - kernel_params.grads[j] = grads[j + start]->data(); \ - kernel_params.velocitys[j] = velocitys_out[j + start]->data(); \ - kernel_params.SetMasterParam( \ - j, \ - kMultiPrecision ? master_params_out[j + start]->data() \ - : nullptr); \ - } \ - platform::ForRange for_range(dev_ctx, max_size); \ - for_range(kernel_params); \ - VLOG(10) << "Launch MergedMomentum kernel " << i << " " \ - << kernel_params.param_num; \ - } - if (multi_precision) { - PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true); - } else { - PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false); - } -#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL - } else { - for (size_t idx = 0; idx < n; idx++) { - phi::RegularizationType regularization_flag = - regularization_methods.size() > 0 && - regularization_methods[idx] == "l2_decay" - ? phi::RegularizationType::kL2DECAY - : phi::RegularizationType::kNONE; - - MT regularization_coeff = static_cast(0.0); - if (regularization_coeffs.size() != 0) { - regularization_coeff = static_cast(regularization_coeffs[idx]); - } - auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0]; - - const MT *master_in_data = - multi_precision ? master_params[idx]->data() : nullptr; - MT *master_out_data = - multi_precision ? master_params_out[idx]->data() : nullptr; - if (platform::is_cpu_place(ctx.GetPlace())) { - phi::CPUDenseMomentumFunctor functor; - functor(params[idx], - grads[idx], - velocitys[idx], - lr_temp, - static_cast(mu), - use_nesterov, - regularization_flag, - regularization_coeff, - params_out[idx], - velocitys_out[idx]); - VLOG(10) << "Launch MergedMomentum cpu kernel."; - } else if (platform::is_gpu_place(ctx.GetPlace())) { - platform::ForRange for_range( - static_cast(ctx.device_context()), - params[idx]->numel()); -#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \ - phi::DenseMomentumFunctor functor( \ - params[idx]->data(), \ - grads[idx]->data(), \ - velocitys[idx]->data(), \ - lr_temp->data(), \ - master_in_data, \ - static_cast(mu), \ - static_cast(rescale_grad), \ - params[idx]->numel(), \ - regularization_coeff, \ - params_out[idx]->data(), \ - velocitys_out[idx]->data(), \ - master_out_data); \ - for_range(functor); - if (use_nesterov) { - if (regularization_flag == phi::RegularizationType::kL2DECAY) { - PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( - phi::UseNesterov, phi::RegularizationType::kL2DECAY); - VLOG(10) - << "Launch MergedMomentum gpu kernel use_nesterov kL2DECAY."; - } else { - PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( - phi::UseNesterov, phi::RegularizationType::kNONE); - VLOG(10) - << "Launch MergedMomentum gpu kernel use_nesterov kNONE."; - } - } else { - if (regularization_flag == phi::RegularizationType::kL2DECAY) { - PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( - phi::NoNesterov, phi::RegularizationType::kL2DECAY); - VLOG(10) - << "Launch MergedMomentum gpu kernel no_nesterov kL2DECAY."; - } else { - PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( - phi::NoNesterov, phi::RegularizationType::kNONE); - VLOG(10) << "Launch MergedMomentum gpu kernel no_nesterov kNONE."; - } - } - } - } - VLOG(10) - << "Launch MergedMomentum kernel with multi_lr and regularization."; - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc b/paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc index 32af057ecd4..90faf8f389a 100644 --- a/paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc +++ b/paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc @@ -12,8 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h" -#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" +#include "paddle/fluid/platform/for_range.h" +#include "paddle/fluid/platform/macros.h" +#include "paddle/phi/kernels/impl/momentum_kernel_impl.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op_npu.cc b/paddle/fluid/operators/optimizers/merged_momentum_op_npu.cc index ff131138e8a..38479d6dba2 100644 --- a/paddle/fluid/operators/optimizers/merged_momentum_op_npu.cc +++ b/paddle/fluid/operators/optimizers/merged_momentum_op_npu.cc @@ -12,8 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" +#include "paddle/fluid/platform/for_range.h" +#include "paddle/fluid/platform/macros.h" #include "paddle/phi/kernels/impl/momentum_kernel_impl.h" namespace paddle { diff --git a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h index 60274f6b667..d3d2e48fdcd 100644 --- a/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h +++ b/paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h @@ -17,7 +17,7 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/for_range.h" -#include "paddle/fluid/platform/macros.h" +#include "paddle/phi/core/macros.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/platform/macros.h b/paddle/fluid/platform/macros.h index 9eede99b7b7..2ea58a7bb0c 100644 --- a/paddle/fluid/platform/macros.h +++ b/paddle/fluid/platform/macros.h @@ -29,9 +29,3 @@ limitations under the License. */ #define FLT_MAX __FLT_MAX__ #endif // __FLT_MAX__ #endif // PADDLE_WITH_MUSL - -#if defined(__NVCC__) || defined(__HIPCC__) -#define PADDLE_RESTRICT __restrict__ -#else -#define PADDLE_RESTRICT -#endif diff --git a/paddle/phi/core/macros.h b/paddle/phi/core/macros.h index 8049d027a77..e48f7342e45 100644 --- a/paddle/phi/core/macros.h +++ b/paddle/phi/core/macros.h @@ -53,4 +53,10 @@ namespace phi { #define PD_CONCATENATE2(arg1, arg2) arg1##arg2 #define PD_EXPAND(x) x +#if defined(__NVCC__) || defined(__HIPCC__) +#define PADDLE_RESTRICT __restrict__ +#else +#define PADDLE_RESTRICT +#endif + } // namespace phi diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op.cu b/paddle/phi/kernels/cpu/merged_momentum_kernel.cc similarity index 55% rename from paddle/fluid/operators/optimizers/merged_momentum_op.cu rename to paddle/phi/kernels/cpu/merged_momentum_kernel.cc index 7e4bbd98079..0751711ef64 100644 --- a/paddle/fluid/operators/optimizers/merged_momentum_op.cu +++ b/paddle/phi/kernels/cpu/merged_momentum_kernel.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/merged_momentum_impl.h" -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL( - merged_momentum, - ops::MergedMomentumOpKernel, - ops::MergedMomentumOpKernel, - ops::MergedMomentumOpKernel); +PD_REGISTER_KERNEL(merged_momentum, + CPU, + ALL_LAYOUT, + phi::MergedMomentumKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/merged_momentum_kernel.cu b/paddle/phi/kernels/gpu/merged_momentum_kernel.cu new file mode 100644 index 00000000000..c6883caecd1 --- /dev/null +++ b/paddle/phi/kernels/gpu/merged_momentum_kernel.cu @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/merged_momentum_impl.h" + +PD_REGISTER_KERNEL(merged_momentum, + GPU, + ALL_LAYOUT, + phi::MergedMomentumKernel, + phi::dtype::float16, + float, + double) {} diff --git a/paddle/phi/kernels/impl/merged_momentum_impl.h b/paddle/phi/kernels/impl/merged_momentum_impl.h new file mode 100644 index 00000000000..2972a93d108 --- /dev/null +++ b/paddle/phi/kernels/impl/merged_momentum_impl.h @@ -0,0 +1,400 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/macros.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/impl/momentum_kernel_impl.h" +#include "paddle/phi/kernels/merged_momentum_kernel.h" + +namespace phi { + +template +using MultiPrecisionType = typename phi::dtype::MPTypeTrait::Type; + +template +struct MergedMomentumMasterParams { + MT *PADDLE_RESTRICT master_params[kParamNum]; + + HOSTDEVICE MT *MasterParam(size_t idx) const { return master_params[idx]; } + HOSTDEVICE void SetMasterParam(size_t idx, MT *p) { master_params[idx] = p; } +}; + +template +struct MergedMomentumMasterParams { + HOSTDEVICE constexpr MT *MasterParam(size_t) const { return nullptr; } + HOSTDEVICE constexpr void SetMasterParam(size_t, MT *) {} +}; + +template +struct MergedMomentumKernelParam + : public MergedMomentumMasterParams { + static constexpr auto N = kParamNum; + size_t sizes[N]; + T *PADDLE_RESTRICT params[N]; + const T *PADDLE_RESTRICT grads[N]; + MT *PADDLE_RESTRICT velocitys[N]; + const MultiPrecisionType *PADDLE_RESTRICT lr; + MT mu; + MT rescale_grad; + uint32_t param_num; + + HOSTDEVICE void operator()(size_t i) const { + const MT lr_val = static_cast(*lr); + for (uint32_t idx = 0; idx < param_num; ++idx) { + auto size = sizes[idx]; + if (i >= size) continue; + + auto param_p = params[idx]; + auto grad_p = grads[idx]; + auto velocity_p = velocitys[idx]; + auto master_param_p = this->MasterParam(idx); + + const MT param = + master_param_p ? master_param_p[i] : static_cast(param_p[i]); + const MT grad = static_cast(grad_p[i]) * rescale_grad; + const MT velocity = velocity_p[i]; + const MT velocity_out = velocity * mu + grad; + const MT param_out = param - lr_val * velocity_out; + velocity_p[i] = velocity_out; + param_p[i] = static_cast(param_out); + if (master_param_p) { + master_param_p[i] = param_out; + } + } + } +}; + +template +void MergedMomentumInnerCompute( + const Context &ctx, + const std::vector ¶ms, + const std::vector &grads, + const std::vector &velocitys, + const std::vector &lrs, + const paddle::optional> &master_params_opt, + float mu, + bool use_nesterov, + const std::vector ®ularization_methods, + const std::vector ®ularization_coeffs, + float rescale_grad, + const bool multi_precision, + std::vector params_out, + std::vector velocitys_out, + std::vector master_params_out) { + size_t n = params.size(); + PADDLE_ENFORCE_EQ(n, + params_out.size(), + phi::errors::InvalidArgument( + "The size of Output(ParamOut) must be equal to " + "Input(Param), but got the size of Output(ParamOut) " + "is %d, the size of Input(Param) is %d.", + params_out.size(), + n)); + for (size_t i = 0; i < n; ++i) { + PADDLE_ENFORCE_EQ( + params[i], + params_out[i], + phi::errors::InvalidArgument("Input(Param) and Output(ParamOut) " + "must be the same Tensors.")); + } + + PADDLE_ENFORCE_EQ( + n, + grads.size(), + phi::errors::InvalidArgument( + "The size of Input(Grad) must be equal to Input(Param), but got " + "the size of Input(Grad) is %d, the size of Input(Param) is %d.", + grads.size(), + n)); + + PADDLE_ENFORCE_EQ(n, + velocitys.size(), + phi::errors::InvalidArgument( + "The size of Input(Velocity) must be equal to " + "Input(Param), but got the size of Input(Velocity) " + "is %d, the size of Input(Param) is %d.", + velocitys.size(), + n)); + + PADDLE_ENFORCE_EQ( + n, + velocitys_out.size(), + phi::errors::InvalidArgument( + "The size of Output(VelocityOut) must be " + "equal to Input(Param), but got the size of Output(VelocityOut) is " + "%d, the size of Input(Param) is %d.", + velocitys_out.size(), + n)); + for (size_t i = 0; i < n; ++i) { + PADDLE_ENFORCE_EQ(velocitys[i], + velocitys_out[i], + phi::errors::InvalidArgument( + "Input(Velocity) and Output(VelocityOut) must be " + "the same Tensors.")); + } + + if (multi_precision) { + auto master_params = master_params_opt.get(); + PADDLE_ENFORCE_EQ( + n, + master_params.size(), + phi::errors::InvalidArgument( + "The size of Input(MasterParam) must be " + "equal to Input(Param), but got the size of Input(MasterParam) " + "is %d, the size of Input(Param) is %d.", + master_params.size(), + n)); + PADDLE_ENFORCE_EQ( + n, + master_params_out.size(), + phi::errors::InvalidArgument( + "The size of Output(MasterParamOut) must be equal to " + "Input(MasterParam), but got the size of Output(MasterParamOut) " + "is %d, the size of Input(Param) is %d.", + master_params_out.size(), + n)); + for (size_t i = 0; i < n; ++i) { + PADDLE_ENFORCE_EQ(master_params[i], + master_params_out[i], + phi::errors::InvalidArgument( + "Input(MasterParam) and Output(MasterParamOut) " + "must be the same Tensors.")); + PADDLE_ENFORCE_NOT_NULL(master_params[i], + phi::errors::InvalidArgument( + "Input(MasterParam) must be provided when " + "multi_precision=True.")); + } + } else { + master_params_out.clear(); + } + + if (lrs.size() != 1) { + PADDLE_ENFORCE_EQ( + n, + lrs.size(), + phi::errors::InvalidArgument( + "If the size of Input(LearningRate) is not 1, the size of " + "Input(LearningRate) must be " + "equal to Input(Param), but got the size of Input(LearningRate) " + "is %d, the size of Input(Param) is %d.", + lrs.size(), + n)); + } + if (regularization_methods.size() != 0) { + PADDLE_ENFORCE_EQ( + n, + regularization_methods.size(), + phi::errors::InvalidArgument( + "The size of Attr(regularization_method) must be equal " + "to Input(Param), but got the size of " + "Attr(regularization_method) is %d, the size of Input(Param) is " + "%d.", + regularization_methods.size(), + n)); + PADDLE_ENFORCE_EQ( + n, + regularization_coeffs.size(), + phi::errors::InvalidArgument( + "The size of Attr(regularization_coeff) must be equal " + "to Input(Param), but got the size of Attr(regularization_coeff) " + "is %d, the size of Input(Param) is %d.", + regularization_coeffs.size(), + n)); + } + + VLOG(5) << "use_nesterov: " << use_nesterov + << ", regularization_methods.size(): " + << regularization_methods.size() + << ", regularization_coeffs.size(): " + << regularization_coeffs.size(); + + if (lrs.size() == 1 && use_nesterov == false && + regularization_methods.size() == 0) { +#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \ + MergedMomentumKernelParam kernel_params; \ + constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \ + size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \ + kernel_params.mu = static_cast(mu); \ + kernel_params.rescale_grad = static_cast(rescale_grad); \ + kernel_params.lr = lrs[0]->data(); \ + for (size_t i = 0; i < kernel_num; ++i) { \ + size_t start = i * kMaxMergedNum; \ + size_t end = std::min((i + 1) * kMaxMergedNum, n); \ + kernel_params.param_num = static_cast(end - start); \ + size_t max_size = 0; \ + for (size_t j = 0; j < kernel_params.param_num; ++j) { \ + auto size = static_cast(params_out[j + start]->numel()); \ + max_size = std::max(max_size, size); \ + kernel_params.sizes[j] = size; \ + kernel_params.params[j] = params_out[j + start]->data(); \ + kernel_params.grads[j] = grads[j + start]->data(); \ + kernel_params.velocitys[j] = velocitys_out[j + start]->data(); \ + kernel_params.SetMasterParam( \ + j, \ + kMultiPrecision ? master_params_out[j + start]->data() \ + : nullptr); \ + } \ + phi::funcs::ForRange for_range(ctx, max_size); \ + for_range(kernel_params); \ + VLOG(10) << "Launch MergedMomentum kernel " << i << " " \ + << kernel_params.param_num; \ + } + if (multi_precision) { + PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true); + } else { + PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false); + } +#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL + } else { + for (size_t idx = 0; idx < n; idx++) { + phi::RegularizationType regularization_flag = + regularization_methods.size() > 0 && + regularization_methods[idx] == "l2_decay" + ? phi::RegularizationType::kL2DECAY + : phi::RegularizationType::kNONE; + + MT regularization_coeff = static_cast(0.0); + if (regularization_coeffs.size() != 0) { + regularization_coeff = static_cast(regularization_coeffs[idx]); + } + auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0]; + + const MT *master_in_data = + multi_precision ? master_params_opt.get()[idx]->data() : nullptr; + MT *master_out_data = + multi_precision ? master_params_out[idx]->data() : nullptr; + if (paddle::platform::is_cpu_place(ctx.GetPlace())) { + phi::CPUDenseMomentumFunctor functor; + functor(params[idx], + grads[idx], + velocitys[idx], + lr_temp, + static_cast(mu), + use_nesterov, + regularization_flag, + regularization_coeff, + params_out[idx], + velocitys_out[idx]); + VLOG(10) << "Launch MergedMomentum cpu kernel."; + } else if (paddle::platform::is_gpu_place(ctx.GetPlace())) { + phi::funcs::ForRange for_range( + static_cast(ctx), params[idx]->numel()); +#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \ + phi::DenseMomentumFunctor functor( \ + params[idx]->data(), \ + grads[idx]->data(), \ + velocitys[idx]->data(), \ + lr_temp->data(), \ + master_in_data, \ + static_cast(mu), \ + static_cast(rescale_grad), \ + params[idx]->numel(), \ + regularization_coeff, \ + params_out[idx]->data(), \ + velocitys_out[idx]->data(), \ + master_out_data); \ + for_range(functor); + if (use_nesterov) { + if (regularization_flag == phi::RegularizationType::kL2DECAY) { + PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( + phi::UseNesterov, phi::RegularizationType::kL2DECAY); + VLOG(10) + << "Launch MergedMomentum gpu kernel use_nesterov kL2DECAY."; + } else { + PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( + phi::UseNesterov, phi::RegularizationType::kNONE); + VLOG(10) << "Launch MergedMomentum gpu kernel use_nesterov kNONE."; + } + } else { + if (regularization_flag == phi::RegularizationType::kL2DECAY) { + PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( + phi::NoNesterov, phi::RegularizationType::kL2DECAY); + VLOG(10) + << "Launch MergedMomentum gpu kernel no_nesterov kL2DECAY."; + } else { + PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( + phi::NoNesterov, phi::RegularizationType::kNONE); + VLOG(10) << "Launch MergedMomentum gpu kernel no_nesterov kNONE."; + } + } + } + } + VLOG(10) + << "Launch MergedMomentum kernel with multi_lr and regularization."; + } +} + +template +void MergedMomentumKernel( + const Context &dev_ctx, + const std::vector ¶m, + const std::vector &grad, + const std::vector &velocity, + const std::vector &learning_rate, + const paddle::optional> &master_param, + float mu, + bool use_nesterov, + const std::vector ®ularization_method, + const std::vector ®ularization_coeff, + bool multi_precision, + float rescale_grad, + std::vector param_out, + std::vector velocity_out, + std::vector master_param_out) { + using MPType = typename phi::dtype::MPTypeTrait::Type; + if (multi_precision) { + MergedMomentumInnerCompute( + dev_ctx, + param, + grad, + velocity, + learning_rate, + master_param, + mu, + use_nesterov, + regularization_method, + regularization_coeff, + rescale_grad, + multi_precision, + param_out, + velocity_out, + master_param_out); + } else { + MergedMomentumInnerCompute(dev_ctx, + param, + grad, + velocity, + learning_rate, + master_param, + mu, + use_nesterov, + regularization_method, + regularization_coeff, + rescale_grad, + multi_precision, + param_out, + velocity_out, + master_param_out); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/merged_momentum_kernel.h b/paddle/phi/kernels/merged_momentum_kernel.h new file mode 100644 index 00000000000..9f21b988b4b --- /dev/null +++ b/paddle/phi/kernels/merged_momentum_kernel.h @@ -0,0 +1,42 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MergedMomentumKernel( + const Context& dev_ctx, + const std::vector& param, + const std::vector& grad, + const std::vector& velocity, + const std::vector& learning_rate, + const paddle::optional>& master_param, + float mu, + bool use_nesterov, + const std::vector& regularization_method, + const std::vector& regularization_coeff, + bool multi_precision, + float rescale_grad, + std::vector param_out, + std::vector velocity_out, + std::vector master_param_out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/merged_momentum_sig.cc b/paddle/phi/ops/compat/merged_momentum_sig.cc new file mode 100644 index 00000000000..3444d5e2d30 --- /dev/null +++ b/paddle/phi/ops/compat/merged_momentum_sig.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature MergedMomentumOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "merged_momentum", + {"Param", "Grad", "Velocity", "LearningRate", "MasterParam"}, + {"mu", + "use_nesterov", + "regularization_method", + "regularization_coeff", + "multi_precision", + "rescale_grad"}, + { + "ParamOut", + "VelocityOut", + "MasterParamOut", + }); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(merged_momentum, + phi::MergedMomentumOpArgumentMapping); -- GitLab