未验证 提交 46ec5b3e 编写于 作者: S shangliang Xu 提交者: GitHub

upgrade dice_loss (#35734)

上级 a74d7fb6
......@@ -7192,7 +7192,8 @@ def dice_loss(input, label, epsilon=0.00001, name=None):
assert input.numel() > 0 and label.numel() > 0, \
"Any dimension of input and label cannot be equal to 0."
label = one_hot(label, depth=input.shape[-1])
label = squeeze(label, [-1])
label = paddle.nn.functional.one_hot(label, input.shape[-1])
reduce_dim = list(range(1, len(input.shape)))
inse = reduce_sum(input * label, dim=reduce_dim)
dice_denominator = reduce_sum(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册