From d7807806dfab5cc2cdffc114c4dc5cabf6d04d48 Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Wed, 31 Aug 2022 14:17:40 +0800 Subject: [PATCH] Move XPU momentum to phi (#45565) * Move XPU momentum to phi, test=kunlun * Fix mu type, test=kunlun --- .../operators/optimizers/momentum_op_xpu.cc | 82 ------------------- paddle/phi/kernels/xpu/momentum_kernel.cc | 72 ++++++++++++++++ 2 files changed, 72 insertions(+), 82 deletions(-) delete mode 100644 paddle/fluid/operators/optimizers/momentum_op_xpu.cc create mode 100644 paddle/phi/kernels/xpu/momentum_kernel.cc diff --git a/paddle/fluid/operators/optimizers/momentum_op_xpu.cc b/paddle/fluid/operators/optimizers/momentum_op_xpu.cc deleted file mode 100644 index bd62c7acaa8..00000000000 --- a/paddle/fluid/operators/optimizers/momentum_op_xpu.cc +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#ifdef PADDLE_WITH_XPU -#include - -#include "paddle/fluid/operators/optimizers/sgd_op.h" -#include "paddle/fluid/platform/device/device_wrapper.h" -namespace paddle { -namespace operators { - -template -class MomentumOpXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - T mu = static_cast(ctx.Attr("mu")); - bool use_nesterov = ctx.Attr("use_nesterov"); - - auto learning_rate = ctx.Input("LearningRate"); - auto param = ctx.Input("Param"); - auto param_out = ctx.Output("ParamOut"); - auto* velocity = ctx.Input("Velocity"); - auto velocity_out = ctx.Output("VelocityOut"); - param_out->mutable_data(ctx.GetPlace()); - velocity_out->mutable_data(ctx.GetPlace()); - auto* lr = learning_rate->data(); - auto regularization_method = ctx.Attr("regularization_method"); - auto regularization_coeff = ctx.Attr("regularization_coeff"); - if (regularization_method != "l2_decay") { - // only support l2_decay - regularization_coeff = 0.0f; - } - auto* grad_var = ctx.InputVar("Grad"); - PADDLE_ENFORCE_EQ(grad_var->IsType(), - true, - platform::errors::PermissionDenied( - "Unsupported Variable Type of Param & Grad in " - "MomentumOp-XPU. Excepted " - "LodTensor, But received [%s] and [%s]", - paddle::framework::ToTypeName(grad_var->Type()))); - auto grad = ctx.Input("Grad"); - auto& dev_ctx = ctx.template device_context(); - - // int momentum(Context* ctx, const T* param, const T* velocity, const T* - // grad, T* param_out, T* velocity_out, int len, const float* lr, int - // use_nesterov, float mu, float l2_weight_decay); - int r = xpu::momentum(dev_ctx.x_context(), - reinterpret_cast(param->data()), - reinterpret_cast(velocity->data()), - reinterpret_cast(grad->data()), - reinterpret_cast(param_out->data()), - reinterpret_cast(velocity_out->data()), - param_out->numel(), - lr, - use_nesterov, - mu, - regularization_coeff); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "momentum"); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL( - momentum, - ops::MomentumOpXPUKernel, - ops::MomentumOpXPUKernel); -#endif diff --git a/paddle/phi/kernels/xpu/momentum_kernel.cc b/paddle/phi/kernels/xpu/momentum_kernel.cc new file mode 100644 index 00000000000..ad9cb2e6ef8 --- /dev/null +++ b/paddle/phi/kernels/xpu/momentum_kernel.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/momentum_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void MomentumDenseKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& velocity, + const DenseTensor& learning_rate, + const paddle::optional& master_param, + float mu, + bool use_nesterov, + const std::string& regularization_method, + float regularization_coeff, + bool multi_precision, + float rescale_grad, + DenseTensor* param_out, + DenseTensor* velocity_out, + DenseTensor* master_param_out) { + using XPUType = typename XPUTypeTrait::Type; + + dev_ctx.template Alloc(param_out); + dev_ctx.template Alloc(velocity_out); + auto* lr = learning_rate.data(); + + if (regularization_method != "l2_decay") { + // only support l2_decay + regularization_coeff = 0.0f; + } + + // int momentum(Context* ctx, const T* param, const T* velocity, const T* + // grad, T* param_out, T* velocity_out, int len, const float* lr, int + // use_nesterov, float mu, float l2_weight_decay); + int r = xpu::momentum(dev_ctx.x_context(), + reinterpret_cast(param.data()), + reinterpret_cast(velocity.data()), + reinterpret_cast(grad.data()), + reinterpret_cast(param_out->data()), + reinterpret_cast(velocity_out->data()), + param_out->numel(), + lr, + use_nesterov, + static_cast(mu), + regularization_coeff); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "momentum"); +} +} // namespace phi + +PD_REGISTER_KERNEL(momentum, + XPU, + ALL_LAYOUT, + phi::MomentumDenseKernel, + float, + phi::dtype::float16) {} -- GitLab