From ab05cdc46e20b1768d473e4abebc9c0654ab6671 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Sat, 11 Apr 2020 15:36:16 +0800 Subject: [PATCH] Add bce_loss op (#23388) * add bce_loss * fix mistake * replace paddle_enforce,test=develop * fix,test=develop * update,test=develop * remove duplicate,test=develop * update,test=develop * update error,test=develop * update,test=develop * fix unittest, test=develop * update, test=develop --- paddle/fluid/operators/bce_loss_op.cc | 180 ++++++++++++++++++ paddle/fluid/operators/bce_loss_op.cu | 133 +++++++++++++ paddle/fluid/operators/bce_loss_op.h | 85 +++++++++ .../fluid/tests/unittests/test_bce_loss.py | 135 +++++++++++++ python/paddle/nn/__init__.py | 2 +- python/paddle/nn/layer/loss.py | 111 ++++++++++- 6 files changed, 644 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/bce_loss_op.cc create mode 100644 paddle/fluid/operators/bce_loss_op.cu create mode 100644 paddle/fluid/operators/bce_loss_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_bce_loss.py diff --git a/paddle/fluid/operators/bce_loss_op.cc b/paddle/fluid/operators/bce_loss_op.cc new file mode 100644 index 00000000000..4cbcd1dd775 --- /dev/null +++ b/paddle/fluid/operators/bce_loss_op.cc @@ -0,0 +1,180 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/bce_loss_op.h" +#include +#include +#include + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class BCELossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::InvalidArgument("Input(X) should be not null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Label"), true, + platform::errors::InvalidArgument("Input(Label) should be not null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + platform::errors::InvalidArgument("Output(Out) should be not null.")); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE_EQ( + x_dims.size(), label_dims.size(), + platform::errors::InvalidArgument( + "Input(X) and Input(Label) shall have the same shape.")); + bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) || + framework::contain_unknown_dim(label_dims); + bool check = ctx->IsRuntime() || !contain_unknown_dim; + if (check) { + PADDLE_ENFORCE_EQ( + x_dims.size(), label_dims.size(), + platform::errors::InvalidArgument( + "ShapeError: Input(X) and Input(Label) shall have the same shape " + "But received: the shape of Input(X) is [%s], the shape of " + "Input(Label) is [%s].", + x_dims, label_dims)); + } + + ctx->ShareDim("X", "Out"); + ctx->ShareLoD("X", "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class BCELossGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::InvalidArgument("Input(X) should be not null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Label"), true, + platform::errors::InvalidArgument("Input(Label) should be not null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, + platform::errors::InvalidArgument( + "Input(Out@GRAD) shoudl be not null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true, + platform::errors::InvalidArgument( + "Output(X@GRAD) should be not null.")); + + auto x_dims = ctx->GetInputDim("X"); + auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); + bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) || + framework::contain_unknown_dim(dout_dims); + bool check = ctx->IsRuntime() || !contain_unknown_dim; + if (check) { + PADDLE_ENFORCE_EQ(x_dims, dout_dims, + platform::errors::InvalidArgument( + "ShapeError:The Input(X) and Input(Out@Grad) " + "should have the same " + "shape, But received: the shape of Input(X) is " + "[%s], the shape of " + "Input(Out@GRAD) is [%s].", + x_dims, dout_dims)); + } + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD("X", framework::GradVarName("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class BCELossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), the input is a tensor of logits" + "computed by the previous operator, which is always the result of" + "a sigmoid operator. Input must between in 0 and 1."); + AddInput("Label", + "(Tensor, default Tensor), have same shape with input" + "label should between in 0 and 1."); + AddOutput("Out", + "(Tensor, default Tensor), have same shape with" + "input"); + AddComment(R"DOC( +BinaryCrossEntropy operator. + +This measures the element-wise probability error in classification tasks +in which each class is independent. + +The logitstic loss is given as follows: + $$loss = -Label * \log(X) - (1 - Label) * \log(1 - X)$$ +)DOC"); + } +}; + +template +class BCELossGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("bce_loss_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Label", this->Input("Label")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + // op->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_INPLACE_OP_INFERER(BCELossInplaceInferer, {"X", "Out"}); +DECLARE_INPLACE_OP_INFERER(BCELossGradInplaceInferer, + {framework::GradVarName("Out"), + framework::GradVarName("X")}); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(bce_loss, ops::BCELossOp, ops::BCELossOpMaker, + ops::BCELossGradOpMaker, + ops::BCELossGradOpMaker, + ops::BCELossInplaceInferer); +REGISTER_OPERATOR(bce_loss_grad, ops::BCELossGradOp, + ops::BCELossGradInplaceInferer); +REGISTER_OP_CPU_KERNEL( + bce_loss, ops::BCELossOpKernel, + ops::BCELossOpKernel); +REGISTER_OP_CPU_KERNEL( + bce_loss_grad, + ops::BCELossGradOpKernel, + ops::BCELossGradOpKernel); diff --git a/paddle/fluid/operators/bce_loss_op.cu b/paddle/fluid/operators/bce_loss_op.cu new file mode 100644 index 00000000000..179e194a9c5 --- /dev/null +++ b/paddle/fluid/operators/bce_loss_op.cu @@ -0,0 +1,133 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include +#include "cub/cub.cuh" +#include "paddle/fluid/operators/bce_loss_op.h" +#include "paddle/fluid/operators/math.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_launch_config.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void GPUBCELossForward(const T* x_data, const T* label_data, + T* out_data, const int in_numel) { + CUDA_1D_KERNEL_LOOP(i, in_numel) { + T x = x_data[i]; + T label = label_data[i]; + T one = static_cast(1.); + T neg_100 = static_cast(-100.); + + T term1 = max(real_log(x), neg_100); + T term2 = max(real_log(one - x), neg_100); + + out_data[i] = ((label - one) * term2) - (label * term1); + } +} + +template +__global__ void GPUBCELossBackward(const T* x_data, const T* label_data, + const T* dout_data, T* dx_data, + const int in_numel) { + CUDA_1D_KERNEL_LOOP(i, in_numel) { + T x = x_data[i]; + T label = label_data[i]; + T dout = dout_data[i]; + T one = static_cast(1.); + T eps = static_cast(1e-12); + + T term1 = max((one - x) * x, eps); + + dx_data[i] = dout * (x - label) / term1; + } +} + +template +class BCELossCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* labels = ctx.Input("Label"); + auto* out = ctx.Output("Out"); + + auto x_data = x->data(); + auto out_data = out->mutable_data(ctx.GetPlace()); + int x_numel = x->numel(); + platform::GpuLaunchConfig config = + platform::getGpuLaunchConfig(x_numel, ctx); + + Tensor x_cpu; + framework::TensorCopy(*x, platform::CPUPlace(), &x_cpu); + T* x_cpu_data = x_cpu.data(); + + for (int i = 0; i < x_numel; ++i) { + PADDLE_ENFORCE_GE( + x_cpu_data[i], static_cast(0), + platform::errors::InvalidArgument( + "Illegal input, input must be greater than or equal to 0")); + PADDLE_ENFORCE_LE( + x_cpu_data[i], static_cast(1), + platform::errors::InvalidArgument( + "Illegal input, input must be less than or equal to 1")); + } + + auto& dev_ctx = ctx.cuda_device_context(); + + GPUBCELossForward< + T><<>>( + x_data, labels->data(), out_data, x_numel); + } +}; + +template +class BCELossGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* labels = ctx.Input("Label"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto dx_data = dx->mutable_data(ctx.GetPlace()); + + int x_numel = x->numel(); + platform::GpuLaunchConfig config = + platform::getGpuLaunchConfig(x_numel, ctx); + auto& dev_ctx = ctx.cuda_device_context(); + + GPUBCELossBackward< + T><<>>( + x->data(), labels->data(), dout->data(), dx_data, x_numel); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + bce_loss, + ops::BCELossCUDAKernel, + ops::BCELossCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + bce_loss_grad, + ops::BCELossGradCUDAKernel, + ops::BCELossGradCUDAKernel); diff --git a/paddle/fluid/operators/bce_loss_op.h b/paddle/fluid/operators/bce_loss_op.h new file mode 100644 index 00000000000..85e120e4642 --- /dev/null +++ b/paddle/fluid/operators/bce_loss_op.h @@ -0,0 +1,85 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include // for max +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class BCELossOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* labels = ctx.Input("Label"); + auto* out = ctx.Output("Out"); + + auto x_data = x->data(); + auto label_data = labels->data(); + auto out_data = out->mutable_data(ctx.GetPlace()); + int x_numel = x->numel(); + + // out = -(label * ln(x) + (1 - label) * ln(1 - x)) = (label - 1) * ln(1 - + // x) - label * ln(x) + for (int i = 0; i < x_numel; ++i) { + PADDLE_ENFORCE_GE( + x_data[i], static_cast(0), + platform::errors::InvalidArgument( + "Illegal input, input must be greater than or equal to 0")); + PADDLE_ENFORCE_LE( + x_data[i], static_cast(1), + platform::errors::InvalidArgument( + "Illegal input, input must be less than or equal to 1")); + out_data[i] = + (label_data[i] - static_cast(1)) * + std::max(real_log(static_cast(1) - x_data[i]), (T)(-100)) - + label_data[i] * std::max(real_log(x_data[i]), (T)(-100)); + } + } +}; + +template +class BCELossGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* labels = ctx.Input("Label"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + + auto dx_data = dx->mutable_data(ctx.GetPlace()); + auto dout_data = dout->data(); + auto x_data = x->data(); + auto label_data = labels->data(); + + int x_numel = x->numel(); + + // dx = dout * ((x - label)/(x - x^2)) + for (int i = 0; i < x_numel; ++i) { + dx_data[i] = + dout_data[i] * ((x_data[i] - label_data[i]) / + std::max((static_cast(1) - x_data[i]) * x_data[i], + static_cast(1e-12))); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_bce_loss.py b/python/paddle/fluid/tests/unittests/test_bce_loss.py new file mode 100644 index 00000000000..f3351e36a69 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_bce_loss.py @@ -0,0 +1,135 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import numpy as np +import unittest +from op_test import OpTest + + +class TestBCELoss(unittest.TestCase): + def test_BCELoss(self): + input_np = np.random.random(size=(20, 30)).astype(np.float64) + label_np = np.random.random(size=(20, 30)).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + reductions = ['sum', 'mean', 'none'] + for place in places: + for red in reductions: + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[None, 30], dtype='float64') + label = fluid.data( + name='label', shape=[None, 30], dtype='float64') + bce_loss = paddle.nn.loss.BCELoss(reduction=red) + res = bce_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run( + prog, + feed={"input": input_np, + "label": label_np}, + fetch_list=[res]) + + with fluid.dygraph.guard(): + bce_loss = paddle.nn.loss.BCELoss(reduction=red) + dy_res = bce_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + + expected = -1 * (label_np * np.log(input_np) + + (1. - label_np) * np.log(1. - input_np)) + if red == 'mean': + expected = np.mean(expected) + elif red == 'sum': + expected = np.sum(expected) + else: + expected = expected + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + def test_BCELoss_weight(self): + input_np = np.random.random(size=(20, 30)).astype(np.float64) + label_np = np.random.random(size=(20, 30)).astype(np.float64) + weight_np = np.random.random(size=(20, 30)).astype(np.float64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[None, 30], dtype='float64') + label = fluid.data(name='label', shape=[None, 30], dtype='float64') + weight = fluid.data( + name='weight', shape=[None, 30], dtype='float64') + bce_loss = paddle.nn.loss.BCELoss(weight=weight) + res = bce_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run(prog, + feed={ + "input": input_np, + "label": label_np, + "weight": weight_np + }, + fetch_list=[res]) + + with fluid.dygraph.guard(): + bce_loss = paddle.nn.loss.BCELoss( + weight=fluid.dygraph.to_variable(weight_np)) + dy_res = bce_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_result = dy_res.numpy() + + expected = np.mean(-1 * weight_np * + (label_np * np.log(input_np) + + (1. - label_np) * np.log(1. - input_np))) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + + +def bce_loss(input, label): + return -1 * (label * np.log(input) + (1. - label) * np.log(1. - input)) + + +class TestBceLossOp(OpTest): + def setUp(self): + self.init_test_case() + self.op_type = "bce_loss" + input_np = np.random.uniform(0.1, 0.8, self.shape).astype("float64") + label_np = np.random.randint(0, 2, self.shape).astype("float64") + output_np = bce_loss(input_np, label_np) + + self.inputs = {'X': input_np, 'Label': label_np} + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def init_test_case(self): + self.shape = [10, 10] + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index e5cfd360780..0a98150ef5a 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -66,7 +66,7 @@ from .layer import loss #DEFINE_ALIAS from .layer import conv #DEFINE_ALIAS from .layer.conv import Conv2D, Conv2DTranspose, Conv3D, Conv3DTranspose #DEFINE_ALIAS # from .layer.loss import NLLLoss #DEFINE_ALIAS -# from .layer.loss import BCELoss #DEFINE_ALIAS +from .layer.loss import BCELoss #DEFINE_ALIAS # from .layer.learning_rate import CosineDecay #DEFINE_ALIAS # from .layer.learning_rate import ExponentialDecay #DEFINE_ALIAS # from .layer.learning_rate import InverseTimeDecay #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index db1ff750ecf..5b8e819521e 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -20,7 +20,7 @@ __all__ = [ # 'MSELoss', 'L1Loss', # 'NLLLoss', - # 'BCELoss' + 'BCELoss' ] @@ -109,3 +109,112 @@ class L1Loss(fluid.dygraph.Layer): return fluid.layers.reduce_mean(unreduced) else: return unreduced + + +class BCELoss(fluid.dygraph.Layer): + """ + This op accepts input predictions and target label and returns binary + cross entropy error. + For predictions label, and target label, the loss is calculated as follows. + If :attr:`weight` is set, the loss is: + Out = -1 * weight * (label * log(input) + (1 - label) * log(1 - input)) + If :attr:`weight` is None, the loss is: + Out = -1 * (label * log(input) + (1 - label) * log(1 - input)) + + If :attr:`reduction` set to ``'none'``, the unreduced loss is: + .. math:: + Out = Out + If :attr:`reduction` set to ``'mean'``, the reduced mean loss is: + .. math:: + Out = MEAN(Out) + If :attr:`reduction` set to ``'sum'``, the reduced sum loss is: + .. math:: + Out = SUM(Out) + Parameters: + input (Variable): Input tensor, the data type is float32, + float64. Input must in (0, 1). + label (Variable): Label tensor, has the same shape with input, + the data type is float32, float64. + weight (Variable, optional): Weight tensor, a manual rescaling weight given + to each class. It has the same dimensions as class number and the data type + is float32, float64, int32, int64. Default is ``'None'``. + reduction (str, optional): Indicate how to average the loss by batch_size, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + Default is ``'mean'``. + Returns: + The tensor variable storing the bce_loss of input and label. + Return type: Variable. + Examples: + .. code-block:: python + # declarative mode + import paddle.fluid as fluid + import numpy as np + import paddle + input = fluid.data(name="input", shape=[3, 1], dtype='float32') + label = fluid.data(name="label", shape=[3, 1], dtype='float32') + bce_loss = paddle.nn.loss.BCELoss() + output = bce_loss(input, label) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + input_data = np.array([0.5, 0.6, 0.7]).astype("float32") + label_data = np.array([1.0, 0.0, 1.0]).astype("float32") + output_data = exe.run(fluid.default_main_program(), + feed={"input":input_data, "label":label_data}, + fetch_list=[output], + return_numpy=True) + + print(output_data) # [array([0.65537095], dtype=float32)] + + # imperative mode + import paddle.fluid.dygraph as dg + with dg.guard(place) as g: + input = dg.to_variable(input_data) + label = dg.to_variable(label_data) + output = bce_loss(input, label) + print(output.numpy()) # [0.65537095] + """ + + def __init__(self, weight=None, reduction='mean'): + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in bce_loss should be 'sum', 'mean' or 'none', but " + "received %s, which is not allowed." % reduction) + + super(BCELoss, self).__init__() + self.weight = weight + self.reduction = reduction + + def forward(self, input, label): + dtype = self._helper.input_dtype(input) + + fluid.data_feeder.check_variable_and_dtype( + input, 'input', ['float32', 'float64'], 'bce_loss') + fluid.data_feeder.check_variable_and_dtype( + label, 'label', ['float32', 'float64'], 'bce_loss') + + out = self._helper.create_variable_for_type_inference(dtype=input.dtype) + self._helper.append_op( + type='bce_loss', + inputs={ + 'X': [input], + 'Label': [label], + }, + outputs={'Out': [out]}) + + if self.weight is not None: + if isinstance(self.weight, fluid.framework.Variable): + w = self.weight + out = fluid.layers.elementwise_mul(out, w, axis=0) + else: + raise ValueError( + "The weight is not a Variable, please convert to Variable.") + + if self.reduction == 'sum': + return fluid.layers.reduce_sum(out) + elif self.reduction == 'mean': + return fluid.layers.reduce_mean(out) + else: + return out -- GitLab