Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ad6c3b92
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ad6c3b92
编写于
8月 16, 2021
作者:
S
shangliang Xu
提交者:
GitHub
8月 16, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[dev] fix dice_loss bug (#34757)
* fix dice_loss bug
上级
e84b2e9b
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
83 addition
and
5 deletion
+83
-5
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+20
-5
python/paddle/fluid/tests/unittests/test_nn_dice_loss.py
python/paddle/fluid/tests/unittests/test_nn_dice_loss.py
+63
-0
未找到文件。
python/paddle/fluid/layers/nn.py
浏览文件 @
ad6c3b92
...
...
@@ -7105,11 +7105,11 @@ def dice_loss(input, label, epsilon=0.00001, name=None):
Parameters:
input (Tensor): Tensor, rank>=2, shape is :math:`[N_1, N_2, ..., N_D]`, where :math:`N_1` is
the batch_size, :math:`
N_D` is 1. It is usually the output predictions of sigmoid activation.
The data type can be float32 or float64.
label (Tensor): Tensor, the groud truth with the same rank as input, shape is :math:`[N_1, N_2, ..., N_
D
]`.
where :math:`N_1` is the batch_size
, :math:`N_D` is 1. The data type can be float32 or floa
t64.
input (Tensor): Tensor, rank>=2, shape is :math:`[N_1, N_2, ..., N_
k,
D]`, where :math:`N_1` is
the batch_size, :math:`
D` is the number of categories. It is usually the output
predictions of sigmoid activation.
The data type can be float32 or float64.
label (Tensor): Tensor, the groud truth with the same rank as input, shape is :math:`[N_1, N_2, ..., N_
k, 1
]`.
where :math:`N_1` is the batch_size
. The data type can be int32 or in
t64.
epsilon (float): The epsilon will be added to the numerator and denominator.
If both input and label are empty, it makes sure dice is 1.
Default: 0.00001
...
...
@@ -7131,6 +7131,21 @@ def dice_loss(input, label, epsilon=0.00001, name=None):
predictions = F.softmax(x)
loss = F.dice_loss(input=predictions, label=label)
"""
assert input.dtype in (paddle.float32, paddle.float64)
assert label.dtype in (paddle.int32, paddle.int64)
assert len(input.shape) >= 2, \
"The rank of input should be greater than or equal to 2."
assert len(input.shape) == len(label.shape), (
"The rank of input and label should be equal, "
"but received input: %d, label: %d." %
(len(input.shape), len(label.shape)))
assert label.shape[-1] == 1, ("The last dimension of label should be 1, "
"but received %d." % label.shape[-1])
assert input.shape[:-1] == label.shape[:-1], (
"All dimensions should be equal except the last one.")
assert input.numel() > 0 and label.numel() > 0, \
"Any dimension of input and label cannot be equal to 0."
label = one_hot(label, depth=input.shape[-1])
reduce_dim = list(range(1, len(input.shape)))
inse = reduce_sum(input * label, dim=reduce_dim)
...
...
python/paddle/fluid/tests/unittests/test_nn_dice_loss.py
0 → 100644
浏览文件 @
ad6c3b92
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle
import
paddle.fluid.layers.nn
as
nn
num_classes
=
4
eps
=
1e-6
class
TestDiceLossValue
(
unittest
.
TestCase
):
def
test_dice_loss
(
self
):
input_
=
paddle
.
rand
([
2
,
3
,
num_classes
])
label_
=
paddle
.
randint
(
0
,
num_classes
,
[
2
,
3
,
1
],
dtype
=
paddle
.
int64
)
input_np
,
label_np
=
input_
.
numpy
(),
label_
.
numpy
()
eye_np
=
np
.
eye
(
num_classes
)
label_np
=
np
.
float32
(
eye_np
[
np
.
squeeze
(
label_np
)])
input_np
=
np
.
reshape
(
input_np
,
[
2
,
-
1
])
label_np
=
np
.
reshape
(
label_np
,
[
2
,
-
1
])
intersection_np
=
np
.
sum
(
input_np
*
label_np
,
axis
=-
1
)
union_np
=
input_np
.
sum
(
-
1
)
+
label_np
.
sum
(
-
1
)
dice_np
=
np
.
mean
(
1
-
2
*
intersection_np
/
(
union_np
+
eps
))
dice_paddle
=
nn
.
dice_loss
(
input_
,
label_
,
eps
)
self
.
assertTrue
(
np
.
isclose
(
dice_np
,
dice_paddle
.
numpy
()).
all
())
class
TestDiceLossInvalidInput
(
unittest
.
TestCase
):
def
test_error
(
self
):
def
test_invalid_dtype
():
input_
=
paddle
.
rand
([
2
,
3
,
num_classes
],
dtype
=
paddle
.
float32
)
label_
=
paddle
.
randint
(
0
,
num_classes
,
[
2
,
3
,
1
],
dtype
=
paddle
.
int64
)
nn
.
dice_loss
(
input_
,
label_
.
astype
(
paddle
.
float32
))
self
.
assertRaises
(
AssertionError
,
test_invalid_dtype
)
def
test_zero_shape_input
():
input_
=
paddle
.
rand
([
0
,
3
,
num_classes
],
dtype
=
paddle
.
float32
)
label_
=
paddle
.
randint
(
0
,
num_classes
,
[
0
,
3
,
1
],
dtype
=
paddle
.
int64
)
nn
.
dice_loss
(
input_
,
label_
)
self
.
assertRaises
(
AssertionError
,
test_zero_shape_input
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录