From ad6c3b9222b85c0aa69e6ee79b843ebf9b465a24 Mon Sep 17 00:00:00 2001 From: shangliang Xu Date: Mon, 16 Aug 2021 11:15:25 +0800 Subject: [PATCH] [dev] fix dice_loss bug (#34757) * fix dice_loss bug --- python/paddle/fluid/layers/nn.py | 25 ++++++-- .../tests/unittests/test_nn_dice_loss.py | 63 +++++++++++++++++++ 2 files changed, 83 insertions(+), 5 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_nn_dice_loss.py diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index dc1e56f13f3..656f1efe493 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -7105,11 +7105,11 @@ def dice_loss(input, label, epsilon=0.00001, name=None): Parameters: - input (Tensor): Tensor, rank>=2, shape is :math:`[N_1, N_2, ..., N_D]`, where :math:`N_1` is - the batch_size, :math:`N_D` is 1. It is usually the output predictions of sigmoid activation. - The data type can be float32 or float64. - label (Tensor): Tensor, the groud truth with the same rank as input, shape is :math:`[N_1, N_2, ..., N_D]`. - where :math:`N_1` is the batch_size, :math:`N_D` is 1. The data type can be float32 or float64. + input (Tensor): Tensor, rank>=2, shape is :math:`[N_1, N_2, ..., N_k, D]`, where :math:`N_1` is + the batch_size, :math:`D` is the number of categories. It is usually the output + predictions of sigmoid activation. The data type can be float32 or float64. + label (Tensor): Tensor, the groud truth with the same rank as input, shape is :math:`[N_1, N_2, ..., N_k, 1]`. + where :math:`N_1` is the batch_size. The data type can be int32 or int64. epsilon (float): The epsilon will be added to the numerator and denominator. If both input and label are empty, it makes sure dice is 1. Default: 0.00001 @@ -7131,6 +7131,21 @@ def dice_loss(input, label, epsilon=0.00001, name=None): predictions = F.softmax(x) loss = F.dice_loss(input=predictions, label=label) """ + assert input.dtype in (paddle.float32, paddle.float64) + assert label.dtype in (paddle.int32, paddle.int64) + assert len(input.shape) >= 2, \ + "The rank of input should be greater than or equal to 2." + assert len(input.shape) == len(label.shape), ( + "The rank of input and label should be equal, " + "but received input: %d, label: %d." % + (len(input.shape), len(label.shape))) + assert label.shape[-1] == 1, ("The last dimension of label should be 1, " + "but received %d." % label.shape[-1]) + assert input.shape[:-1] == label.shape[:-1], ( + "All dimensions should be equal except the last one.") + assert input.numel() > 0 and label.numel() > 0, \ + "Any dimension of input and label cannot be equal to 0." + label = one_hot(label, depth=input.shape[-1]) reduce_dim = list(range(1, len(input.shape))) inse = reduce_sum(input * label, dim=reduce_dim) diff --git a/python/paddle/fluid/tests/unittests/test_nn_dice_loss.py b/python/paddle/fluid/tests/unittests/test_nn_dice_loss.py new file mode 100644 index 00000000000..31606376777 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nn_dice_loss.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.layers.nn as nn + +num_classes = 4 +eps = 1e-6 + + +class TestDiceLossValue(unittest.TestCase): + def test_dice_loss(self): + input_ = paddle.rand([2, 3, num_classes]) + label_ = paddle.randint(0, num_classes, [2, 3, 1], dtype=paddle.int64) + + input_np, label_np = input_.numpy(), label_.numpy() + eye_np = np.eye(num_classes) + label_np = np.float32(eye_np[np.squeeze(label_np)]) + input_np = np.reshape(input_np, [2, -1]) + label_np = np.reshape(label_np, [2, -1]) + intersection_np = np.sum(input_np * label_np, axis=-1) + union_np = input_np.sum(-1) + label_np.sum(-1) + dice_np = np.mean(1 - 2 * intersection_np / (union_np + eps)) + dice_paddle = nn.dice_loss(input_, label_, eps) + self.assertTrue(np.isclose(dice_np, dice_paddle.numpy()).all()) + + +class TestDiceLossInvalidInput(unittest.TestCase): + def test_error(self): + def test_invalid_dtype(): + input_ = paddle.rand([2, 3, num_classes], dtype=paddle.float32) + label_ = paddle.randint( + 0, num_classes, [2, 3, 1], dtype=paddle.int64) + nn.dice_loss(input_, label_.astype(paddle.float32)) + + self.assertRaises(AssertionError, test_invalid_dtype) + + def test_zero_shape_input(): + input_ = paddle.rand([0, 3, num_classes], dtype=paddle.float32) + label_ = paddle.randint( + 0, num_classes, [0, 3, 1], dtype=paddle.int64) + nn.dice_loss(input_, label_) + + self.assertRaises(AssertionError, test_zero_shape_input) + + +if __name__ == "__main__": + unittest.main() -- GitLab