From 1137677af0c18f505fbec62edbec51717bda896b Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Tue, 6 Sep 2022 09:52:35 +0800 Subject: [PATCH] [XPU] rmsprop to phi. (#45734) --- .../operators/optimizers/rmsprop_op_xpu.cc | 145 ------------------ paddle/phi/kernels/xpu/rmsprop_kernel.cc | 85 ++++++++++ 2 files changed, 85 insertions(+), 145 deletions(-) delete mode 100644 paddle/fluid/operators/optimizers/rmsprop_op_xpu.cc create mode 100644 paddle/phi/kernels/xpu/rmsprop_kernel.cc diff --git a/paddle/fluid/operators/optimizers/rmsprop_op_xpu.cc b/paddle/fluid/operators/optimizers/rmsprop_op_xpu.cc deleted file mode 100644 index 6addb7c2fe..0000000000 --- a/paddle/fluid/operators/optimizers/rmsprop_op_xpu.cc +++ /dev/null @@ -1,145 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU - -#include - -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device/device_wrapper.h" - -namespace paddle { -namespace operators { - -static inline float GetAttrFromTensor(const framework::Tensor* tensor) { - const float* tensor_data = tensor->data(); - framework::Tensor cpu_tensor; - if (platform::is_gpu_place(tensor->place()) || - platform::is_xpu_place(tensor->place())) { - paddle::framework::TensorCopySync( - *tensor, platform::CPUPlace(), &cpu_tensor); - tensor_data = cpu_tensor.data(); - } - return tensor_data[0]; -} - -using framework::OpKernelType; -using framework::Tensor; - -template -class RmspropOpXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - using paddle::framework::LoDTensor; - - // check Param & Grad tensor type - const auto* param_var = ctx.InputVar("Param"); - PADDLE_ENFORCE_EQ(param_var->IsType(), - true, - platform::errors::InvalidArgument( - "Tensor holds the wrong type,Expected Var(%s)'s " - "type is LoDTensor, " - "but the received is %s", - ctx.InputNames("Param").front(), - framework::ToTypeName(param_var->Type()))); - - const auto* grad_var = ctx.InputVar("Grad"); - PADDLE_ENFORCE_EQ(grad_var->IsType(), - true, - platform::errors::InvalidArgument( - "Tensor holds the wrong type,Expected Var(%s)'s " - "type is LoDTensor, " - "but the received is %s", - ctx.InputNames("Grad").front(), - framework::ToTypeName(grad_var->Type()))); - - // inputs - auto& param = GET_DATA_SAFELY( - ctx.Input("Param"), "Input", "Param", "Rmsprop"); - auto& meanSquare = GET_DATA_SAFELY( - ctx.Input("MeanSquare"), "Input", "MeanSquare", "Rmsprop"); - auto& grad = GET_DATA_SAFELY( - ctx.Input("Grad"), "Input", "Grad", "Rmsprop"); - auto& mom = GET_DATA_SAFELY( - ctx.Input("Moment"), "Input", "Moment", "Rmsprop"); - - auto* learning_rate = ctx.Input("LearningRate"); - PADDLE_ENFORCE_EQ(learning_rate->dims().size(), - 1, - platform::errors::InvalidArgument( - "learining rate should have dimension = 1." - " But received learning rate dim [%s] ", - learning_rate->dims().size())); - T lr = static_cast(GetAttrFromTensor(learning_rate)); - - // constants - T epsilon = static_cast(ctx.Attr("epsilon")); - T decay = static_cast(ctx.Attr("decay")); - T momentum = static_cast(ctx.Attr("momentum")); - - bool centered = ctx.Attr("centered"); - PADDLE_ENFORCE_EQ(centered, - false, - platform::errors::Unimplemented( - "centered=True is not supported in the xpu kernel of " - "rmsprop. use XPU_BLACK_LIST to disable this op.")); - /* - TODO(houj04): when XDNN api supports 'center', add input of - mean_grad_input and output of mean_grad_output. auto *mean_grad_input = - ctx.Input("MeanGrad"); auto *mean_grad_output = - ctx.Output("MeanGradOut"); - */ - - // outputs - auto& param_out = GET_DATA_SAFELY( - ctx.Output("ParamOut"), "Output", "ParamOut", "Rmsprop"); - auto& mom_out = GET_DATA_SAFELY( - ctx.Output("MomentOut"), "Output", "MomentOut", "Rmsprop"); - auto& mom_sqrt_out = GET_DATA_SAFELY(ctx.Output("MeanSquareOut"), - "Output", - "MeanSquareOut", - "Rmsprop"); - auto& dev_ctx = ctx.template device_context(); - - // int rmsprop(Context* ctx, const T* g, const T* p, const float* ms, const - // float* mom, T* p_out, float* ms_out, float* mom_out, float epsilon, float - // rho, float momentum, float lr, int n); - int r = xpu::rmsprop(dev_ctx.x_context(), - grad.template data(), - param.template data(), - meanSquare.template data(), - mom.template data(), - param_out.template mutable_data(ctx.GetPlace()), - mom_sqrt_out.template mutable_data(ctx.GetPlace()), - mom_out.template mutable_data(ctx.GetPlace()), - epsilon, - decay, - momentum, - lr, - param.numel()); - - PADDLE_ENFORCE_XDNN_SUCCESS(r, "rmsprop"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL( - rmsprop, - ops::RmspropOpXPUKernel); -#endif diff --git a/paddle/phi/kernels/xpu/rmsprop_kernel.cc b/paddle/phi/kernels/xpu/rmsprop_kernel.cc new file mode 100644 index 0000000000..c95076933c --- /dev/null +++ b/paddle/phi/kernels/xpu/rmsprop_kernel.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/rmsprop_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/fluid/memory/memcpy.h" + +namespace phi { + +template +void RmspropDenseKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& mean_square, + const DenseTensor& grad, + const DenseTensor& moment, + const DenseTensor& learning_rate, + const paddle::optional& mean_grad, + float epsilon, + float decay, + float momentum, + bool centered, + DenseTensor* param_out, + DenseTensor* moment_out, + DenseTensor* mean_square_out, + DenseTensor* mean_grad_out) { + // check input + PADDLE_ENFORCE_EQ(centered, + false, + errors::Unimplemented( + "centered=True is not supported in the xpu kernel of " + "rmsprop. use XPU_BLACK_LIST to disable this op.")); + // copy learning_rate to cpu + PADDLE_ENFORCE_EQ( + learning_rate.dims().size(), + 1, + errors::InvalidArgument("learining rate should have dimension = 1." + " But received learning rate dim [%s] ", + learning_rate.dims().size())); + T learning_rate_cpu = 0.0f; + paddle::memory::Copy(CPUPlace(), + static_cast(&learning_rate_cpu), + dev_ctx.GetPlace(), + static_cast(learning_rate.data()), + sizeof(T)); + + // alloc output + dev_ctx.template Alloc(param_out); + dev_ctx.template Alloc(moment_out); + dev_ctx.template Alloc(mean_square_out); + + // int rmsprop(Context* ctx, const T* g, const T* p, const float* ms, const + // float* mom, T* p_out, float* ms_out, float* mom_out, float epsilon, float + // rho, float momentum, float lr, int n); + int r = xpu::rmsprop(dev_ctx.x_context(), + grad.data(), + param.data(), + mean_square.data(), + moment.data(), + param_out->data(), + mean_square_out->data(), + moment_out->data(), + epsilon, + decay, + momentum, + learning_rate_cpu, + param.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "rmsprop"); +} +} // namespace phi + +PD_REGISTER_KERNEL(rmsprop, XPU, ALL_LAYOUT, phi::RmspropDenseKernel, float) {} -- GitLab