未验证 提交 6fc5d88a 编写于 作者: L Linjie Chen 提交者: GitHub

[phi] move bce_loss to phi (#39868)

* move bce_loss to phi

* refine PADDLE_ENFORCE

* revert PADDLE_ENFORCE

* fix ci
上级 eb4ad509
...@@ -12,11 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/bce_loss_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,41 +29,6 @@ class BCELossOp : public framework::OperatorWithKernel { ...@@ -26,41 +29,6 @@ class BCELossOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BCELoss");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "BCELoss");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "BCELoss");
auto x_dims = ctx->GetInputDim("X");
auto labels_dims = ctx->GetInputDim("Label");
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, labels_dims.size(),
platform::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same rank."
"But received: the rank of Input(X) is [%d], "
"the rank of Input(Label) is [%d].",
rank, labels_dims.size()));
bool check = true;
if ((!ctx->IsRuntime()) &&
(phi::product(x_dims) <= 0 || phi::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(x_dims, labels_dims,
platform::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same "
"shape. But received: the shape of Input(X) is "
"[%s], the shape of Input(Label) is [%s].",
x_dims, labels_dims));
}
ctx->ShareDim("X", "Out");
ctx->ShareLoD("X", "Out");
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -170,16 +138,12 @@ DECLARE_INPLACE_OP_INFERER(BCELossGradInplaceInferer, ...@@ -170,16 +138,12 @@ DECLARE_INPLACE_OP_INFERER(BCELossGradInplaceInferer,
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(bce_loss, BCELossInferShapeFunctor,
PT_INFER_META(phi::BCELossInferMeta));
REGISTER_OPERATOR(bce_loss, ops::BCELossOp, ops::BCELossOpMaker, REGISTER_OPERATOR(bce_loss, ops::BCELossOp, ops::BCELossOpMaker,
ops::BCELossGradOpMaker<paddle::framework::OpDesc>, ops::BCELossGradOpMaker<paddle::framework::OpDesc>,
ops::BCELossGradOpMaker<paddle::imperative::OpBase>, ops::BCELossGradOpMaker<paddle::imperative::OpBase>,
ops::BCELossInplaceInferer); ops::BCELossInplaceInferer, BCELossInferShapeFunctor);
REGISTER_OPERATOR(bce_loss_grad, ops::BCELossGradOp, REGISTER_OPERATOR(bce_loss_grad, ops::BCELossGradOp,
ops::BCELossGradInplaceInferer); ops::BCELossGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
bce_loss, ops::BCELossOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::BCELossOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
bce_loss_grad,
ops::BCELossGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::BCELossGradOpKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/operators/bce_loss_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/core/hostdevice.h"
namespace paddle {
namespace operators {
template <typename T>
struct BCELossFunctor {
T one;
T neg_100;
HOSTDEVICE inline BCELossFunctor() {
one = static_cast<T>(1.0f);
neg_100 = static_cast<T>(-100.);
}
HOSTDEVICE inline T operator()(const T x, const T label) const {
PADDLE_ENFORCE(
(x >= static_cast<T>(0)) && (x <= one),
"Input is expected to be within the interval [0, 1], but recieved %f.",
x);
T term1 = max(real_log(x), neg_100);
T term2 = max(real_log(one - x), neg_100);
return (((label - one) * term2) - (label * term1));
}
};
template <typename T>
struct BCELossGradFunctor {
T one;
T eps;
HOSTDEVICE inline BCELossGradFunctor() {
one = static_cast<T>(1.0f);
eps = static_cast<T>(1e-12);
}
HOSTDEVICE inline T operator()(const T x, const T label, const T dout) const {
T term1 = max((one - x) * x, eps);
return (dout * (x - label) / term1);
}
};
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class BCELossCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
std::vector<const framework::Tensor*> ins = {x, labels};
std::vector<framework::Tensor*> outs = {out};
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto functor = BCELossFunctor<T>();
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
&outs, functor);
}
};
template <typename DeviceContext, typename T>
class BCELossGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
std::vector<const framework::Tensor*> ins = {x, labels, dout};
std::vector<framework::Tensor*> outs = {dx};
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto functor = BCELossGradFunctor<T>();
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
&outs, functor);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
bce_loss,
ops::BCELossCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::BCELossCUDAKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
bce_loss_grad,
ops::BCELossGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::BCELossGradCUDAKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm> // for max
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class BCELossOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* out = ctx.Output<Tensor>("Out");
auto x_data = x->data<T>();
auto label_data = labels->data<T>();
auto out_data = out->mutable_data<T>(ctx.GetPlace());
auto x_numel = x->numel();
// out = -(label * ln(x) + (1 - label) * ln(1 - x)) = (label - 1) * ln(1 -
// x) - label * ln(x)
for (int64_t i = 0; i < x_numel; ++i) {
PADDLE_ENFORCE_GE(
x_data[i], static_cast<T>(0),
platform::errors::InvalidArgument(
"Illegal input, input must be greater than or equal to 0"));
PADDLE_ENFORCE_LE(
x_data[i], static_cast<T>(1),
platform::errors::InvalidArgument(
"Illegal input, input must be less than or equal to 1"));
out_data[i] =
(label_data[i] - static_cast<T>(1)) *
std::max(real_log(static_cast<T>(1) - x_data[i]), (T)(-100)) -
label_data[i] * std::max(real_log(x_data[i]), (T)(-100));
}
}
};
template <typename DeviceContext, typename T>
class BCELossGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto dout_data = dout->data<T>();
auto x_data = x->data<T>();
auto label_data = labels->data<T>();
int x_numel = x->numel();
// dx = dout * ((x - label)/(x - x^2))
for (int i = 0; i < x_numel; ++i) {
dx_data[i] =
dout_data[i] * ((x_data[i] - label_data[i]) /
std::max((static_cast<T>(1) - x_data[i]) * x_data[i],
static_cast<T>(1e-12)));
}
}
};
} // namespace operators
} // namespace paddle
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/bce_loss_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle { namespace paddle {
......
...@@ -230,4 +230,42 @@ void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { ...@@ -230,4 +230,42 @@ void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
out->set_dims(in_dims); out->set_dims(in_dims);
} }
void BCELossInferMeta(const MetaTensor& input,
const MetaTensor& label,
MetaTensor* out,
MetaConfig config) {
auto input_dims = input.dims();
auto label_dims = label.dims();
int rank = input_dims.size();
PADDLE_ENFORCE_EQ(rank,
label_dims.size(),
phi::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same rank."
"But received: the rank of Input(X) is [%d], "
"the rank of Input(Label) is [%d].",
rank,
label_dims.size()));
bool check = true;
if ((!config.is_runtime) &&
(phi::product(input_dims) <= 0 || phi::product(label_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(input_dims,
label_dims,
phi::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same "
"shape. But received: the shape of Input(X) is "
"[%s], the shape of Input(Label) is [%s].",
input_dims,
label_dims));
}
out->set_dims(input_dims);
out->share_lod(input);
}
} // namespace phi } // namespace phi
...@@ -54,4 +54,8 @@ void HuberLossInferMeta(const MetaTensor& input_meta, ...@@ -54,4 +54,8 @@ void HuberLossInferMeta(const MetaTensor& input_meta,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
void BCELossInferMeta(const MetaTensor& input,
const MetaTensor& label,
MetaTensor* out,
MetaConfig config = MetaConfig());
} // namespace phi } // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void BCELossGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& label,
const DenseTensor& out_grad,
DenseTensor* input_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void BCELossKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& label,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/bce_loss_grad_kernel.h"
#include <algorithm> // for max
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void BCELossGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& label,
const DenseTensor& out_grad,
DenseTensor* input_grad) {
auto dx_data = dev_ctx.template Alloc<T>(input_grad);
auto dout_data = out_grad.data<T>();
auto x_data = input.data<T>();
auto label_data = label.data<T>();
int x_numel = input.numel();
// dx = dout * ((x - label)/(x - x^2))
for (int i = 0; i < x_numel; ++i) {
dx_data[i] =
dout_data[i] * ((x_data[i] - label_data[i]) /
std::max((static_cast<T>(1) - x_data[i]) * x_data[i],
static_cast<T>(1e-12)));
}
}
} // namespace phi
PD_REGISTER_KERNEL(
bce_loss_grad, CPU, ALL_LAYOUT, phi::BCELossGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/bce_loss_kernel.h"
#include <algorithm> // for max
#include "paddle/fluid/operators/math.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void BCELossKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& label,
DenseTensor* out) {
auto x_data = input.data<T>();
auto label_data = label.data<T>();
auto out_data = dev_ctx.template Alloc<T>(out);
auto x_numel = input.numel();
// out = -(label * ln(x) + (1 - label) * ln(1 - x)) = (label - 1) * ln(1 -
// x) - label * ln(x)
for (int64_t i = 0; i < x_numel; ++i) {
PADDLE_ENFORCE_GE(
x_data[i],
static_cast<T>(0),
phi::errors::InvalidArgument(
"Illegal input, input must be greater than or equal to 0"));
PADDLE_ENFORCE_LE(
x_data[i],
static_cast<T>(1),
phi::errors::InvalidArgument(
"Illegal input, input must be less than or equal to 1"));
out_data[i] =
(label_data[i] - static_cast<T>(1)) *
std::max(paddle::operators::real_log(static_cast<T>(1) - x_data[i]),
(T)(-100)) -
label_data[i] *
std::max(paddle::operators::real_log(x_data[i]), (T)(-100));
}
}
} // namespace phi
PD_REGISTER_KERNEL(
bce_loss, CPU, ALL_LAYOUT, phi::BCELossKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/bce_loss_grad_kernel.h"
#include <algorithm>
#include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
namespace phi {
template <typename T>
struct BCELossGradFunctor {
T one;
T eps;
HOSTDEVICE inline BCELossGradFunctor() {
one = static_cast<T>(1.0f);
eps = static_cast<T>(1e-12);
}
HOSTDEVICE inline T operator()(const T x, const T label, const T dout) const {
T term1 = max((one - x) * x, eps);
return (dout * (x - label) / term1);
}
};
template <typename T, typename Context>
void BCELossGradKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& label,
const DenseTensor& out_grad,
DenseTensor* input_grad) {
dev_ctx.template Alloc<T>(input_grad);
std::vector<const DenseTensor*> ins = {&input, &label, &out_grad};
std::vector<DenseTensor*> outs = {input_grad};
auto functor = BCELossGradFunctor<T>();
phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);
}
} // namespace phi
PD_REGISTER_KERNEL(
bce_loss_grad, GPU, ALL_LAYOUT, phi::BCELossGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/bce_loss_kernel.h"
#include <algorithm>
#include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
namespace phi {
template <typename T>
struct BCELossFunctor {
T one;
T neg_100;
HOSTDEVICE inline BCELossFunctor() {
one = static_cast<T>(1.0f);
neg_100 = static_cast<T>(-100.);
}
HOSTDEVICE inline T operator()(const T x, const T label) const {
PADDLE_ENFORCE(
(x >= static_cast<T>(0)) && (x <= one),
"Input is expected to be within the interval [0, 1], but recieved %f.",
x);
T term1 = max(phi::kps::details::Log(x), neg_100);
T term2 = max(phi::kps::details::Log(one - x), neg_100);
return (((label - one) * term2) - (label * term1));
}
};
template <typename T, typename Context>
void BCELossKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& label,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&input, &label};
std::vector<DenseTensor*> outs = {out};
auto functor = BCELossFunctor<T>();
phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);
}
} // namespace phi
PD_REGISTER_KERNEL(
bce_loss, GPU, ALL_LAYOUT, phi::BCELossKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature BCELossGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("bce_loss_grad",
{"X", "Label", GradVarName("Out")},
{},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(bce_loss_grad, phi::BCELossGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册