Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
289677e2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
289677e2
编写于
3月 22, 2023
作者:
B
Bo Zhang
提交者:
GitHub
3月 22, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【AMP OP&Test】unit test for test_logit_op (#51051)
* test_logit_op * add cudaKernel to replace eigen impl * bf16 unit test CI
上级
de2166c0
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
138 addition
and
21 deletion
+138
-21
paddle/phi/kernels/funcs/activation_functor.h
paddle/phi/kernels/funcs/activation_functor.h
+52
-0
paddle/phi/kernels/gpu/activation_grad_kernel.cu
paddle/phi/kernels/gpu/activation_grad_kernel.cu
+4
-9
paddle/phi/kernels/gpu/activation_kernel.cu
paddle/phi/kernels/gpu/activation_kernel.cu
+3
-8
python/paddle/fluid/tests/unittests/test_logit_op.py
python/paddle/fluid/tests/unittests/test_logit_op.py
+78
-4
python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py
...luid/tests/unittests/white_list/op_accuracy_white_list.py
+1
-0
未找到文件。
paddle/phi/kernels/funcs/activation_functor.h
浏览文件 @
289677e2
...
@@ -2472,6 +2472,58 @@ struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
...
@@ -2472,6 +2472,58 @@ struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
};
};
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
template
<
typename
T
>
struct
CudaLogitFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
MT
zero
=
static_cast
<
MT
>
(
0.0
f
);
MT
one
=
static_cast
<
MT
>
(
1.0
f
);
float
eps
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"eps"
,
&
eps
}};
}
// logit(x) = ln(x/(1-x))
__device__
__forceinline__
T
operator
()(
const
T
arg_x
)
const
{
MT
x
=
static_cast
<
MT
>
(
arg_x
);
MT
y
=
min
(
x
,
(
one
-
static_cast
<
MT
>
(
eps
)));
y
=
max
(
y
,
static_cast
<
MT
>
(
eps
));
if
(
!
eps
)
{
y
=
x
<
zero
||
x
>
one
?
static_cast
<
T
>
(
NAN
)
:
log
(
y
/
(
one
-
y
));
}
else
{
y
=
log
(
y
/
(
one
-
y
));
}
return
static_cast
<
T
>
(
y
);
}
};
template
<
typename
T
>
struct
CudaLogitGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
float
eps
;
MT
zero
=
static_cast
<
MT
>
(
0.0
f
);
MT
one
=
static_cast
<
MT
>
(
1.0
f
);
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"eps"
,
&
eps
}};
}
// logit(x)' = 1/(x*(1-x))
__device__
__forceinline__
T
operator
()(
const
T
dout
,
const
T
arg_x
)
const
{
MT
x
=
static_cast
<
MT
>
(
arg_x
);
MT
dx
=
(
x
<
static_cast
<
MT
>
(
eps
)
||
x
>
one
-
static_cast
<
MT
>
(
eps
))
?
zero
:
(
static_cast
<
MT
>
(
dout
)
/
(
x
*
(
one
-
x
)));
return
static_cast
<
T
>
(
dx
);
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
ActBwdOpFwdDeps
::
kDepOut
;
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
CudaReluFunctor
:
public
BaseActivationFunctor
<
T
>
{
struct
CudaReluFunctor
:
public
BaseActivationFunctor
<
T
>
{
T
zero
=
static_cast
<
T
>
(
0.0
f
);
T
zero
=
static_cast
<
T
>
(
0.0
f
);
...
...
paddle/phi/kernels/gpu/activation_grad_kernel.cu
浏览文件 @
289677e2
...
@@ -228,6 +228,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu,
...
@@ -228,6 +228,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu,
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT
(
Relu6
,
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT
(
Relu6
,
CudaRelu6GradFunctor
,
CudaRelu6GradFunctor
,
threshold
);
threshold
);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT
(
LogitCUDA
,
CudaLogitGradFunctor
,
eps
);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
HardTanh
,
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
HardTanh
,
CudaHardTanhGradFunctor
,
CudaHardTanhGradFunctor
,
...
@@ -382,6 +385,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel)
...
@@ -382,6 +385,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
silu_grad
,
SiluGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
silu_grad
,
SiluGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
elu_grad
,
EluGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
elu_grad
,
EluGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
elu_double_grad
,
EluDoubleGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
elu_double_grad
,
EluDoubleGradKernel
)
PD_REGISTER_ACTIVATION_GRAD_KERNEL
(
logit_grad
,
LogitCUDAGradKernel
)
PD_REGISTER_KERNEL
(
expm1_grad
,
PD_REGISTER_KERNEL
(
expm1_grad
,
GPU
,
GPU
,
...
@@ -392,15 +396,6 @@ PD_REGISTER_KERNEL(expm1_grad,
...
@@ -392,15 +396,6 @@ PD_REGISTER_KERNEL(expm1_grad,
phi
::
dtype
::
float16
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
phi
::
dtype
::
bfloat16
)
{}
PD_REGISTER_KERNEL
(
logit_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
LogitGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
PD_REGISTER_KERNEL
(
square_grad
,
PD_REGISTER_KERNEL
(
square_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
...
...
paddle/phi/kernels/gpu/activation_kernel.cu
浏览文件 @
289677e2
...
@@ -109,6 +109,7 @@ DEFINE_GPU_ACTIVATION_KERNEL(Floor, CudaFloorFunctor)
...
@@ -109,6 +109,7 @@ DEFINE_GPU_ACTIVATION_KERNEL(Floor, CudaFloorFunctor)
DEFINE_GPU_ACTIVATION_KERNEL
(
Ceil
,
CudaCeilFunctor
)
DEFINE_GPU_ACTIVATION_KERNEL
(
Ceil
,
CudaCeilFunctor
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
CudaLeakyReluFunctor
,
alpha
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
CudaLeakyReluFunctor
,
alpha
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LogitCUDA
,
CudaLogitFunctor
,
eps
)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
CudaThresholdedReluFunctor
,
CudaThresholdedReluFunctor
,
threshold
)
threshold
)
...
@@ -225,14 +226,6 @@ PD_REGISTER_KERNEL(expm1,
...
@@ -225,14 +226,6 @@ PD_REGISTER_KERNEL(expm1,
double
,
double
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
phi
::
dtype
::
bfloat16
)
{}
PD_REGISTER_KERNEL
(
logit
,
GPU
,
ALL_LAYOUT
,
phi
::
LogitKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
PD_REGISTER_KERNEL
(
square
,
PD_REGISTER_KERNEL
(
square
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
...
@@ -263,6 +256,8 @@ PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
...
@@ -263,6 +256,8 @@ PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL
(
floor
,
FloorKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
floor
,
FloorKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
ceil
,
CeilKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
ceil
,
CeilKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
celu
,
CeluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
celu
,
CeluKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
logit
,
LogitCUDAKernel
)
PD_REGISTER_KERNEL
(
pow
,
PD_REGISTER_KERNEL
(
pow
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
...
...
python/paddle/fluid/tests/unittests/test_logit_op.py
浏览文件 @
289677e2
...
@@ -18,6 +18,8 @@ import numpy as np
...
@@ -18,6 +18,8 @@ import numpy as np
from
op_test
import
OpTest
from
op_test
import
OpTest
import
paddle
import
paddle
from
paddle.fluid
import
core
from
paddle.fluid.tests.unittests.op_test
import
convert_float_to_uint16
np
.
random
.
seed
(
10
)
np
.
random
.
seed
(
10
)
...
@@ -43,9 +45,6 @@ class TestLogitOp(OpTest):
...
@@ -43,9 +45,6 @@ class TestLogitOp(OpTest):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
'logit'
self
.
op_type
=
'logit'
self
.
python_api
=
paddle
.
logit
self
.
python_api
=
paddle
.
logit
self
.
dtype
=
np
.
float64
self
.
shape
=
[
120
]
self
.
eps
=
1e-8
self
.
set_attrs
()
self
.
set_attrs
()
x
=
np
.
random
.
uniform
(
-
1.0
,
1.0
,
self
.
shape
).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
-
1.0
,
1.0
,
self
.
shape
).
astype
(
self
.
dtype
)
out
=
logit
(
x
,
self
.
eps
)
out
=
logit
(
x
,
self
.
eps
)
...
@@ -55,7 +54,9 @@ class TestLogitOp(OpTest):
...
@@ -55,7 +54,9 @@ class TestLogitOp(OpTest):
self
.
attrs
=
{
'eps'
:
self
.
eps
}
self
.
attrs
=
{
'eps'
:
self
.
eps
}
def
set_attrs
(
self
):
def
set_attrs
(
self
):
pass
self
.
dtype
=
np
.
float64
self
.
shape
=
[
120
]
self
.
eps
=
1e-8
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
(
check_eager
=
True
)
self
.
check_output
(
check_eager
=
True
)
...
@@ -66,13 +67,86 @@ class TestLogitOp(OpTest):
...
@@ -66,13 +67,86 @@ class TestLogitOp(OpTest):
)
)
class
TestLogitOpFp32
(
TestLogitOp
):
def
set_attrs
(
self
):
self
.
dtype
=
np
.
float32
self
.
shape
=
[
120
]
self
.
eps
=
1e-8
def
test_check_output
(
self
):
self
.
check_output
(
check_eager
=
True
)
def
test_check_grad
(
self
):
self
.
check_grad
(
[
'X'
],
[
'Out'
],
user_defined_grads
=
[
self
.
x_grad
],
check_eager
=
True
)
class
TestLogitOpFp16
(
TestLogitOp
):
def
set_attrs
(
self
):
self
.
dtype
=
np
.
float16
self
.
shape
=
[
120
]
self
.
eps
=
1e-8
def
test_check_output
(
self
):
self
.
check_output
(
check_eager
=
True
)
def
test_check_grad
(
self
):
self
.
check_grad
(
[
'X'
],
[
'Out'
],
user_defined_grads
=
[
self
.
x_grad
],
check_eager
=
True
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not compiled with CUDA and not support the bfloat16"
,
)
class
TestLogitOpBf16
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
'logit'
self
.
python_api
=
paddle
.
logit
self
.
set_attrs
()
x
=
np
.
random
.
uniform
(
-
0.5
,
0.5
,
self
.
shape
).
astype
(
np
.
float32
)
out
=
logit
(
x
,
self
.
eps
)
self
.
x_grad
=
logit_grad
(
x
,
self
.
eps
)
self
.
inputs
=
{
'X'
:
convert_float_to_uint16
(
x
)}
self
.
outputs
=
{
'Out'
:
convert_float_to_uint16
(
out
)}
self
.
attrs
=
{
'eps'
:
self
.
eps
}
def
set_attrs
(
self
):
self
.
dtype
=
np
.
uint16
self
.
shape
=
[
120
]
self
.
eps
=
1e-8
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
self
.
check_output_with_place
(
place
,
check_eager
=
True
)
def
test_check_grad
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
self
.
check_grad_with_place
(
place
,
[
'X'
],
[
'Out'
],
user_defined_grads
=
[
self
.
x_grad
],
check_eager
=
True
,
)
class
TestLogitShape
(
TestLogitOp
):
class
TestLogitShape
(
TestLogitOp
):
def
set_attrs
(
self
):
def
set_attrs
(
self
):
self
.
dtype
=
np
.
float64
self
.
shape
=
[
2
,
60
]
self
.
shape
=
[
2
,
60
]
self
.
eps
=
1e-8
class
TestLogitEps
(
TestLogitOp
):
class
TestLogitEps
(
TestLogitOp
):
def
set_attrs
(
self
):
def
set_attrs
(
self
):
self
.
dtype
=
np
.
float32
self
.
shape
=
[
120
]
self
.
eps
=
1e-8
self
.
eps
=
1e-8
...
...
python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py
浏览文件 @
289677e2
...
@@ -38,6 +38,7 @@ NO_FP64_CHECK_GRAD_OP_LIST = [
...
@@ -38,6 +38,7 @@ NO_FP64_CHECK_GRAD_OP_LIST = [
'increment'
,
'increment'
,
'l1_norm'
,
'l1_norm'
,
'log_loss'
,
'log_loss'
,
'logit'
,
'lrn'
,
'lrn'
,
'margin_rank_loss'
,
'margin_rank_loss'
,
'match_matrix_tensor'
,
'match_matrix_tensor'
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录