From bf4a4636f16f7ed2e870485533da9522e88117f9 Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Mon, 24 Aug 2020 11:32:47 +0800 Subject: [PATCH] change to use bce_loss op, add shape check for bce_loss change to use bce_loss op, add numel check for bce_loss. --- paddle/fluid/operators/bce_loss_op.cc | 64 +++++++++++------- paddle/fluid/operators/bce_loss_op.cu | 5 +- paddle/fluid/operators/bce_loss_op.h | 4 +- .../fluid/tests/unittests/test_bce_loss.py | 14 ---- python/paddle/nn/functional/loss.py | 67 ++++++++----------- 5 files changed, 73 insertions(+), 81 deletions(-) diff --git a/paddle/fluid/operators/bce_loss_op.cc b/paddle/fluid/operators/bce_loss_op.cc index 50797a100b..f56789b889 100644 --- a/paddle/fluid/operators/bce_loss_op.cc +++ b/paddle/fluid/operators/bce_loss_op.cc @@ -32,22 +32,29 @@ class BCELossOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "BCELoss"); auto x_dims = ctx->GetInputDim("X"); - auto label_dims = ctx->GetInputDim("Label"); - PADDLE_ENFORCE_EQ( - x_dims.size(), label_dims.size(), - platform::errors::InvalidArgument( - "Input(X) and Input(Label) shall have the same shape.")); - bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) || - framework::contain_unknown_dim(label_dims); - bool check = ctx->IsRuntime() || !contain_unknown_dim; + auto labels_dims = ctx->GetInputDim("Label"); + + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(rank, labels_dims.size(), + platform::errors::InvalidArgument( + "Input(X) and Input(Label) shall have the same rank." + "But received: the rank of Input(X) is [%d], " + "the rank of Input(Label) is [%d].", + rank, labels_dims.size())); + + bool check = true; + if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || + framework::product(labels_dims) <= 0)) { + check = false; + } + if (check) { - PADDLE_ENFORCE_EQ( - x_dims.size(), label_dims.size(), - platform::errors::InvalidArgument( - "ShapeError: Input(X) and Input(Label) shall have the same shape " - "But received: the shape of Input(X) is [%s], the shape of " - "Input(Label) is [%s].", - x_dims, label_dims)); + PADDLE_ENFORCE_EQ(x_dims, labels_dims, + platform::errors::InvalidArgument( + "Input(X) and Input(Label) shall have the same " + "shape. But received: the shape of Input(X) is " + "[%s], the shape of Input(Label) is [%s].", + x_dims, labels_dims)); } ctx->ShareDim("X", "Out"); @@ -76,20 +83,31 @@ class BCELossGradOp : public framework::OperatorWithKernel { framework::GradVarName("X"), "BCELossGrad"); auto x_dims = ctx->GetInputDim("X"); + auto labels_dims = ctx->GetInputDim("Label"); auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); - bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) || - framework::contain_unknown_dim(dout_dims); - bool check = ctx->IsRuntime() || !contain_unknown_dim; + + bool check = true; + if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || + framework::product(labels_dims) <= 0)) { + check = false; + } + if (check) { + PADDLE_ENFORCE_EQ(x_dims, labels_dims, + platform::errors::InvalidArgument( + "Input(X) and Input(Label) shall have the same " + "shape. But received: the shape of Input(X) is " + "[%s], the shape of Input(Label) is [%s].", + x_dims, labels_dims)); + PADDLE_ENFORCE_EQ(x_dims, dout_dims, platform::errors::InvalidArgument( - "ShapeError:The Input(X) and Input(Out@Grad) " - "should have the same " - "shape, But received: the shape of Input(X) is " - "[%s], the shape of " - "Input(Out@GRAD) is [%s].", + "Input(X) and Input(Out@Grad) shall have the same " + "shape. But received: the shape of Input(X) is " + "[%s], the shape of Input(Out@Grad) is [%s].", x_dims, dout_dims)); } + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); ctx->ShareLoD("X", framework::GradVarName("X")); } diff --git a/paddle/fluid/operators/bce_loss_op.cu b/paddle/fluid/operators/bce_loss_op.cu index 8e30f4eb15..16db4f05e3 100644 --- a/paddle/fluid/operators/bce_loss_op.cu +++ b/paddle/fluid/operators/bce_loss_op.cu @@ -67,7 +67,8 @@ class BCELossCUDAKernel : public framework::OpKernel { auto x_data = x->data(); auto out_data = out->mutable_data(ctx.GetPlace()); - int x_numel = x->numel(); + auto x_numel = x->numel(); + platform::GpuLaunchConfig config = platform::getGpuLaunchConfig(x_numel, ctx); @@ -75,7 +76,7 @@ class BCELossCUDAKernel : public framework::OpKernel { framework::TensorCopy(*x, platform::CPUPlace(), &x_cpu); T* x_cpu_data = x_cpu.data(); - for (int i = 0; i < x_numel; ++i) { + for (int64_t i = 0; i < x_numel; ++i) { PADDLE_ENFORCE_GE( x_cpu_data[i], static_cast(0), platform::errors::InvalidArgument( diff --git a/paddle/fluid/operators/bce_loss_op.h b/paddle/fluid/operators/bce_loss_op.h index 85e120e464..dd87b69efe 100644 --- a/paddle/fluid/operators/bce_loss_op.h +++ b/paddle/fluid/operators/bce_loss_op.h @@ -34,11 +34,11 @@ class BCELossOpKernel : public framework::OpKernel { auto x_data = x->data(); auto label_data = labels->data(); auto out_data = out->mutable_data(ctx.GetPlace()); - int x_numel = x->numel(); + auto x_numel = x->numel(); // out = -(label * ln(x) + (1 - label) * ln(1 - x)) = (label - 1) * ln(1 - // x) - label * ln(x) - for (int i = 0; i < x_numel; ++i) { + for (int64_t i = 0; i < x_numel; ++i) { PADDLE_ENFORCE_GE( x_data[i], static_cast(0), platform::errors::InvalidArgument( diff --git a/python/paddle/fluid/tests/unittests/test_bce_loss.py b/python/paddle/fluid/tests/unittests/test_bce_loss.py index b04aa49c60..a8054295b4 100644 --- a/python/paddle/fluid/tests/unittests/test_bce_loss.py +++ b/python/paddle/fluid/tests/unittests/test_bce_loss.py @@ -189,20 +189,6 @@ class TestBCELoss(unittest.TestCase): self.assertTrue(np.allclose(static_functional, dy_functional)) self.assertTrue(np.allclose(dy_functional, expected)) - def test_BCELoss_boardcast(self): - input_np = np.random.uniform( - 0.1, 0.8, size=(2, 3, 4, 10)).astype(np.float64) - label_np = np.random.randint(0, 2, size=(3, 4, 10)).astype(np.float64) - place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( - ) else fluid.CPUPlace() - - static_result = test_static_layer(place, input_np, label_np) - dy_result = test_dygraph_layer(place, input_np, label_np) - expected = calc_bceloss(input_np, label_np) - self.assertTrue(np.allclose(static_result, expected)) - self.assertTrue(np.allclose(static_result, dy_result)) - self.assertTrue(np.allclose(dy_result, expected)) - def test_BCELoss_error(self): paddle.disable_static() self.assertRaises( diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 9069630d86..55bb36d136 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -157,19 +157,7 @@ def binary_cross_entropy(input, label, weight=None, reduction='mean', reduction) if in_dygraph_mode(): - one = _varbase_creator(dtype=input.dtype) - core.ops.fill_constant(one, 'value', - float(1.0), 'force_cpu', False, 'dtype', - one.dtype, 'str_value', '1.0', 'shape', [1]) - one.stop_gradient = True - label_minus = core.ops.elementwise_sub(label, one) - input_minus = core.ops.elementwise_sub(one, input) - input_minus_log = core.ops.log(input_minus) - input_log = core.ops.log(input) - loss_1 = core.ops.elementwise_mul(label_minus, input_minus_log) - loss_2 = core.ops.elementwise_mul(label, input_log) - out = core.ops.elementwise_sub(loss_1, loss_2) - + out = core.ops.bce_loss(input, label) if weight is not None: out = core.ops.elementwise_mul(out, weight, 'axis', -1) @@ -187,17 +175,16 @@ def binary_cross_entropy(input, label, weight=None, reduction='mean', fluid.data_feeder.check_variable_and_dtype( label, 'label', ['float32', 'float64'], 'binary_cross_entropy') - one = paddle.fill_constant(shape=[1], value=1.0, dtype=input.dtype) - one.stop_gradient = True - label_minus = paddle.elementwise_sub(label, one) - input_minus = paddle.elementwise_sub(one, input) - input_minus_log = paddle.log(input_minus) - input_log = paddle.log(input) - loss_1 = paddle.multiply(label_minus, input_minus_log) - loss_2 = paddle.multiply(label, input_log) - sub_name = name if weight is None and reduction is 'none' else None - out = paddle.elementwise_sub(loss_1, loss_2, name=sub_name) + helper = LayerHelper("binary_cross_entropy", name=sub_name) + out = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type='bce_loss', + inputs={ + 'X': [input], + 'Label': [label], + }, + outputs={'Out': [out]}) if weight is not None: if isinstance(weight, paddle.framework.Variable): @@ -952,9 +939,9 @@ def ctc_loss(log_probs, reduction='mean'): """ - An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc) - to compute Connectionist Temporal Classification (CTC) loss. - It can be aliased as softmax with CTC, since a native softmax activation + An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc) + to compute Connectionist Temporal Classification (CTC) loss. + It can be aliased as softmax with CTC, since a native softmax activation is interated to the Warp-CTC library to normalize values for each row of the input tensor. Parameters: @@ -967,7 +954,7 @@ def ctc_loss(log_probs, Returns: Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``. - + Examples: .. code-block:: python @@ -1012,18 +999,18 @@ def ctc_loss(log_probs, input_lengths = paddle.to_tensor(input_lengths) label_lengths = paddle.to_tensor(label_lengths) - loss = F.ctc_loss(log_probs, labels, - input_lengths, - label_lengths, - blank=0, + loss = F.ctc_loss(log_probs, labels, + input_lengths, + label_lengths, + blank=0, reduction='none') print(loss.numpy()) #[3.9179852 2.9076521] - loss = F.ctc_loss(log_probs, labels, - input_lengths, - label_lengths, - blank=0, - reduction='mean') + loss = F.ctc_loss(log_probs, labels, + input_lengths, + label_lengths, + blank=0, + reduction='mean') print(loss.numpy()) #[1.1376063] """ @@ -1071,8 +1058,8 @@ def cross_entropy(input, Parameters: input (Tensor): Input tensor, the data type is float32, float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this - is (N, C, D1, D2,..., Dk), k >= 1. - label (Tensor): Label tensor, the data type is int64. Shape is (N), where each + is (N, C, D1, D2,..., Dk), k >= 1. + label (Tensor): Label tensor, the data type is int64. Shape is (N), where each value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is (N, D1, D2,..., Dk), k >= 1. weight (Tensor, optional): Weight tensor, a manual rescaling weight given @@ -1105,7 +1092,7 @@ def cross_entropy(input, weight = paddle.to_tensor(weight_data) loss = paddle.nn.functional.cross_entropy(input=input, label=label, weight=weight) print(loss.numpy()) - + """ if not in_dygraph_mode(): fluid.data_feeder.check_variable_and_dtype( @@ -1124,7 +1111,7 @@ def cross_entropy(input, raise ValueError( "The weight' is not a Variable, please convert to Variable.") - #step 2. nll_loss + #step 2. nll_loss input = log_softmax_out helper = LayerHelper('nll_loss', **locals()) dtype = helper.input_dtype(input) -- GitLab