From df7600ab2aed331bfb64a41d36302bdd7752051a Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Wed, 31 Aug 2022 10:30:01 +0800 Subject: [PATCH] Move XPU mean and mean_grad to phi (#45512) * Move XPU mean and mean_grad to phi, test=kunlun * Fix stream, test=kunlun * Replace ENFORCE, test=kunlun --- paddle/fluid/operators/mean_op_xpu.cc | 105 ------------------ .../phi/kernels/xpu/mean_all_grad_kernel.cc | 58 ++++++++++ paddle/phi/kernels/xpu/mean_all_kernel.cc | 48 ++++++++ 3 files changed, 106 insertions(+), 105 deletions(-) delete mode 100644 paddle/fluid/operators/mean_op_xpu.cc create mode 100644 paddle/phi/kernels/xpu/mean_all_grad_kernel.cc create mode 100644 paddle/phi/kernels/xpu/mean_all_kernel.cc diff --git a/paddle/fluid/operators/mean_op_xpu.cc b/paddle/fluid/operators/mean_op_xpu.cc deleted file mode 100644 index 8f1300b1524..00000000000 --- a/paddle/fluid/operators/mean_op_xpu.cc +++ /dev/null @@ -1,105 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class MeanXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("X"); - auto* output = context.Output("Out"); - output->mutable_data(context.GetPlace()); - auto& dev_ctx = context.template device_context(); - const T* x_data = input->data(); - T* y_data = output->data(); - std::vector x_shape; - x_shape.push_back(1); - x_shape.push_back(input->numel()); - std::vector rdims = {1}; - int r = xpu::reduce_mean(dev_ctx.x_context(), - reinterpret_cast(x_data), - reinterpret_cast(y_data), - x_shape, - rdims); - PADDLE_ENFORCE_EQ(r, - XPU_SUCCESS, - platform::errors::External( - "XPU reduce_mean kernel return wrong value[%d %s]", - r, - XPUAPIErrorMsg[r])); - } -}; -template -class MeanGradXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& context) const override { - auto OG = context.Input(framework::GradVarName("Out")); - PADDLE_ENFORCE_EQ( - OG->numel(), - 1, - platform::errors::InvalidArgument("Mean Gradient should be scalar")); - auto IG = context.Output(framework::GradVarName("X")); - IG->mutable_data(context.GetPlace()); - auto& dev_ctx = context.template device_context(); - - XPUType* dx = reinterpret_cast(IG->data()); - - const T* dy = OG->data(); - T dy0_value; - xpu_wait(dev_ctx.x_context()->xpu_stream); - memory::Copy(platform::CPUPlace(), &dy0_value, OG->place(), dy, sizeof(T)); - float dy0_fp32 = static_cast(dy0_value); - dy0_fp32 = dy0_fp32 / static_cast(IG->numel()); - - int r = xpu::constant( - dev_ctx.x_context(), dx, IG->numel(), static_cast(dy0_fp32)); - PADDLE_ENFORCE_EQ(r, - XPU_SUCCESS, - platform::errors::External( - "XPU constant kernel return wrong value[%d %s]", - r, - XPUAPIErrorMsg[r])); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL( - mean, - ops::MeanXPUKernel, - ops::MeanXPUKernel); -REGISTER_OP_XPU_KERNEL( - mean_grad, - ops::MeanGradXPUKernel, - ops::MeanGradXPUKernel); -#endif diff --git a/paddle/phi/kernels/xpu/mean_all_grad_kernel.cc b/paddle/phi/kernels/xpu/mean_all_grad_kernel.cc new file mode 100644 index 00000000000..cd7a0c4a106 --- /dev/null +++ b/paddle/phi/kernels/xpu/mean_all_grad_kernel.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/mean_all_grad_kernel.h" + +#include "paddle/fluid/memory/memory.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void MeanAllGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + using XPUType = typename XPUTypeTrait::Type; + + auto OG = &out_grad; + PADDLE_ENFORCE_EQ( + OG->numel(), + 1, + phi::errors::InvalidArgument("Mean Gradient should be scalar")); + auto IG = x_grad; + dev_ctx.template Alloc(IG); + + XPUType* dx = reinterpret_cast(IG->data()); + + const T* dy = OG->data(); + T dy0_value; + xpu_wait(dev_ctx.x_context()->xpu_stream); + paddle::memory::Copy(phi::CPUPlace(), &dy0_value, OG->place(), dy, sizeof(T)); + float dy0_fp32 = static_cast(dy0_value); + dy0_fp32 = dy0_fp32 / static_cast(IG->numel()); + + int r = xpu::constant( + dev_ctx.x_context(), dx, IG->numel(), static_cast(dy0_fp32)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "mean_all_grad"); +} +} // namespace phi + +PD_REGISTER_KERNEL(mean_all_grad, + XPU, + ALL_LAYOUT, + phi::MeanAllGradKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/mean_all_kernel.cc b/paddle/phi/kernels/xpu/mean_all_kernel.cc new file mode 100644 index 00000000000..c80fb62dd8e --- /dev/null +++ b/paddle/phi/kernels/xpu/mean_all_kernel.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/mean_all_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void MeanAllKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + + auto* input = &x; + auto* output = out; + dev_ctx.template Alloc(out); + const T* x_data = input->data(); + T* y_data = output->data(); + std::vector x_shape; + x_shape.push_back(1); + x_shape.push_back(input->numel()); + std::vector rdims = {1}; + int r = xpu::reduce_mean(dev_ctx.x_context(), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + x_shape, + rdims); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "mean_all"); +} +} // namespace phi + +PD_REGISTER_KERNEL( + mean_all, XPU, ALL_LAYOUT, phi::MeanAllKernel, float, phi::dtype::float16) { +} -- GitLab