From d3ec3fe324fb47a89ca2140e1c31d0de484affc5 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Mon, 29 Aug 2022 10:20:15 +0800 Subject: [PATCH] [XPU] migrate bce_loss to phi;test=kunlun (#45459) --- paddle/fluid/operators/bce_loss_op_xpu.cc | 76 ------------------- .../phi/kernels/xpu/bce_loss_grad_kernel.cc | 46 +++++++++++ paddle/phi/kernels/xpu/bce_loss_kernel.cc | 42 ++++++++++ 3 files changed, 88 insertions(+), 76 deletions(-) delete mode 100644 paddle/fluid/operators/bce_loss_op_xpu.cc create mode 100644 paddle/phi/kernels/xpu/bce_loss_grad_kernel.cc create mode 100644 paddle/phi/kernels/xpu/bce_loss_kernel.cc diff --git a/paddle/fluid/operators/bce_loss_op_xpu.cc b/paddle/fluid/operators/bce_loss_op_xpu.cc deleted file mode 100644 index 9a9fbc2243f..00000000000 --- a/paddle/fluid/operators/bce_loss_op_xpu.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device/device_wrapper.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class XPUBCELossKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* labels = context.Input("Label"); - auto* out = context.Output("Out"); - out->mutable_data(context.GetPlace()); - - auto x_numel = x->numel(); - auto& dev_ctx = context.template device_context(); - int r = xpu::bce_loss(dev_ctx.x_context(), - x->data(), - labels->data(), - out->data(), - x_numel); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "bce_loss"); - } -}; - -template -class XPUBCELossGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* labels = context.Input("Label"); - auto* dout = context.Input(framework::GradVarName("Out")); - auto* dx = context.Output(framework::GradVarName("X")); - dx->mutable_data(context.GetPlace()); - - auto x_numel = x->numel(); - auto& dev_ctx = context.template device_context(); - int r = xpu::bce_loss_grad(dev_ctx.x_context(), - x->data(), - labels->data(), - dout->data(), - dx->data(), - x_numel); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "bce_loss_grad"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL( - bce_loss, ops::XPUBCELossKernel); -REGISTER_OP_XPU_KERNEL( - bce_loss_grad, - ops::XPUBCELossGradKernel); - -#endif // PADDLE_WITH_XPU diff --git a/paddle/phi/kernels/xpu/bce_loss_grad_kernel.cc b/paddle/phi/kernels/xpu/bce_loss_grad_kernel.cc new file mode 100644 index 00000000000..04dee2cf0c6 --- /dev/null +++ b/paddle/phi/kernels/xpu/bce_loss_grad_kernel.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bce_loss_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void BCELossGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& label, + const DenseTensor& out_grad, + DenseTensor* input_grad) { + using XPUType = typename XPUTypeTrait::Type; + + dev_ctx.template Alloc(input_grad); + + auto x_numel = input.numel(); + int r = xpu::bce_loss_grad( + dev_ctx.x_context(), + reinterpret_cast(input.data()), + reinterpret_cast(label.data()), + reinterpret_cast(out_grad.data()), + reinterpret_cast(input_grad->data()), + x_numel); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "bce_loss_grad"); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + bce_loss_grad, XPU, ALL_LAYOUT, phi::BCELossGradKernel, float) {} diff --git a/paddle/phi/kernels/xpu/bce_loss_kernel.cc b/paddle/phi/kernels/xpu/bce_loss_kernel.cc new file mode 100644 index 00000000000..480fc154167 --- /dev/null +++ b/paddle/phi/kernels/xpu/bce_loss_kernel.cc @@ -0,0 +1,42 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bce_loss_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void BCELossKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& label, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + dev_ctx.template Alloc(out); + + auto x_numel = input.numel(); + int r = + xpu::bce_loss(dev_ctx.x_context(), + reinterpret_cast(input.data()), + reinterpret_cast(label.data()), + reinterpret_cast(out->data()), + x_numel); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "bce_loss"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(bce_loss, XPU, ALL_LAYOUT, phi::BCELossKernel, float) {} -- GitLab