From 6fc5d88a378892ad0936626222983d8958aea24c Mon Sep 17 00:00:00 2001 From: Linjie Chen <40840292+linjieccc@users.noreply.github.com> Date: Thu, 24 Feb 2022 14:15:20 +0800 Subject: [PATCH] [phi] move bce_loss to phi (#39868) * move bce_loss to phi * refine PADDLE_ENFORCE * revert PADDLE_ENFORCE * fix ci --- paddle/fluid/operators/bce_loss_op.cc | 52 ++------- paddle/fluid/operators/bce_loss_op.cu | 109 ------------------ paddle/fluid/operators/bce_loss_op.h | 85 -------------- paddle/fluid/operators/bce_loss_op_npu.cc | 2 +- paddle/phi/infermeta/binary.cc | 38 ++++++ paddle/phi/infermeta/binary.h | 4 + paddle/phi/kernels/bce_loss_grad_kernel.h | 28 +++++ paddle/phi/kernels/bce_loss_kernel.h | 27 +++++ .../phi/kernels/cpu/bce_loss_grad_kernel.cc | 47 ++++++++ paddle/phi/kernels/cpu/bce_loss_kernel.cc | 59 ++++++++++ .../phi/kernels/gpu/bce_loss_grad_kernel.cu | 59 ++++++++++ paddle/phi/kernels/gpu/bce_loss_kernel.cu | 64 ++++++++++ paddle/phi/ops/compat/bce_loss_sig.cc | 29 +++++ 13 files changed, 364 insertions(+), 239 deletions(-) delete mode 100644 paddle/fluid/operators/bce_loss_op.cu delete mode 100644 paddle/fluid/operators/bce_loss_op.h create mode 100644 paddle/phi/kernels/bce_loss_grad_kernel.h create mode 100644 paddle/phi/kernels/bce_loss_kernel.h create mode 100644 paddle/phi/kernels/cpu/bce_loss_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/bce_loss_kernel.cc create mode 100644 paddle/phi/kernels/gpu/bce_loss_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/bce_loss_kernel.cu create mode 100644 paddle/phi/ops/compat/bce_loss_sig.cc diff --git a/paddle/fluid/operators/bce_loss_op.cc b/paddle/fluid/operators/bce_loss_op.cc index 1c390923d0..55bb57466c 100644 --- a/paddle/fluid/operators/bce_loss_op.cc +++ b/paddle/fluid/operators/bce_loss_op.cc @@ -12,11 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/bce_loss_op.h" #include #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/binary.h" + namespace paddle { namespace operators { @@ -26,41 +29,6 @@ class BCELossOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BCELoss"); - OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "BCELoss"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "BCELoss"); - - auto x_dims = ctx->GetInputDim("X"); - auto labels_dims = ctx->GetInputDim("Label"); - - int rank = x_dims.size(); - PADDLE_ENFORCE_EQ(rank, labels_dims.size(), - platform::errors::InvalidArgument( - "Input(X) and Input(Label) shall have the same rank." - "But received: the rank of Input(X) is [%d], " - "the rank of Input(Label) is [%d].", - rank, labels_dims.size())); - - bool check = true; - if ((!ctx->IsRuntime()) && - (phi::product(x_dims) <= 0 || phi::product(labels_dims) <= 0)) { - check = false; - } - - if (check) { - PADDLE_ENFORCE_EQ(x_dims, labels_dims, - platform::errors::InvalidArgument( - "Input(X) and Input(Label) shall have the same " - "shape. But received: the shape of Input(X) is " - "[%s], the shape of Input(Label) is [%s].", - x_dims, labels_dims)); - } - - ctx->ShareDim("X", "Out"); - ctx->ShareLoD("X", "Out"); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -170,16 +138,12 @@ DECLARE_INPLACE_OP_INFERER(BCELossGradInplaceInferer, } // namespace paddle namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(bce_loss, BCELossInferShapeFunctor, + PT_INFER_META(phi::BCELossInferMeta)); + REGISTER_OPERATOR(bce_loss, ops::BCELossOp, ops::BCELossOpMaker, ops::BCELossGradOpMaker, ops::BCELossGradOpMaker, - ops::BCELossInplaceInferer); + ops::BCELossInplaceInferer, BCELossInferShapeFunctor); REGISTER_OPERATOR(bce_loss_grad, ops::BCELossGradOp, ops::BCELossGradInplaceInferer); -REGISTER_OP_CPU_KERNEL( - bce_loss, ops::BCELossOpKernel, - ops::BCELossOpKernel); -REGISTER_OP_CPU_KERNEL( - bce_loss_grad, - ops::BCELossGradOpKernel, - ops::BCELossGradOpKernel); diff --git a/paddle/fluid/operators/bce_loss_op.cu b/paddle/fluid/operators/bce_loss_op.cu deleted file mode 100644 index f71fbbdc6b..0000000000 --- a/paddle/fluid/operators/bce_loss_op.cu +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include -#include "paddle/fluid/operators/bce_loss_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -#include "paddle/fluid/operators/math.h" -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#include "paddle/phi/core/hostdevice.h" - -namespace paddle { -namespace operators { -template -struct BCELossFunctor { - T one; - T neg_100; - - HOSTDEVICE inline BCELossFunctor() { - one = static_cast(1.0f); - neg_100 = static_cast(-100.); - } - - HOSTDEVICE inline T operator()(const T x, const T label) const { - PADDLE_ENFORCE( - (x >= static_cast(0)) && (x <= one), - "Input is expected to be within the interval [0, 1], but recieved %f.", - x); - T term1 = max(real_log(x), neg_100); - T term2 = max(real_log(one - x), neg_100); - return (((label - one) * term2) - (label * term1)); - } -}; - -template -struct BCELossGradFunctor { - T one; - T eps; - - HOSTDEVICE inline BCELossGradFunctor() { - one = static_cast(1.0f); - eps = static_cast(1e-12); - } - - HOSTDEVICE inline T operator()(const T x, const T label, const T dout) const { - T term1 = max((one - x) * x, eps); - return (dout * (x - label) / term1); - } -}; - -using Tensor = framework::Tensor; - -template -class BCELossCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* labels = ctx.Input("Label"); - auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - std::vector ins = {x, labels}; - std::vector outs = {out}; - auto& dev_ctx = ctx.template device_context(); - auto functor = BCELossFunctor(); - paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, - &outs, functor); - } -}; - -template -class BCELossGradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* labels = ctx.Input("Label"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - dx->mutable_data(ctx.GetPlace()); - std::vector ins = {x, labels, dout}; - std::vector outs = {dx}; - auto& dev_ctx = ctx.template device_context(); - auto functor = BCELossGradFunctor(); - paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, - &outs, functor); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - bce_loss, - ops::BCELossCUDAKernel, - ops::BCELossCUDAKernel); -REGISTER_OP_CUDA_KERNEL( - bce_loss_grad, - ops::BCELossGradCUDAKernel, - ops::BCELossGradCUDAKernel); diff --git a/paddle/fluid/operators/bce_loss_op.h b/paddle/fluid/operators/bce_loss_op.h deleted file mode 100644 index dd87b69efe..0000000000 --- a/paddle/fluid/operators/bce_loss_op.h +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include // for max -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class BCELossOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* labels = ctx.Input("Label"); - auto* out = ctx.Output("Out"); - - auto x_data = x->data(); - auto label_data = labels->data(); - auto out_data = out->mutable_data(ctx.GetPlace()); - auto x_numel = x->numel(); - - // out = -(label * ln(x) + (1 - label) * ln(1 - x)) = (label - 1) * ln(1 - - // x) - label * ln(x) - for (int64_t i = 0; i < x_numel; ++i) { - PADDLE_ENFORCE_GE( - x_data[i], static_cast(0), - platform::errors::InvalidArgument( - "Illegal input, input must be greater than or equal to 0")); - PADDLE_ENFORCE_LE( - x_data[i], static_cast(1), - platform::errors::InvalidArgument( - "Illegal input, input must be less than or equal to 1")); - out_data[i] = - (label_data[i] - static_cast(1)) * - std::max(real_log(static_cast(1) - x_data[i]), (T)(-100)) - - label_data[i] * std::max(real_log(x_data[i]), (T)(-100)); - } - } -}; - -template -class BCELossGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* labels = ctx.Input("Label"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - - auto dx_data = dx->mutable_data(ctx.GetPlace()); - auto dout_data = dout->data(); - auto x_data = x->data(); - auto label_data = labels->data(); - - int x_numel = x->numel(); - - // dx = dout * ((x - label)/(x - x^2)) - for (int i = 0; i < x_numel; ++i) { - dx_data[i] = - dout_data[i] * ((x_data[i] - label_data[i]) / - std::max((static_cast(1) - x_data[i]) * x_data[i], - static_cast(1e-12))); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/bce_loss_op_npu.cc b/paddle/fluid/operators/bce_loss_op_npu.cc index 46e8a36d2e..c3cee6a7b0 100644 --- a/paddle/fluid/operators/bce_loss_op_npu.cc +++ b/paddle/fluid/operators/bce_loss_op_npu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/bce_loss_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index e94926a9c1..ab1fe5433f 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -230,4 +230,42 @@ void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { out->set_dims(in_dims); } +void BCELossInferMeta(const MetaTensor& input, + const MetaTensor& label, + MetaTensor* out, + MetaConfig config) { + auto input_dims = input.dims(); + auto label_dims = label.dims(); + + int rank = input_dims.size(); + PADDLE_ENFORCE_EQ(rank, + label_dims.size(), + phi::errors::InvalidArgument( + "Input(X) and Input(Label) shall have the same rank." + "But received: the rank of Input(X) is [%d], " + "the rank of Input(Label) is [%d].", + rank, + label_dims.size())); + + bool check = true; + if ((!config.is_runtime) && + (phi::product(input_dims) <= 0 || phi::product(label_dims) <= 0)) { + check = false; + } + + if (check) { + PADDLE_ENFORCE_EQ(input_dims, + label_dims, + phi::errors::InvalidArgument( + "Input(X) and Input(Label) shall have the same " + "shape. But received: the shape of Input(X) is " + "[%s], the shape of Input(Label) is [%s].", + input_dims, + label_dims)); + } + + out->set_dims(input_dims); + out->share_lod(input); +} + } // namespace phi diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index f23382be89..effa18c567 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -54,4 +54,8 @@ void HuberLossInferMeta(const MetaTensor& input_meta, MetaConfig config = MetaConfig()); void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); +void BCELossInferMeta(const MetaTensor& input, + const MetaTensor& label, + MetaTensor* out, + MetaConfig config = MetaConfig()); } // namespace phi diff --git a/paddle/phi/kernels/bce_loss_grad_kernel.h b/paddle/phi/kernels/bce_loss_grad_kernel.h new file mode 100644 index 0000000000..14bf52196a --- /dev/null +++ b/paddle/phi/kernels/bce_loss_grad_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void BCELossGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& label, + const DenseTensor& out_grad, + DenseTensor* input_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/bce_loss_kernel.h b/paddle/phi/kernels/bce_loss_kernel.h new file mode 100644 index 0000000000..6459ea9116 --- /dev/null +++ b/paddle/phi/kernels/bce_loss_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void BCELossKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& label, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/bce_loss_grad_kernel.cc b/paddle/phi/kernels/cpu/bce_loss_grad_kernel.cc new file mode 100644 index 0000000000..6859451e8b --- /dev/null +++ b/paddle/phi/kernels/cpu/bce_loss_grad_kernel.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bce_loss_grad_kernel.h" + +#include // for max +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void BCELossGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& label, + const DenseTensor& out_grad, + DenseTensor* input_grad) { + auto dx_data = dev_ctx.template Alloc(input_grad); + auto dout_data = out_grad.data(); + auto x_data = input.data(); + auto label_data = label.data(); + + int x_numel = input.numel(); + + // dx = dout * ((x - label)/(x - x^2)) + for (int i = 0; i < x_numel; ++i) { + dx_data[i] = + dout_data[i] * ((x_data[i] - label_data[i]) / + std::max((static_cast(1) - x_data[i]) * x_data[i], + static_cast(1e-12))); + } +} +} // namespace phi + +PD_REGISTER_KERNEL( + bce_loss_grad, CPU, ALL_LAYOUT, phi::BCELossGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/bce_loss_kernel.cc b/paddle/phi/kernels/cpu/bce_loss_kernel.cc new file mode 100644 index 0000000000..76b9793651 --- /dev/null +++ b/paddle/phi/kernels/cpu/bce_loss_kernel.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bce_loss_kernel.h" + +#include // for max +#include "paddle/fluid/operators/math.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void BCELossKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& label, + DenseTensor* out) { + auto x_data = input.data(); + auto label_data = label.data(); + auto out_data = dev_ctx.template Alloc(out); + auto x_numel = input.numel(); + + // out = -(label * ln(x) + (1 - label) * ln(1 - x)) = (label - 1) * ln(1 - + // x) - label * ln(x) + for (int64_t i = 0; i < x_numel; ++i) { + PADDLE_ENFORCE_GE( + x_data[i], + static_cast(0), + phi::errors::InvalidArgument( + "Illegal input, input must be greater than or equal to 0")); + PADDLE_ENFORCE_LE( + x_data[i], + static_cast(1), + phi::errors::InvalidArgument( + "Illegal input, input must be less than or equal to 1")); + out_data[i] = + (label_data[i] - static_cast(1)) * + std::max(paddle::operators::real_log(static_cast(1) - x_data[i]), + (T)(-100)) - + label_data[i] * + std::max(paddle::operators::real_log(x_data[i]), (T)(-100)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + bce_loss, CPU, ALL_LAYOUT, phi::BCELossKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/bce_loss_grad_kernel.cu b/paddle/phi/kernels/gpu/bce_loss_grad_kernel.cu new file mode 100644 index 0000000000..94eabac4d1 --- /dev/null +++ b/paddle/phi/kernels/gpu/bce_loss_grad_kernel.cu @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bce_loss_grad_kernel.h" + +#include +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" + +namespace phi { + +template +struct BCELossGradFunctor { + T one; + T eps; + + HOSTDEVICE inline BCELossGradFunctor() { + one = static_cast(1.0f); + eps = static_cast(1e-12); + } + + HOSTDEVICE inline T operator()(const T x, const T label, const T dout) const { + T term1 = max((one - x) * x, eps); + return (dout * (x - label) / term1); + } +}; + +template +void BCELossGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& label, + const DenseTensor& out_grad, + DenseTensor* input_grad) { + dev_ctx.template Alloc(input_grad); + std::vector ins = {&input, &label, &out_grad}; + std::vector outs = {input_grad}; + auto functor = BCELossGradFunctor(); + phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + bce_loss_grad, GPU, ALL_LAYOUT, phi::BCELossGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/bce_loss_kernel.cu b/paddle/phi/kernels/gpu/bce_loss_kernel.cu new file mode 100644 index 0000000000..adbcd3b2b6 --- /dev/null +++ b/paddle/phi/kernels/gpu/bce_loss_kernel.cu @@ -0,0 +1,64 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/bce_loss_kernel.h" + +#include +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/primitive/functor_primitives.h" + +namespace phi { + +template +struct BCELossFunctor { + T one; + T neg_100; + + HOSTDEVICE inline BCELossFunctor() { + one = static_cast(1.0f); + neg_100 = static_cast(-100.); + } + + HOSTDEVICE inline T operator()(const T x, const T label) const { + PADDLE_ENFORCE( + (x >= static_cast(0)) && (x <= one), + "Input is expected to be within the interval [0, 1], but recieved %f.", + x); + T term1 = max(phi::kps::details::Log(x), neg_100); + T term2 = max(phi::kps::details::Log(one - x), neg_100); + return (((label - one) * term2) - (label * term1)); + } +}; + +template +void BCELossKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& label, + DenseTensor* out) { + dev_ctx.template Alloc(out); + std::vector ins = {&input, &label}; + std::vector outs = {out}; + auto functor = BCELossFunctor(); + phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + bce_loss, GPU, ALL_LAYOUT, phi::BCELossKernel, float, double) {} diff --git a/paddle/phi/ops/compat/bce_loss_sig.cc b/paddle/phi/ops/compat/bce_loss_sig.cc new file mode 100644 index 0000000000..17f76067d2 --- /dev/null +++ b/paddle/phi/ops/compat/bce_loss_sig.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature BCELossGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("bce_loss_grad", + {"X", "Label", GradVarName("Out")}, + {}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(bce_loss_grad, phi::BCELossGradOpArgumentMapping); -- GitLab