Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
7ccf6b60
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7ccf6b60
编写于
3月 22, 2021
作者:
A
arlesniak
提交者:
GitHub
3月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[oneDNN] Initial bf16 amp integration (#31093)
上级
a501a7b0
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
777 addition
and
38 deletion
+777
-38
paddle/fluid/operators/cast_op.cc
paddle/fluid/operators/cast_op.cc
+1
-0
paddle/fluid/operators/scale_op.cc
paddle/fluid/operators/scale_op.cc
+2
-0
python/paddle/fluid/contrib/mixed_precision/__init__.py
python/paddle/fluid/contrib/mixed_precision/__init__.py
+3
-0
python/paddle/fluid/contrib/mixed_precision/bf16/__init__.py
python/paddle/fluid/contrib/mixed_precision/bf16/__init__.py
+24
-0
python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py
...on/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py
+97
-0
python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py
...on/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py
+296
-0
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
+1
-1
python/paddle/fluid/contrib/tests/test_bf16_utils.py
python/paddle/fluid/contrib/tests/test_bf16_utils.py
+144
-0
python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py
python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py
+138
-0
python/paddle/fluid/data_feeder.py
python/paddle/fluid/data_feeder.py
+16
-7
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+9
-7
python/paddle/fluid/tests/book/test_fit_a_line.py
python/paddle/fluid/tests/book/test_fit_a_line.py
+14
-3
python/paddle/fluid/tests/book/test_word2vec_book.py
python/paddle/fluid/tests/book/test_word2vec_book.py
+20
-9
python/paddle/fluid/tests/unittests/op_test.py
python/paddle/fluid/tests/unittests/op_test.py
+6
-11
python/paddle/static/amp/__init__.py
python/paddle/static/amp/__init__.py
+3
-0
python/setup.py.in
python/setup.py.in
+1
-0
tools/parallel_UT_rule.py
tools/parallel_UT_rule.py
+1
-0
tools/static_mode_white_list.py
tools/static_mode_white_list.py
+1
-0
未找到文件。
paddle/fluid/operators/cast_op.cc
浏览文件 @
7ccf6b60
...
...
@@ -97,5 +97,6 @@ REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops
::
CastOpKernel
<
CPU
,
bool
>
,
ops
::
CastOpKernel
<
CPU
,
uint8_t
>
,
ops
::
CastOpKernel
<
CPU
,
paddle
::
platform
::
float16
>
,
ops
::
CastOpKernel
<
CPU
,
paddle
::
platform
::
bfloat16
>
,
ops
::
CastOpKernel
<
CPU
,
paddle
::
platform
::
complex64
>
,
ops
::
CastOpKernel
<
CPU
,
paddle
::
platform
::
complex128
>
);
paddle/fluid/operators/scale_op.cc
浏览文件 @
7ccf6b60
...
...
@@ -128,6 +128,8 @@ REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker,
REGISTER_OP_CPU_KERNEL
(
scale
,
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
,
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUDeviceContext
,
uint8_t
>
,
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int8_t
>
,
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int16_t
>
,
...
...
python/paddle/fluid/contrib/mixed_precision/__init__.py
浏览文件 @
7ccf6b60
...
...
@@ -20,7 +20,10 @@ from . import fp16_lists
from
.fp16_lists
import
*
from
.
import
fp16_utils
from
.fp16_utils
import
*
from
.
import
bf16
from
.bf16
import
*
__all__
=
decorator
.
__all__
__all__
+=
fp16_lists
.
__all__
__all__
+=
fp16_utils
.
__all__
__all__
+=
bf16
.
__all__
python/paddle/fluid/contrib/mixed_precision/bf16/__init__.py
0 → 100644
浏览文件 @
7ccf6b60
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
.
import
amp_lists
from
.amp_lists
import
*
from
.
import
amp_utils
from
.amp_utils
import
*
__all__
=
[]
__all__
+=
amp_lists
.
__all__
__all__
+=
amp_utils
.
__all__
python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py
0 → 100644
浏览文件 @
7ccf6b60
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
from
..fp16_lists
import
white_list
as
white_list_fp16
,
black_list
as
black_list_fp16
,
\
gray_list
as
gray_list_fp16
,
unsupported_fp16_list
__all__
=
[
"AutoMixedPrecisionListsBF16"
]
class
AutoMixedPrecisionListsBF16
(
object
):
"""
AutoMixedPrecisionListsBF16 is a class for fp32/bf16 op types list. The lists are used for an
algorithm which determines op's execution mode (fp32 or bf16).It can update pre-defined
fp32 list and bf16 list according to users' custom fp32 bf16 lists.
Args:
custom_bf16_list (set): Users' custom bf16 list.
custom_fp32_list (set): Users' custom fp32 list.
custom_fp32_varnames (set): Users' custom fp32 variables' names.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
with paddle.static.amp.bf16_guard():
paddle.static.amp.AutoMixedPrecisionListsBF16(custom_fp32_list={'lstm'})
"""
def
__init__
(
self
,
custom_bf16_list
=
None
,
custom_fp32_list
=
None
,
custom_fp32_varnames
=
None
):
self
.
_custom_bf16_list
=
custom_bf16_list
self
.
_custom_fp32_list
=
custom_fp32_list
self
.
bf16_list
=
copy
.
copy
(
bf16_list
)
self
.
fp32_list
=
copy
.
copy
(
fp32_list
)
self
.
gray_list
=
copy
.
copy
(
gray_list
)
self
.
unsupported_list
=
copy
.
copy
(
unsupported_list
)
self
.
fp32_varnames
=
copy
.
copy
(
custom_fp32_varnames
)
self
.
_update_list
()
def
_update_list
(
self
):
"""
Update fp32 and bf16 list according to users' custom list.
"""
if
self
.
_custom_bf16_list
and
self
.
_custom_fp32_list
:
for
op_name
in
self
.
_custom_bf16_list
:
if
op_name
in
self
.
_custom_fp32_list
:
raise
ValueError
(
"Custom bf16 list overlap "
"custom fp32 list"
)
if
self
.
_custom_bf16_list
:
for
op_name
in
self
.
_custom_bf16_list
:
if
op_name
in
self
.
fp32_list
:
self
.
fp32_list
.
remove
(
op_name
)
elif
op_name
in
self
.
gray_list
:
self
.
gray_list
.
remove
(
op_name
)
self
.
bf16_list
.
add
(
op_name
)
if
self
.
_custom_fp32_list
:
for
op_name
in
self
.
_custom_fp32_list
:
if
op_name
in
self
.
bf16_list
:
self
.
bf16_list
.
remove
(
op_name
)
elif
op_name
in
self
.
gray_list
:
self
.
gray_list
.
remove
(
op_name
)
self
.
fp32_list
.
add
(
op_name
)
self
.
unsupported_list
.
add
(
op_name
)
# always bf16
bf16_list
=
{
'elementwise_add'
,
}
# depends on the prev_op type
gray_list
=
{
'reshape2'
,
'lookup_table'
,
}
unsupported_list
=
unsupported_fp16_list
.
copy
().
copy
()
fp32_list
=
black_list_fp16
.
copy
().
copy
()
fp32_list
|=
white_list_fp16
fp32_list
|=
gray_list_fp16
fp32_list
-=
bf16_list
fp32_list
-=
gray_list
unsupported_list
-=
bf16_list
unsupported_list
-=
gray_list
python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py
0 → 100644
浏览文件 @
7ccf6b60
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
struct
from
....
import
core
from
....
import
framework
from
....log_helper
import
get_logger
from
....wrapped_decorator
import
signature_safe_contextmanager
from
.amp_lists
import
AutoMixedPrecisionListsBF16
from
..fp16_utils
import
find_true_prev_op
,
find_true_post_op
,
_rename_arg
,
find_op_index
import
logging
import
numpy
as
np
__all__
=
[
"bf16_guard"
,
"rewrite_program_bf16"
,
"convert_float_to_uint16"
]
_logger
=
get_logger
(
__name__
,
logging
.
INFO
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
)
_valid_types
=
[
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
]
_bf16_guard_pattern
=
"__use_bf16__"
def
convert_float_to_uint16
(
in_list
):
in_list
=
np
.
asarray
(
in_list
)
out
=
np
.
vectorize
(
lambda
x
:
struct
.
unpack
(
'<I'
,
struct
.
pack
(
'<f'
,
x
))[
0
]
>>
16
,
otypes
=
[
np
.
uint16
])(
in_list
.
flat
)
return
np
.
reshape
(
out
,
in_list
.
shape
)
def
_dtype_to_str
(
dtype
):
"""
Convert specific variable type to its corresponding string.
Args:
dtype (VarType): Variable type.
"""
if
dtype
==
core
.
VarDesc
.
VarType
.
BF16
:
return
'bf16'
else
:
return
'fp32'
def
_insert_cast_op
(
block
,
op
,
idx
,
src_dtype
,
dest_dtype
):
"""
Insert cast op and rename args of input and output.
Args:
block (Program): The block in which the operator is.
op (Operator): The operator to insert cast op.
idx (int): The index of current operator.
src_dtype (VarType): The input variable dtype of cast op.
dest_dtype (VarType): The output variable dtype of cast op.
Returns:
num_cast_op (int): The number of cast ops that have been inserted.
"""
num_cast_ops
=
0
for
in_name
in
op
.
input_names
:
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
op
.
type
in
[
'batch_norm'
,
'fused_bn_add_activation'
,
'layer_norm'
]:
if
in_name
not
in
{
'X'
,
'Z'
}:
continue
for
in_var_name
in
op
.
input
(
in_name
):
in_var
=
block
.
var
(
in_var_name
)
if
in_var
.
type
not
in
_valid_types
or
in_var
.
dtype
==
dest_dtype
:
continue
if
in_var
.
dtype
==
src_dtype
:
cast_name
=
in_var
.
name
+
'.cast_'
+
_dtype_to_str
(
dest_dtype
)
out_var
=
block
.
vars
.
get
(
cast_name
)
if
out_var
is
None
or
out_var
.
dtype
!=
dest_dtype
:
out_var
=
block
.
create_var
(
name
=
cast_name
,
dtype
=
dest_dtype
,
persistable
=
False
,
stop_gradient
=
in_var
.
stop_gradient
)
block
.
_insert_op
(
idx
,
type
=
"cast"
,
inputs
=
{
"X"
:
in_var
},
outputs
=
{
"Out"
:
out_var
},
attrs
=
{
"in_dtype"
:
in_var
.
dtype
,
"out_dtype"
:
out_var
.
dtype
})
num_cast_ops
+=
1
_rename_arg
(
op
,
in_var
.
name
,
out_var
.
name
)
else
:
if
op
.
has_attr
(
'in_dtype'
):
op
.
_set_attr
(
'in_dtype'
,
dest_dtype
)
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
dest_dtype
==
core
.
VarDesc
.
VarType
.
BF16
:
for
out_name
in
op
.
output_names
:
if
op
.
type
in
[
'batch_norm'
,
'fused_bn_add_activation'
,
'layer_norm'
]
and
out_name
!=
'Y'
:
continue
for
out_var_name
in
op
.
output
(
out_name
):
out_var
=
block
.
var
(
out_var_name
)
if
out_var
.
type
not
in
_valid_types
:
continue
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
BF16
)
if
op
.
has_attr
(
'out_dtype'
):
op
.
_set_attr
(
'out_dtype'
,
core
.
VarDesc
.
VarType
.
BF16
)
return
num_cast_ops
def
_is_in_fp32_varnames
(
op
,
amp_lists
):
for
in_name
in
op
.
input_arg_names
:
if
in_name
in
amp_lists
.
fp32_varnames
:
return
True
for
out_name
in
op
.
output_arg_names
:
if
out_name
in
amp_lists
.
fp32_varnames
:
return
True
return
False
def
_need_keep_fp32
(
op
,
unsupported_op_list
,
use_bf16_guard
):
if
op
.
type
in
unsupported_op_list
:
# the highest priority condition: If ops don't have bf16 computing kernels,
# they must be executed in fp32 calculation pattern.
return
True
# process ops about learning rate
in_out_arg_names
=
[]
in_out_arg_names
.
extend
(
list
(
op
.
input_arg_names
))
in_out_arg_names
.
extend
(
list
(
op
.
output_arg_names
))
for
name
in
in_out_arg_names
:
if
"learning_rate"
in
name
:
return
True
if
use_bf16_guard
:
if
op
.
has_attr
(
"op_namescope"
)
and
\
(
_bf16_guard_pattern
in
op
.
attr
(
"op_namescope"
)):
# op in bf16 guard
return
False
else
:
# op not in bf16 guard
return
True
else
:
return
False
@
signature_safe_contextmanager
def
bf16_guard
():
"""
As for the pure bf16 training, if users set `use_bf16_guard` to True,
only those ops created in the context manager `bf16_guard` will be
transformed as float16 type.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn.functional as F
paddle.enable_static()
data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
with paddle.static.amp.bf16_guard():
bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
pool = F.max_pool2d(bn, kernel_size=2, stride=2)
hidden = paddle.static.nn.fc(pool, size=10)
loss = paddle.mean(hidden)
"""
with
framework
.
name_scope
(
prefix
=
_bf16_guard_pattern
):
yield
def
rewrite_program_bf16
(
main_prog
,
amp_lists
=
None
,
use_bf16_guard
=
False
):
"""
Traverse all ops in current block and insert cast op according to
which set current op belongs to.
1. When an op belongs to the fp32 list, add it to fp32 set
2. When an op belongs to the bf16 list, add it to bf16 set
3. When an op belongs to the gray list. If one
of its inputs is the output of fp32 set op or fp32 list op,
add it to fp32 set. If all of its previous ops are not fp32
op and one of its inputs is the output of bf16 set op or
bf16 list op, add it to bf16 set.
4. When an op isn't in the lists, add it to fp32 op set.
5. Add necessary cast ops to make sure that fp32 set op will be
computed in fp32 mode, while bf16 set op will be computed in
bf16 mode.
Args:
main_prog (Program): The main program for training.
"""
if
amp_lists
is
None
:
amp_lists
=
AutoMixedPrecisionListsBF16
()
block
=
main_prog
.
global_block
()
ops
=
block
.
ops
bf16_op_set
=
set
()
fp32_op_set
=
set
()
for
op
in
ops
:
# NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
# we don't need to handle reader op and the input of 'create_py_reader' is not
# in block, which may result in errors.
# See GeneratorLoader._init_non_iterable() for details.
if
op
.
type
==
'create_py_reader'
or
op
.
type
==
'read'
:
continue
if
amp_lists
.
fp32_varnames
is
not
None
and
_is_in_fp32_varnames
(
op
,
amp_lists
):
fp32_op_set
.
add
(
op
)
continue
if
op
.
type
in
amp_lists
.
fp32_list
or
_need_keep_fp32
(
op
,
amp_lists
.
unsupported_list
,
use_bf16_guard
):
fp32_op_set
.
add
(
op
)
elif
op
.
type
in
amp_lists
.
bf16_list
:
bf16_op_set
.
add
(
op
)
elif
op
.
type
in
amp_lists
.
gray_list
:
is_fp32_op
=
False
is_bf16_op
=
False
for
in_name
in
op
.
input_names
:
# if this op has inputs
if
in_name
:
for
in_var_name
in
op
.
input
(
in_name
):
in_var
=
block
.
var
(
in_var_name
)
# this in_var isn't the output of other op
if
in_var
.
op
is
None
:
continue
elif
in_var
.
op
is
op
:
prev_op
=
find_true_prev_op
(
ops
,
op
,
in_var_name
)
if
prev_op
is
None
:
continue
else
:
prev_op
=
in_var
.
op
# if it's one of inputs
if
prev_op
in
fp32_op_set
or
\
prev_op
.
type
in
amp_lists
.
fp32_list
:
is_fp32_op
=
True
elif
prev_op
in
bf16_op_set
or
\
prev_op
.
type
in
amp_lists
.
bf16_list
:
is_bf16_op
=
True
if
is_fp32_op
:
fp32_op_set
.
add
(
op
)
elif
is_bf16_op
:
bf16_op_set
.
add
(
op
)
else
:
pass
else
:
# For numerical safe, we apply fp32 computation on ops that
# are not determined which list they should stay.
fp32_op_set
.
add
(
op
)
idx
=
0
while
idx
<
len
(
ops
):
op
=
ops
[
idx
]
num_cast_ops
=
0
if
op
in
fp32_op_set
:
num_cast_ops
=
_insert_cast_op
(
block
,
op
,
idx
,
core
.
VarDesc
.
VarType
.
BF16
,
core
.
VarDesc
.
VarType
.
FP32
)
elif
op
in
bf16_op_set
:
if
op
.
has_attr
(
'use_mkldnn'
):
op
.
_set_attr
(
'use_mkldnn'
,
True
)
op
.
_set_attr
(
'mkldnn_data_type'
,
'bfloat16'
)
elif
op
.
has_attr
(
'dtype'
)
and
op
.
attr
(
'dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
:
op
.
_set_attr
(
'dtype'
,
core
.
VarDesc
.
VarType
.
BF16
)
num_cast_ops
=
_insert_cast_op
(
block
,
op
,
idx
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
BF16
)
else
:
pass
idx
+=
num_cast_ops
+
1
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
浏览文件 @
7ccf6b60
...
...
@@ -69,7 +69,7 @@ class AutoMixedPrecisionLists(object):
self
.
unsupported_list
.
add
(
op_name
)
# The three sets listed below are changed dynamiclly. They don't contain all
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.
# The set of ops that support fp16 calculation and are considered numerically-
...
...
python/paddle/fluid/contrib/tests/test_bf16_utils.py
0 → 100644
浏览文件 @
7ccf6b60
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
unittest
import
paddle.fluid
as
fluid
import
paddle.fluid.contrib.mixed_precision
as
amp
from
paddle.fluid
import
core
import
paddle
paddle
.
enable_static
()
class
AMPTest
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
bf16_list
=
copy
.
copy
(
amp
.
bf16
.
amp_lists
.
bf16_list
)
self
.
fp32_list
=
copy
.
copy
(
amp
.
bf16
.
amp_lists
.
fp32_list
)
self
.
gray_list
=
copy
.
copy
(
amp
.
bf16
.
amp_lists
.
gray_list
)
self
.
amp_lists_
=
None
def
tearDown
(
self
):
self
.
assertEqual
(
self
.
amp_lists_
.
bf16_list
,
self
.
bf16_list
)
self
.
assertEqual
(
self
.
amp_lists_
.
fp32_list
,
self
.
fp32_list
)
self
.
assertEqual
(
self
.
amp_lists_
.
gray_list
,
self
.
gray_list
)
def
test_amp_lists
(
self
):
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
()
def
test_amp_lists_1
(
self
):
# 1. w={'exp}, b=None
self
.
bf16_list
.
add
(
'exp'
)
self
.
fp32_list
.
remove
(
'exp'
)
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
({
'exp'
})
def
test_amp_lists_2
(
self
):
# 2. w={'tanh'}, b=None
self
.
fp32_list
.
remove
(
'tanh'
)
self
.
bf16_list
.
add
(
'tanh'
)
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
({
'tanh'
})
def
test_amp_lists_3
(
self
):
# 3. w={'lstm'}, b=None
self
.
bf16_list
.
add
(
'lstm'
)
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
({
'lstm'
})
def
test_amp_lists_4
(
self
):
# 4. w=None, b={'elementwise_add'}
self
.
bf16_list
.
remove
(
'elementwise_add'
)
self
.
fp32_list
.
add
(
'elementwise_add'
)
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
(
custom_fp32_list
=
{
'elementwise_add'
})
def
test_amp_lists_5
(
self
):
# 5. w=None, b={'elementwise_add'}
self
.
fp32_list
.
add
(
'elementwise_add'
)
self
.
bf16_list
.
remove
(
'elementwise_add'
)
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
(
custom_fp32_list
=
{
'elementwise_add'
})
def
test_amp_lists_6
(
self
):
# 6. w=None, b={'lstm'}
self
.
fp32_list
.
add
(
'lstm'
)
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
(
custom_fp32_list
=
{
'lstm'
})
def
test_amp_lists_7
(
self
):
self
.
fp32_list
.
add
(
'reshape2'
)
self
.
gray_list
.
remove
(
'reshape2'
)
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
(
custom_fp32_list
=
{
'reshape2'
})
def
test_amp_list_8
(
self
):
self
.
bf16_list
.
add
(
'reshape2'
)
self
.
gray_list
.
remove
(
'reshape2'
)
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
(
custom_bf16_list
=
{
'reshape2'
})
class
AMPTest2
(
unittest
.
TestCase
):
def
test_amp_lists_
(
self
):
# 7. w={'lstm'} b={'lstm'}
# raise ValueError
self
.
assertRaises
(
ValueError
,
amp
.
AutoMixedPrecisionListsBF16
,
{
'lstm'
},
{
'lstm'
})
def
test_find_op_index
(
self
):
block
=
fluid
.
default_main_program
().
global_block
()
op_desc
=
core
.
OpDesc
()
idx
=
amp
.
bf16
.
amp_utils
.
find_op_index
(
block
.
desc
,
op_desc
)
assert
(
idx
==
-
1
)
def
test_is_in_fp32_varnames
(
self
):
block
=
fluid
.
default_main_program
().
global_block
()
var1
=
block
.
create_var
(
name
=
"X"
,
shape
=
[
3
],
dtype
=
'float32'
)
var2
=
block
.
create_var
(
name
=
"Y"
,
shape
=
[
3
],
dtype
=
'float32'
)
var3
=
block
.
create_var
(
name
=
"Z"
,
shape
=
[
3
],
dtype
=
'float32'
)
op1
=
block
.
append_op
(
type
=
"abs"
,
inputs
=
{
"X"
:
[
var1
]},
outputs
=
{
"Out"
:
[
var2
]})
op2
=
block
.
append_op
(
type
=
"abs"
,
inputs
=
{
"X"
:
[
var2
]},
outputs
=
{
"Out"
:
[
var3
]})
amp_lists_1
=
amp
.
AutoMixedPrecisionListsBF16
(
custom_fp32_varnames
=
{
'X'
})
assert
amp
.
bf16
.
amp_utils
.
_is_in_fp32_varnames
(
op1
,
amp_lists_1
)
amp_lists_2
=
amp
.
AutoMixedPrecisionListsBF16
(
custom_fp32_varnames
=
{
'Y'
})
assert
amp
.
bf16
.
amp_utils
.
_is_in_fp32_varnames
(
op2
,
amp_lists_2
)
assert
amp
.
bf16
.
amp_utils
.
_is_in_fp32_varnames
(
op1
,
amp_lists_2
)
def
test_find_true_post_op
(
self
):
block
=
fluid
.
default_main_program
().
global_block
()
var1
=
block
.
create_var
(
name
=
"X"
,
shape
=
[
3
],
dtype
=
'float32'
)
var2
=
block
.
create_var
(
name
=
"Y"
,
shape
=
[
3
],
dtype
=
'float32'
)
var3
=
block
.
create_var
(
name
=
"Z"
,
shape
=
[
3
],
dtype
=
'float32'
)
op1
=
block
.
append_op
(
type
=
"abs"
,
inputs
=
{
"X"
:
[
var1
]},
outputs
=
{
"Out"
:
[
var2
]})
op2
=
block
.
append_op
(
type
=
"abs"
,
inputs
=
{
"X"
:
[
var2
]},
outputs
=
{
"Out"
:
[
var3
]})
res
=
amp
.
bf16
.
amp_utils
.
find_true_post_op
(
block
.
ops
,
op1
,
"Y"
)
assert
(
res
==
[
op2
])
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py
0 → 100644
浏览文件 @
7ccf6b60
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
import
contextlib
import
unittest
import
numpy
as
np
import
paddle.fluid.layers
as
layers
import
paddle.static.amp
as
amp
from
paddle.fluid
import
core
paddle
.
enable_static
()
@
unittest
.
skipIf
(
not
core
.
supports_bfloat16
(),
"place does not support BF16 evaluation"
)
class
TestModelCastBF16
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
seed
=
111
@
classmethod
def
tearDownClass
(
cls
):
pass
@
contextlib
.
contextmanager
def
static_graph
(
self
):
with
self
.
scope_prog_guard
():
paddle
.
seed
(
self
.
seed
)
paddle
.
framework
.
random
.
_manual_program_seed
(
self
.
seed
)
yield
@
contextlib
.
contextmanager
def
scope_prog_guard
(
self
):
prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
with
fluid
.
program_guard
(
prog
,
startup_prog
):
yield
def
get_static_graph_result
(
self
,
feed
,
fetch_list
,
amp_fun
,
with_lod
=
False
):
exe
=
fluid
.
Executor
(
core
.
CPUPlace
())
exe
.
run
(
fluid
.
default_startup_program
())
prog
=
fluid
.
default_main_program
()
if
amp_fun
is
not
None
:
amp_fun
(
prog
)
return
exe
.
run
(
prog
,
feed
=
feed
,
fetch_list
=
fetch_list
,
return_numpy
=
(
not
with_lod
))
def
test_graph_rewrite
(
self
):
size
=
3
n
=
np
.
ones
([
size
,
size
],
dtype
=
'float32'
)
*
3.2
nn
=
np
.
ones
([
size
,
size
],
dtype
=
'float32'
)
*
-
2.7
n_bf16
=
amp
.
convert_float_to_uint16
(
n
)
nn_bf16
=
amp
.
convert_float_to_uint16
(
nn
)
with
self
.
static_graph
():
t_bf16
=
layers
.
data
(
name
=
't_bf16'
,
shape
=
[
size
,
size
],
dtype
=
np
.
uint16
)
tt_bf16
=
layers
.
data
(
name
=
'tt_bf16'
,
shape
=
[
size
,
size
],
dtype
=
np
.
uint16
)
t
=
layers
.
data
(
name
=
't'
,
shape
=
[
size
,
size
],
dtype
=
'float32'
)
tt
=
layers
.
data
(
name
=
'tt'
,
shape
=
[
size
,
size
],
dtype
=
'float32'
)
ret
=
layers
.
elementwise_add
(
t
,
tt
)
ret
=
layers
.
elementwise_mul
(
ret
,
t
)
ret
=
layers
.
reshape
(
ret
,
[
0
,
0
])
with
amp
.
bf16_guard
():
ret_bf16
=
layers
.
elementwise_add
(
t_bf16
,
tt_bf16
)
ret_bf16
=
layers
.
elementwise_mul
(
ret_bf16
,
t_bf16
)
ret_bf16
=
layers
.
reshape
(
ret_bf16
,
[
0
,
0
])
with
amp
.
bf16_guard
():
ret_fp32bf16
=
layers
.
elementwise_add
(
t
,
tt
)
ret_fp32bf16
=
layers
.
elementwise_mul
(
ret_fp32bf16
,
t
)
ret_fp32bf16
=
layers
.
reshape
(
ret_fp32bf16
,
[
0
,
0
])
static_ret_bf16
,
static_ret
,
ret_fp32bf16
=
self
.
get_static_graph_result
(
feed
=
{
't'
:
n
,
'tt'
:
nn
,
't_bf16'
:
n_bf16
,
'tt_bf16'
:
nn_bf16
,
},
fetch_list
=
[
ret_bf16
,
ret
,
ret_fp32bf16
],
amp_fun
=
lambda
prog
:
amp
.
rewrite_program_bf16
(
prog
,
use_bf16_guard
=
True
))
self
.
assertTrue
(
np
.
allclose
(
static_ret_bf16
,
static_ret
,
1e-2
))
self
.
assertTrue
(
np
.
allclose
(
static_ret_bf16
,
ret_fp32bf16
,
1e-2
))
with
self
.
static_graph
():
t
=
layers
.
data
(
name
=
't'
,
shape
=
[
size
,
size
],
dtype
=
'float32'
)
tt
=
layers
.
data
(
name
=
'tt'
,
shape
=
[
size
,
size
],
dtype
=
'float32'
)
with
amp
.
bf16_guard
():
ret
=
layers
.
elementwise_add
(
t
,
tt
)
ret
=
layers
.
reshape
(
ret
,
[
0
,
0
],
act
=
'elu'
)
ret
=
layers
.
elementwise_mul
(
ret
,
t
)
ret
=
layers
.
elementwise_add
(
ret
,
tt
)
static_ret_bf16
=
\
self
.
get_static_graph_result
(
feed
=
{
't'
:
n
,
'tt'
:
nn
},
fetch_list
=
[
ret
],
amp_fun
=
lambda
prog
:
amp
.
rewrite_program_bf16
(
prog
,
amp
.
AutoMixedPrecisionListsBF16
(
custom_fp32_varnames
=
{
'elementwise_add_0.tmp_0'
}),
use_bf16_guard
=
True
)
)
self
.
assertTrue
(
static_ret_bf16
,
np
.
ones
(
[
size
,
size
],
dtype
=
'float32'
)
*
-
1.1
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/data_feeder.py
浏览文件 @
7ccf6b60
...
...
@@ -29,6 +29,7 @@ __all__ = ['DataFeeder']
_PADDLE_DTYPE_2_NUMPY_DTYPE
=
{
core
.
VarDesc
.
VarType
.
BOOL
:
'bool'
,
core
.
VarDesc
.
VarType
.
FP16
:
'float16'
,
core
.
VarDesc
.
VarType
.
BF16
:
'uint16'
,
core
.
VarDesc
.
VarType
.
FP32
:
'float32'
,
core
.
VarDesc
.
VarType
.
FP64
:
'float64'
,
core
.
VarDesc
.
VarType
.
INT8
:
'int8'
,
...
...
@@ -47,16 +48,18 @@ def convert_dtype(dtype):
return
_PADDLE_DTYPE_2_NUMPY_DTYPE
[
dtype
]
elif
isinstance
(
dtype
,
type
):
if
dtype
in
[
np
.
bool
,
np
.
float16
,
np
.
float32
,
np
.
float64
,
np
.
int8
,
np
.
int16
,
np
.
int32
,
np
.
int64
,
np
.
uint8
,
np
.
complex64
,
np
.
complex128
np
.
bool
,
np
.
float16
,
np
.
uint16
,
np
.
float32
,
np
.
float64
,
np
.
int8
,
np
.
int16
,
np
.
int32
,
np
.
int64
,
np
.
uint8
,
np
.
complex64
,
np
.
complex128
]:
return
dtype
.
__name__
else
:
if
dtype
in
[
'bool'
,
'float16'
,
'float32'
,
'float64'
,
'int8'
,
'int16'
,
'int32'
,
'int64'
,
'uint8'
,
'complex64'
,
'complex128'
,
u
'bool'
,
u
'float16'
,
u
'float32'
,
u
'float64'
,
u
'int8'
,
u
'int16'
,
u
'int32'
,
u
'int64'
,
u
'uint8'
,
u
'complex64'
,
u
'complex128'
'bool'
,
'float16'
,
'uint16'
,
'float32'
,
'float64'
,
'int8'
,
'int16'
,
'int32'
,
'int64'
,
'uint8'
,
'complex64'
,
'complex128'
,
u
'bool'
,
u
'float16'
,
u
'uint16'
,
u
'float32'
,
u
'float64'
,
u
'int8'
,
u
'int16'
,
u
'int32'
,
u
'int64'
,
u
'uint8'
,
u
'complex64'
,
u
'complex128'
]:
# this code is a little bit dangerous, since error could happen
# when casting no-ascii code to str in python2.
...
...
@@ -66,7 +69,7 @@ def convert_dtype(dtype):
return
str
(
dtype
)
raise
TypeError
(
"dtype must be any of [bool, float16, float32, float64, int8, int16, "
"dtype must be any of [bool, float16,
uint16,
float32, float64, int8, int16, "
"int32, int64, uint8, complex64, complex128], but received %s"
%
dtype
)
...
...
@@ -123,6 +126,12 @@ def check_dtype(input_dtype,
warnings
.
warn
(
"The data type of '%s' in %s only support float16 in GPU now. %s"
%
(
input_name
,
op_name
,
extra_message
))
if
convert_dtype
(
input_dtype
)
in
[
'uint16'
]
and
op_name
not
in
[
'reshape'
,
'lookup_table'
,
'scale'
]:
warnings
.
warn
(
"The data type of '%s' in %s only support bfloat16 in OneDNN now. %s"
%
(
input_name
,
op_name
,
extra_message
))
if
convert_dtype
(
input_dtype
)
not
in
expected_dtype
:
raise
TypeError
(
"The data type of '%s' in %s must be %s, but received %s. %s"
%
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
7ccf6b60
...
...
@@ -6137,9 +6137,9 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
return dygraph_utils._append_activation_in_dygraph(out, act)
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64',
'bool'
], 'reshape')
check_variable_and_dtype(
x, 'x', [
'float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'uint16'
], 'reshape')
check_type(shape, 'shape', (list, tuple, Variable), 'reshape')
check_type(actual_shape, 'actual_shape', (Variable, type(None)), 'reshape')
...
...
@@ -11354,9 +11354,11 @@ def _elementwise_op(helper):
assert x is not None, 'x cannot be None in {}'.format(op_type)
assert y is not None, 'y cannot be None in {}'.format(op_type)
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], op_type)
x, 'x', ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
op_type)
check_variable_and_dtype(
y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64'], op_type)
y, 'y', ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
op_type)
axis = helper.kwargs.get('axis', -1)
use_mkldnn = helper.kwargs.get('use_mkldnn', False)
...
...
@@ -11428,8 +11430,8 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
return dygraph_utils._append_activation_in_dygraph(out)
check_variable_and_dtype(x, "x", [
'float16', '
float32', 'float64', 'int8', 'int16', 'int32', 'int64
',
'uint8'
'float16', '
uint16', 'float32', 'float64', 'int8', 'int16', 'int32
',
'
int64', '
uint8'
], "scale")
inputs = {'X': [x]}
attrs = {
...
...
python/paddle/fluid/tests/book/test_fit_a_line.py
浏览文件 @
7ccf6b60
...
...
@@ -26,7 +26,7 @@ import os
paddle
.
enable_static
()
def
train
(
use_cuda
,
save_dirname
,
is_local
):
def
train
(
use_cuda
,
save_dirname
,
is_local
,
use_bf16
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
...
...
@@ -37,6 +37,8 @@ def train(use_cuda, save_dirname, is_local):
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
if
use_bf16
:
paddle
.
static
.
amp
.
rewrite_program_bf16
(
fluid
.
default_main_program
())
sgd_optimizer
.
minimize
(
avg_cost
)
BATCH_SIZE
=
20
...
...
@@ -133,14 +135,17 @@ def infer(use_cuda, save_dirname=None):
print
(
"ground truth: "
,
test_label
)
def
main
(
use_cuda
,
is_local
=
True
):
def
main
(
use_cuda
,
is_local
=
True
,
use_bf16
=
False
):
if
use_cuda
and
not
fluid
.
core
.
is_compiled_with_cuda
():
return
if
use_bf16
and
not
fluid
.
core
.
is_compiled_with_mkldnn
():
return
# Directory for saving the trained model
save_dirname
=
"fit_a_line.inference.model"
train
(
use_cuda
,
save_dirname
,
is_local
)
train
(
use_cuda
,
save_dirname
,
is_local
,
use_bf16
)
infer
(
use_cuda
,
save_dirname
)
...
...
@@ -153,6 +158,12 @@ class TestFitALine(unittest.TestCase):
with
self
.
program_scope_guard
():
main
(
use_cuda
=
True
)
@
unittest
.
skipIf
(
not
fluid
.
core
.
supports_bfloat16
(),
"place does not support BF16 evaluation"
)
def
test_bf16
(
self
):
with
self
.
program_scope_guard
():
main
(
use_cuda
=
False
,
use_bf16
=
True
)
@
contextlib
.
contextmanager
def
program_scope_guard
(
self
):
prog
=
fluid
.
Program
()
...
...
python/paddle/fluid/tests/book/test_word2vec_book.py
浏览文件 @
7ccf6b60
...
...
@@ -39,7 +39,12 @@ def get_place(target):
format
(
target
))
def
train
(
target
,
is_sparse
,
is_parallel
,
save_dirname
,
is_local
=
True
):
def
train
(
target
,
is_sparse
,
is_parallel
,
save_dirname
,
is_local
=
True
,
use_bf16
=
False
):
PASS_NUM
=
100
EMBED_SIZE
=
32
HIDDEN_SIZE
=
256
...
...
@@ -101,6 +106,8 @@ def train(target, is_sparse, is_parallel, save_dirname, is_local=True):
raise
NotImplementedError
()
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
if
use_bf16
:
paddle
.
static
.
amp
.
rewrite_program_bf16
(
fluid
.
default_main_program
())
sgd_optimizer
.
minimize
(
avg_cost
)
train_reader
=
paddle
.
batch
(
...
...
@@ -239,12 +246,15 @@ def infer(target, save_dirname=None):
assert
np
.
isclose
(
a
,
b
,
rtol
=
5e-5
),
"a: {}, b: {}"
.
format
(
a
,
b
)
def
main
(
target
,
is_sparse
,
is_parallel
):
def
main
(
target
,
is_sparse
,
is_parallel
,
use_bf16
):
if
target
==
"cuda"
and
not
fluid
.
core
.
is_compiled_with_cuda
():
return
if
target
==
"xpu"
and
not
fluid
.
core
.
is_compiled_with_xpu
():
return
if
use_bf16
and
not
fluid
.
core
.
is_compiled_with_mkldnn
():
return
if
not
is_parallel
:
save_dirname
=
"word2vec.inference.model"
else
:
...
...
@@ -255,7 +265,7 @@ def main(target, is_sparse, is_parallel):
# so only inference is turned on.
train
(
"cpu"
,
is_sparse
,
is_parallel
,
save_dirname
)
else
:
train
(
target
,
is_sparse
,
is_parallel
,
save_dirname
)
train
(
target
,
is_sparse
,
is_parallel
,
save_dirname
,
use_bf16
=
use_bf16
)
infer
(
target
,
save_dirname
)
...
...
@@ -268,10 +278,11 @@ class W2VTest(unittest.TestCase):
pass
def
inject_test_method
(
target
,
is_sparse
,
is_parallel
):
fn_name
=
"test_{0}_{1}_{2}"
.
format
(
target
,
"sparse"
if
is_sparse
else
"dense"
,
"parallel"
if
is_parallel
else
"normal"
)
def
inject_test_method
(
target
,
is_sparse
,
is_parallel
,
use_bf16
=
False
):
fn_name
=
"test_{0}_{1}_{2}{3}"
.
format
(
target
,
"sparse"
if
is_sparse
else
"dense"
,
"parallel"
if
is_parallel
else
"normal"
,
"_bf16"
if
use_bf16
else
""
)
def
__impl__
(
*
args
,
**
kwargs
):
prog
=
fluid
.
Program
()
...
...
@@ -279,8 +290,7 @@ def inject_test_method(target, is_sparse, is_parallel):
scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
with
fluid
.
program_guard
(
prog
,
startup_prog
):
main
(
target
=
target
,
is_sparse
=
is_sparse
,
is_parallel
=
is_parallel
)
main
(
target
,
is_sparse
,
is_parallel
,
use_bf16
)
if
(
not
fluid
.
core
.
is_compiled_with_cuda
()
or
target
==
"cuda"
)
and
is_sparse
:
...
...
@@ -297,6 +307,7 @@ for target in ("cuda", "cpu", "xpu"):
for
is_sparse
in
(
False
,
True
):
for
is_parallel
in
(
False
,
):
inject_test_method
(
target
,
is_sparse
,
is_parallel
)
inject_test_method
(
"cpu"
,
False
,
False
,
use_bf16
=
True
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/op_test.py
浏览文件 @
7ccf6b60
...
...
@@ -244,17 +244,12 @@ def convert_float_to_uint16(float_list, data_format="NCHW"):
return
new_output
def
copy_bits_from_uint16_to_float
(
i
):
i
=
np
.
uint32
(
i
)
<<
16
return
struct
.
unpack
(
'<f'
,
struct
.
pack
(
'<I'
,
i
))[
0
]
def
convert_uint16_to_float
(
uint16_list
):
new_output
=
[]
for
x
in
np
.
nditer
(
uint16_list
):
new_output
.
append
(
np
.
float32
(
copy_bits_from_uint16_to_float
(
x
)))
return
np
.
reshape
(
new_output
,
uint16_list
.
shape
).
view
(
np
.
float32
)
def
convert_uint16_to_float
(
in_list
):
in_list
=
np
.
asarray
(
in_list
)
out
=
np
.
vectorize
(
lambda
x
:
struct
.
unpack
(
'<f'
,
struct
.
pack
(
'<I'
,
x
<<
16
))[
0
],
otypes
=
[
np
.
float32
])(
in_list
.
flat
)
return
np
.
reshape
(
out
,
in_list
.
shape
)
class
OpTest
(
unittest
.
TestCase
):
...
...
python/paddle/static/amp/__init__.py
浏览文件 @
7ccf6b60
...
...
@@ -14,5 +14,8 @@
from
...fluid.contrib
import
mixed_precision
from
...fluid.contrib.mixed_precision
import
*
from
...fluid.contrib.mixed_precision
import
bf16
from
...fluid.contrib.mixed_precision.bf16
import
*
__all__
=
mixed_precision
.
__all__
__all__
+=
bf16
.
__all__
python/setup.py.in
浏览文件 @
7ccf6b60
...
...
@@ -179,6 +179,7 @@ packages=['paddle',
'paddle.fluid.contrib.utils',
'paddle.fluid.contrib.extend_optimizer',
'paddle.fluid.contrib.mixed_precision',
'paddle.fluid.contrib.mixed_precision.bf16',
'paddle.fluid.contrib.layers',
'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details',
...
...
tools/parallel_UT_rule.py
浏览文件 @
7ccf6b60
...
...
@@ -219,6 +219,7 @@ CPU_PARALLEL_JOB = [
'test_full_op'
,
'test_framework_debug_str'
,
'test_fp16_utils'
,
'test_bf16_utils'
,
'test_fleet_rolemaker_4'
,
'test_flags_use_mkldnn'
,
'test_filter_by_instag_op'
,
...
...
tools/static_mode_white_list.py
浏览文件 @
7ccf6b60
...
...
@@ -699,4 +699,5 @@ STATIC_MODE_TESTING_LIST = [
'test_slice_op_xpu'
,
'test_generate_proposals_v2_op'
,
'test_lamb_op_xpu'
,
'test_model_cast_to_bf16'
,
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录