From 46ec5b3e5eb784e5bb3701ab0472d26f16ec9e6b Mon Sep 17 00:00:00 2001 From: shangliang Xu Date: Wed, 15 Sep 2021 14:37:14 +0800 Subject: [PATCH] upgrade dice_loss (#35734) --- python/paddle/fluid/layers/nn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 534ebf231a..3796585f08 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -7192,7 +7192,8 @@ def dice_loss(input, label, epsilon=0.00001, name=None): assert input.numel() > 0 and label.numel() > 0, \ "Any dimension of input and label cannot be equal to 0." - label = one_hot(label, depth=input.shape[-1]) + label = squeeze(label, [-1]) + label = paddle.nn.functional.one_hot(label, input.shape[-1]) reduce_dim = list(range(1, len(input.shape))) inse = reduce_sum(input * label, dim=reduce_dim) dice_denominator = reduce_sum( -- GitLab