Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3423f0b6
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3423f0b6
编写于
11月 26, 2019
作者:
W
WangXi
提交者:
gongweibao
11月 26, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix INF bug of softmax_cross_entropy_op, test=release/1.6 (#21283)
上级
9a98d11e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
62 addition
and
28 deletion
+62
-28
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
+7
-4
python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py
...e/fluid/tests/unittests/test_fused_multihead_matmul_op.py
+3
-1
python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py
...on/paddle/fluid/tests/unittests/test_multiclass_nms_op.py
+8
-20
python/paddle/fluid/tests/unittests/test_softmax_op.py
python/paddle/fluid/tests/unittests/test_softmax_op.py
+3
-1
python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py
...uid/tests/unittests/test_softmax_with_cross_entropy_op.py
+41
-2
未找到文件。
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
浏览文件 @
3423f0b6
...
...
@@ -150,10 +150,7 @@ static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
cur_max
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
cur_max
,
cub
::
Max
());
if
(
threadIdx
.
x
==
0
)
{
max_data
[
blockIdx
.
x
]
=
cur_max
<
static_cast
<
T
>
(
-
64
)
?
static_cast
<
T
>
(
-
64
)
:
cur_max
;
}
if
(
threadIdx
.
x
==
0
)
max_data
[
blockIdx
.
x
]
=
cur_max
;
}
// Make sure that BlockDim <= axis_dim
...
...
@@ -175,6 +172,12 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
auto
block_max
=
max_data
[
blockIdx
.
x
];
int
step
=
BlockDim
*
remain
;
// In numeric stable mode softmax_with_loss, we calc loss with
// tmp_i_j = x_i_j - max_i - logDiffMaxSum_i, instead of
// log(exp(x_i_j - max_i)/DiffMaxSum_i). Therefore, log(0) will not occur.
// Also we calc softmax_i_j = e^{tmp_i_j}, the maximum and minimum value will
// be 1.0 and 0.0, represent prob is 1.0 and 0.0.
// So there is no need to clip on shift_softmax.
softmax
[
beg_idx
]
=
logits_data
[
beg_idx
]
-
block_max
;
T
diff_max_sum
=
exp_on_device
(
softmax
[
beg_idx
]);
auto
idx
=
beg_idx
+
step
;
...
...
python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py
浏览文件 @
3423f0b6
...
...
@@ -25,7 +25,9 @@ np.random.random(123)
def
stable_softmax
(
x
):
"""Compute the softmax of vector x in a numerically stable way."""
shiftx
=
x
-
np
.
max
(
x
).
clip
(
-
64.
)
# clip to shiftx, otherwise, when calc loss with
# log(exp(shiftx)), may get log(0)=INF
shiftx
=
(
x
-
np
.
max
(
x
)).
clip
(
-
64.
)
exps
=
np
.
exp
(
shiftx
)
return
exps
/
np
.
sum
(
exps
)
...
...
python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py
浏览文件 @
3423f0b6
...
...
@@ -19,6 +19,14 @@ import copy
from
op_test
import
OpTest
def
softmax
(
x
):
# clip to shiftx, otherwise, when calc loss with
# log(exp(shiftx)), may get log(0)=INF
shiftx
=
(
x
-
np
.
max
(
x
)).
clip
(
-
64.
)
exps
=
np
.
exp
(
shiftx
)
return
exps
/
np
.
sum
(
exps
)
def
iou
(
box_a
,
box_b
,
norm
):
"""Apply intersection-over-union overlap between box_a and box_b
"""
...
...
@@ -254,11 +262,6 @@ class TestMulticlassNMSOp(OpTest):
scores
=
np
.
random
.
random
((
N
*
M
,
C
)).
astype
(
'float32'
)
def
softmax
(
x
):
shiftx
=
x
-
np
.
max
(
x
).
clip
(
-
64.
)
exps
=
np
.
exp
(
shiftx
)
return
exps
/
np
.
sum
(
exps
)
scores
=
np
.
apply_along_axis
(
softmax
,
1
,
scores
)
scores
=
np
.
reshape
(
scores
,
(
N
,
M
,
C
))
scores
=
np
.
transpose
(
scores
,
(
0
,
2
,
1
))
...
...
@@ -318,11 +321,6 @@ class TestMulticlassNMSLoDInput(OpTest):
scores
=
np
.
random
.
random
((
M
,
C
)).
astype
(
'float32'
)
def
softmax
(
x
):
shiftx
=
x
-
np
.
max
(
x
).
clip
(
-
64.
)
exps
=
np
.
exp
(
shiftx
)
return
exps
/
np
.
sum
(
exps
)
scores
=
np
.
apply_along_axis
(
softmax
,
1
,
scores
)
boxes
=
np
.
random
.
random
((
M
,
C
,
BOX_SIZE
)).
astype
(
'float32'
)
...
...
@@ -382,11 +380,6 @@ class TestMulticlassNMS2Op(TestMulticlassNMSOp):
scores
=
np
.
random
.
random
((
N
*
M
,
C
)).
astype
(
'float32'
)
def
softmax
(
x
):
shiftx
=
x
-
np
.
max
(
x
).
clip
(
-
64.
)
exps
=
np
.
exp
(
shiftx
)
return
exps
/
np
.
sum
(
exps
)
scores
=
np
.
apply_along_axis
(
softmax
,
1
,
scores
)
scores
=
np
.
reshape
(
scores
,
(
N
,
M
,
C
))
scores
=
np
.
transpose
(
scores
,
(
0
,
2
,
1
))
...
...
@@ -447,11 +440,6 @@ class TestMulticlassNMS2LoDInput(TestMulticlassNMSLoDInput):
scores
=
np
.
random
.
random
((
M
,
C
)).
astype
(
'float32'
)
def
softmax
(
x
):
shiftx
=
x
-
np
.
max
(
x
).
clip
(
-
64.
)
exps
=
np
.
exp
(
shiftx
)
return
exps
/
np
.
sum
(
exps
)
scores
=
np
.
apply_along_axis
(
softmax
,
1
,
scores
)
boxes
=
np
.
random
.
random
((
M
,
C
,
BOX_SIZE
)).
astype
(
'float32'
)
...
...
python/paddle/fluid/tests/unittests/test_softmax_op.py
浏览文件 @
3423f0b6
...
...
@@ -24,7 +24,9 @@ from paddle.fluid import compiler, Program, program_guard
def
stable_softmax
(
x
):
"""Compute the softmax of vector x in a numerically stable way."""
shiftx
=
x
-
np
.
max
(
x
).
clip
(
-
64.
)
# clip to shiftx, otherwise, when calc loss with
# log(exp(shiftx)), may get log(0)=INF
shiftx
=
(
x
-
np
.
max
(
x
)).
clip
(
-
64.
)
exps
=
np
.
exp
(
shiftx
)
return
exps
/
np
.
sum
(
exps
)
...
...
python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py
浏览文件 @
3423f0b6
...
...
@@ -58,7 +58,9 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
def
setUp
(
self
):
self
.
initParams
()
logits
=
np
.
random
.
uniform
(
0.1
,
1.0
,
self
.
shape
).
astype
(
self
.
dtype
)
logits
=
getattr
(
self
,
"logits"
,
np
.
random
.
uniform
(
0.1
,
1.0
,
self
.
shape
).
astype
(
self
.
dtype
))
softmax
=
np
.
apply_along_axis
(
stable_softmax
,
self
.
axis
,
logits
)
if
self
.
soft_label
:
...
...
@@ -119,7 +121,9 @@ class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp):
self
.
op_type
=
"softmax_with_cross_entropy"
# NOTE: numpy float16 have very low accuracy, use float32 for numpy check.
logits
=
np
.
random
.
uniform
(
0.1
,
1.0
,
self
.
shape
).
astype
(
np
.
float32
)
logits
=
getattr
(
self
,
"logits"
,
np
.
random
.
uniform
(
0.1
,
1.0
,
self
.
shape
).
astype
(
np
.
float32
))
softmax
=
np
.
apply_along_axis
(
stable_softmax
,
self
.
axis
,
logits
)
axis_dim
=
self
.
shape
[
self
.
axis
]
...
...
@@ -408,5 +412,40 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4(
self
.
dtype
=
np
.
float64
class
TestSoftmaxWithCrossEntropyOpBoundary0
(
TestSoftmaxWithCrossEntropyOp
):
"""
Test stable softmax with cross entropy operator will not product INF
with small logits value.
"""
def
initParams
(
self
):
self
.
op_type
=
"softmax_with_cross_entropy"
self
.
numeric_stable_mode
=
True
self
.
soft_label
=
False
self
.
shape
=
[
3
,
5
,
7
,
11
]
self
.
axis
=
-
1
self
.
ignore_index
=
-
1
self
.
dtype
=
np
.
float32
self
.
logits
=
np
.
full
(
self
.
shape
,
-
500.0
).
astype
(
self
.
dtype
)
class
TestSoftmaxWithCrossEntropyOpBoundary1
(
TestSoftmaxWithCrossEntropyOp
):
"""
Test stable softmax with cross entropy operator will not product INF
with small logits value.
"""
def
initParams
(
self
):
self
.
op_type
=
"softmax_with_cross_entropy"
self
.
numeric_stable_mode
=
True
self
.
soft_label
=
False
self
.
shape
=
[
3
,
5
,
7
,
11
]
self
.
axis
=
-
1
self
.
ignore_index
=
-
1
self
.
dtype
=
np
.
float32
self
.
logits
=
np
.
full
(
self
.
shape
,
1000.0
).
astype
(
self
.
dtype
)
self
.
logits
[:,
:,
0
,
:]
=
-
1000.0
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录