Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
bcd40f21
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bcd40f21
编写于
5月 18, 2021
作者:
W
wuhuanzhou
提交者:
GitHub
5月 18, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
relu supports bfloat16 data type (#32542)
上级
b5882c6e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
147 addition
and
10 deletion
+147
-10
paddle/fluid/operators/activation_op.cu
paddle/fluid/operators/activation_op.cu
+32
-1
paddle/fluid/operators/cast_op.cu
paddle/fluid/operators/cast_op.cu
+18
-0
python/paddle/fluid/tests/unittests/op_test.py
python/paddle/fluid/tests/unittests/op_test.py
+57
-3
python/paddle/fluid/tests/unittests/test_activation_op.py
python/paddle/fluid/tests/unittests/test_activation_op.py
+40
-6
未找到文件。
paddle/fluid/operators/activation_op.cu
浏览文件 @
bcd40f21
...
...
@@ -13,6 +13,7 @@ limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/cuda_device_function.h"
namespace
paddle
{
...
...
@@ -1437,9 +1438,9 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* =========================== relu register ============================ */
#ifdef PADDLE_WITH_HIP
REGISTER_ACTIVATION_CUDA_KERNEL
(
relu
,
Relu
,
CudaReluFunctor
,
CudaReluGradFunctor
);
REGISTER_OP_CUDA_KERNEL
(
relu_grad_grad
,
ops
::
ActivationDoubleGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
...
...
@@ -1448,6 +1449,36 @@ REGISTER_OP_CUDA_KERNEL(
ops
::
ReluGradGradFunctor
<
double
>>
,
ops
::
ActivationDoubleGradKernel
<
plat
::
CUDADeviceContext
,
ops
::
ReluGradGradFunctor
<
plat
::
float16
>>
);
#else
REGISTER_OP_CUDA_KERNEL
(
relu
,
ops
::
ActivationCudaKernel
<
paddle
::
platform
::
CUDADeviceContext
,
ops
::
CudaReluFunctor
<
float
>>
,
ops
::
ActivationCudaKernel
<
paddle
::
platform
::
CUDADeviceContext
,
ops
::
CudaReluFunctor
<
double
>>
,
ops
::
ActivationCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaReluFunctor
<
plat
::
float16
>>
,
ops
::
ActivationCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaReluFunctor
<
plat
::
bfloat16
>>
);
REGISTER_OP_CUDA_KERNEL
(
relu_grad
,
ops
::
ActivationGradCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaReluGradFunctor
<
float
>>
,
ops
::
ActivationGradCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaReluGradFunctor
<
double
>>
,
ops
::
ActivationGradCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaReluGradFunctor
<
plat
::
float16
>>
,
ops
::
ActivationGradCudaKernel
<
plat
::
CUDADeviceContext
,
ops
::
CudaReluGradFunctor
<
plat
::
bfloat16
>>
);
REGISTER_OP_CUDA_KERNEL
(
relu_grad_grad
,
ops
::
ActivationDoubleGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
ops
::
ReluGradGradFunctor
<
float
>>
,
ops
::
ActivationDoubleGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
ops
::
ReluGradGradFunctor
<
double
>>
,
ops
::
ActivationDoubleGradKernel
<
plat
::
CUDADeviceContext
,
ops
::
ReluGradGradFunctor
<
plat
::
float16
>>
,
ops
::
ActivationDoubleGradKernel
<
plat
::
CUDADeviceContext
,
ops
::
ReluGradGradFunctor
<
plat
::
bfloat16
>>
);
#endif
/* ========================================================================== */
/* =========================== tanh register ============================ */
...
...
paddle/fluid/operators/cast_op.cu
浏览文件 @
bcd40f21
...
...
@@ -95,6 +95,7 @@ struct CastOpFunctor<platform::CUDADeviceContext, InT> {
namespace
ops
=
paddle
::
operators
;
#ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL
(
cast
,
ops
::
CastOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
CastOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
...
...
@@ -108,3 +109,20 @@ REGISTER_OP_CUDA_KERNEL(
paddle
::
platform
::
complex64
>
,
ops
::
CastOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
complex128
>
);
#else
REGISTER_OP_CUDA_KERNEL
(
cast
,
ops
::
CastOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
CastOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
CastOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
CastOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
,
ops
::
CastOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
bool
>
,
ops
::
CastOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
uint8_t
>
,
ops
::
CastOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>
,
ops
::
CastOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
bfloat16
>
,
ops
::
CastOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
complex64
>
,
ops
::
CastOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
complex128
>
);
#endif
python/paddle/fluid/tests/unittests/op_test.py
浏览文件 @
bcd40f21
...
...
@@ -132,6 +132,8 @@ def get_numeric_gradient(place,
tensor_to_check_dtype
=
np
.
float16
# set delta as np.float16, will automatic convert to float32, float64
delta
=
np
.
array
(
delta
).
astype
(
np
.
float16
)
elif
tensor_to_check_dtype
==
core
.
VarDesc
.
VarType
.
BF16
:
tensor_to_check_dtype
=
np
.
float32
else
:
raise
ValueError
(
"Not supported data type "
+
str
(
tensor_to_check_dtype
))
...
...
@@ -140,9 +142,10 @@ def get_numeric_gradient(place,
sum
=
[]
op
.
run
(
scope
,
place
)
for
output_name
in
output_names
:
sum
.
append
(
np
.
array
(
scope
.
find_var
(
output_name
).
get_tensor
()).
astype
(
tensor_to_check_dtype
).
mean
())
output_numpy
=
np
.
array
(
scope
.
find_var
(
output_name
).
get_tensor
())
if
tensor_to_check
.
_dtype
()
==
core
.
VarDesc
.
VarType
.
BF16
:
output_numpy
=
convert_uint16_to_float
(
output_numpy
)
sum
.
append
(
output_numpy
.
astype
(
tensor_to_check_dtype
).
mean
())
return
tensor_to_check_dtype
(
np
.
array
(
sum
).
sum
()
/
len
(
output_names
))
gradient_flat
=
np
.
zeros
(
shape
=
(
tensor_size
,
),
dtype
=
tensor_to_check_dtype
)
...
...
@@ -152,6 +155,11 @@ def get_numeric_gradient(place,
numpy_tensor
=
np
.
array
(
tensor
).
astype
(
np
.
float16
)
numpy_tensor
=
numpy_tensor
.
flatten
()
return
numpy_tensor
[
i
]
elif
tensor_to_check
.
_dtype
()
==
core
.
VarDesc
.
VarType
.
BF16
:
numpy_tensor
=
np
.
array
(
tensor
).
astype
(
np
.
uint16
)
numpy_tensor
=
numpy_tensor
.
flatten
()
return
struct
.
unpack
(
'<f'
,
struct
.
pack
(
'<I'
,
numpy_tensor
[
i
]
<<
16
))[
0
]
elif
tensor_to_check_dtype
==
np
.
float32
:
return
tensor
.
_get_float_element
(
i
)
elif
tensor_to_check_dtype
==
np
.
float64
:
...
...
@@ -168,6 +176,13 @@ def get_numeric_gradient(place,
numpy_tensor
[
i
]
=
e
numpy_tensor
=
numpy_tensor
.
reshape
(
shape
)
tensor
.
set
(
numpy_tensor
,
place
)
elif
tensor_to_check
.
_dtype
()
==
core
.
VarDesc
.
VarType
.
BF16
:
numpy_tensor
=
np
.
array
(
tensor
).
astype
(
np
.
uint16
)
shape
=
numpy_tensor
.
shape
numpy_tensor
=
numpy_tensor
.
flatten
()
numpy_tensor
[
i
]
=
np
.
uint16
(
copy_bits_from_float_to_uint16
(
e
))
numpy_tensor
=
numpy_tensor
.
reshape
(
shape
)
tensor
.
set
(
numpy_tensor
,
place
)
elif
tensor_to_check_dtype
==
np
.
float32
:
tensor
.
_set_float_element
(
i
,
e
)
elif
tensor_to_check_dtype
==
np
.
float64
:
...
...
@@ -1353,6 +1368,8 @@ class OpTest(unittest.TestCase):
abs_a
[
abs_a
<
1e-10
]
=
1e-3
abs_a
[
np
.
logical_and
(
abs_a
>
1e-10
,
abs_a
<=
1e-8
)]
*=
1e4
abs_a
[
np
.
logical_and
(
abs_a
>
1e-8
,
abs_a
<=
1e-6
)]
*=
1e2
elif
self
.
is_bfloat16_op
():
abs_a
[
abs_a
<
1e-2
]
=
1
else
:
abs_a
[
abs_a
<
1e-3
]
=
1
...
...
@@ -1500,6 +1517,13 @@ class OpTest(unittest.TestCase):
dygraph_grad
=
self
.
_get_dygraph_grad
(
inputs_to_check
,
place
,
output_names
,
user_defined_grad_outputs
,
no_grad_set
)
fp32_grads
=
[]
for
grad
in
dygraph_grad
:
if
grad
.
dtype
==
np
.
uint16
:
grad
=
convert_uint16_to_float
(
grad
)
max_relative_error
=
0.03
fp32_grads
.
append
(
grad
)
dygraph_grad
=
fp32_grads
self
.
_assert_is_close
(
numeric_grads
,
dygraph_grad
,
inputs_to_check
,
max_relative_error
,
"Gradient Check On %s"
%
str
(
place
))
...
...
@@ -1544,6 +1568,21 @@ class OpTest(unittest.TestCase):
outputs
=
outputs
,
attrs
=
attrs_outputs
if
hasattr
(
self
,
"attrs"
)
else
None
)
if
self
.
dtype
==
np
.
uint16
:
cast_inputs
=
self
.
_find_var_in_dygraph
(
outputs
,
output_names
[
0
])
cast_outputs
=
block
.
create_var
(
dtype
=
"float32"
,
shape
=
cast_inputs
[
0
].
shape
)
cast_op
=
block
.
append_op
(
inputs
=
{
"X"
:
cast_inputs
},
outputs
=
{
"Out"
:
cast_outputs
},
type
=
"cast"
,
attrs
=
{
"in_dtype"
:
core
.
VarDesc
.
VarType
.
BF16
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
})
outputs
=
{
output_names
[
0
]:
cast_outputs
}
outputs_valid
=
{}
for
output_name
in
output_names
:
outputs_valid
[
output_name
]
=
self
.
_find_var_in_dygraph
(
...
...
@@ -1659,6 +1698,21 @@ class OpTest(unittest.TestCase):
feed_dict
=
self
.
feed_var
(
inputs
,
place
)
if
user_defined_grad_outputs
is
None
:
if
self
.
dtype
==
np
.
uint16
:
cast_inputs
=
list
(
map
(
block
.
var
,
output_names
))
cast_outputs
=
block
.
create_var
(
dtype
=
"float32"
,
shape
=
cast_inputs
[
0
].
shape
)
cast_op
=
block
.
append_op
(
inputs
=
{
"X"
:
cast_inputs
},
outputs
=
{
"Out"
:
cast_outputs
},
type
=
"cast"
,
attrs
=
{
"in_dtype"
:
core
.
VarDesc
.
VarType
.
BF16
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
})
cast_op
.
desc
.
infer_var_type
(
block
.
desc
)
cast_op
.
desc
.
infer_shape
(
block
.
desc
)
output_names
=
[
cast_outputs
.
name
]
loss
=
append_loss_ops
(
block
,
output_names
)
param_grad_list
=
append_backward
(
loss
=
loss
,
...
...
python/paddle/fluid/tests/unittests/test_activation_op.py
浏览文件 @
bcd40f21
...
...
@@ -18,7 +18,7 @@ import unittest
import
numpy
as
np
from
scipy.special
import
expit
,
erf
from
op_test
import
OpTest
from
op_test
import
OpTest
,
convert_float_to_uint16
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
...
...
@@ -1103,12 +1103,19 @@ class TestRelu(TestActivation):
self
.
init_dtype
()
np
.
random
.
seed
(
1024
)
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
# The same reason with TestAbs
x
[
np
.
abs
(
x
)
<
0.005
]
=
0.02
out
=
np
.
maximum
(
x
,
0
)
if
self
.
dtype
==
np
.
uint16
:
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
np
.
float32
)
# The same reason with TestAbs
x
[
np
.
abs
(
x
)
<
0.005
]
=
0.02
out
=
convert_float_to_uint16
(
np
.
maximum
(
x
,
0
))
self
.
inputs
=
{
'X'
:
convert_float_to_uint16
(
x
)}
else
:
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
# The same reason with TestAbs
x
[
np
.
abs
(
x
)
<
0.005
]
=
0.02
out
=
np
.
maximum
(
x
,
0
)
self
.
inputs
=
{
'X'
:
x
}
self
.
inputs
=
{
'X'
:
x
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_grad
(
self
):
...
...
@@ -2739,5 +2746,32 @@ create_test_act_fp16_class(TestHardSigmoid)
create_test_act_fp16_class
(
TestSwish
,
grad_atol
=
0.85
)
create_test_act_fp16_class
(
TestHardSwish
)
def
create_test_act_bf16_class
(
parent
,
atol
=
1e-2
,
grad_check
=
True
,
grad_atol
=
0.80
):
@
unittest
.
skipIf
(
not
paddle
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestActBF16
(
parent
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
uint16
def
test_check_output
(
self
):
place
=
core
.
CUDAPlace
(
0
)
self
.
check_output_with_place
(
place
,
atol
=
atol
)
def
test_check_grad
(
self
):
place
=
core
.
CUDAPlace
(
0
)
self
.
check_grad_with_place
(
place
,
[
'X'
],
'Out'
,
max_relative_error
=
grad_atol
)
cls_name
=
"{0}_{1}"
.
format
(
parent
.
__name__
,
"bf16"
)
TestActBF16
.
__name__
=
cls_name
globals
()[
cls_name
]
=
TestActBF16
create_test_act_bf16_class
(
TestRelu
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录