未验证 提交 ad6c3b92 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] fix dice_loss bug (#34757)

* fix dice_loss bug
上级 e84b2e9b
...@@ -7105,11 +7105,11 @@ def dice_loss(input, label, epsilon=0.00001, name=None): ...@@ -7105,11 +7105,11 @@ def dice_loss(input, label, epsilon=0.00001, name=None):
Parameters: Parameters:
input (Tensor): Tensor, rank>=2, shape is :math:`[N_1, N_2, ..., N_D]`, where :math:`N_1` is input (Tensor): Tensor, rank>=2, shape is :math:`[N_1, N_2, ..., N_k, D]`, where :math:`N_1` is
the batch_size, :math:`N_D` is 1. It is usually the output predictions of sigmoid activation. the batch_size, :math:`D` is the number of categories. It is usually the output
The data type can be float32 or float64. predictions of sigmoid activation. The data type can be float32 or float64.
label (Tensor): Tensor, the groud truth with the same rank as input, shape is :math:`[N_1, N_2, ..., N_D]`. label (Tensor): Tensor, the groud truth with the same rank as input, shape is :math:`[N_1, N_2, ..., N_k, 1]`.
where :math:`N_1` is the batch_size, :math:`N_D` is 1. The data type can be float32 or float64. where :math:`N_1` is the batch_size. The data type can be int32 or int64.
epsilon (float): The epsilon will be added to the numerator and denominator. epsilon (float): The epsilon will be added to the numerator and denominator.
If both input and label are empty, it makes sure dice is 1. If both input and label are empty, it makes sure dice is 1.
Default: 0.00001 Default: 0.00001
...@@ -7131,6 +7131,21 @@ def dice_loss(input, label, epsilon=0.00001, name=None): ...@@ -7131,6 +7131,21 @@ def dice_loss(input, label, epsilon=0.00001, name=None):
predictions = F.softmax(x) predictions = F.softmax(x)
loss = F.dice_loss(input=predictions, label=label) loss = F.dice_loss(input=predictions, label=label)
""" """
assert input.dtype in (paddle.float32, paddle.float64)
assert label.dtype in (paddle.int32, paddle.int64)
assert len(input.shape) >= 2, \
"The rank of input should be greater than or equal to 2."
assert len(input.shape) == len(label.shape), (
"The rank of input and label should be equal, "
"but received input: %d, label: %d." %
(len(input.shape), len(label.shape)))
assert label.shape[-1] == 1, ("The last dimension of label should be 1, "
"but received %d." % label.shape[-1])
assert input.shape[:-1] == label.shape[:-1], (
"All dimensions should be equal except the last one.")
assert input.numel() > 0 and label.numel() > 0, \
"Any dimension of input and label cannot be equal to 0."
label = one_hot(label, depth=input.shape[-1]) label = one_hot(label, depth=input.shape[-1])
reduce_dim = list(range(1, len(input.shape))) reduce_dim = list(range(1, len(input.shape)))
inse = reduce_sum(input * label, dim=reduce_dim) inse = reduce_sum(input * label, dim=reduce_dim)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.layers.nn as nn
num_classes = 4
eps = 1e-6
class TestDiceLossValue(unittest.TestCase):
def test_dice_loss(self):
input_ = paddle.rand([2, 3, num_classes])
label_ = paddle.randint(0, num_classes, [2, 3, 1], dtype=paddle.int64)
input_np, label_np = input_.numpy(), label_.numpy()
eye_np = np.eye(num_classes)
label_np = np.float32(eye_np[np.squeeze(label_np)])
input_np = np.reshape(input_np, [2, -1])
label_np = np.reshape(label_np, [2, -1])
intersection_np = np.sum(input_np * label_np, axis=-1)
union_np = input_np.sum(-1) + label_np.sum(-1)
dice_np = np.mean(1 - 2 * intersection_np / (union_np + eps))
dice_paddle = nn.dice_loss(input_, label_, eps)
self.assertTrue(np.isclose(dice_np, dice_paddle.numpy()).all())
class TestDiceLossInvalidInput(unittest.TestCase):
def test_error(self):
def test_invalid_dtype():
input_ = paddle.rand([2, 3, num_classes], dtype=paddle.float32)
label_ = paddle.randint(
0, num_classes, [2, 3, 1], dtype=paddle.int64)
nn.dice_loss(input_, label_.astype(paddle.float32))
self.assertRaises(AssertionError, test_invalid_dtype)
def test_zero_shape_input():
input_ = paddle.rand([0, 3, num_classes], dtype=paddle.float32)
label_ = paddle.randint(
0, num_classes, [0, 3, 1], dtype=paddle.int64)
nn.dice_loss(input_, label_)
self.assertRaises(AssertionError, test_zero_shape_input)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册