Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wmsofts
Paddle
提交
418edae5
P
Paddle
项目概览
wmsofts
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
418edae5
编写于
12月 29, 2022
作者:
X
xu98bin
提交者:
GitHub
12月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
auto parallel bf16 (#49079)
* auto parallel bf16
上级
1078e064
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
943 addition
and
19 deletion
+943
-19
paddle/fluid/operators/collective/c_concat_op.cu.cc
paddle/fluid/operators/collective/c_concat_op.cu.cc
+3
-0
paddle/fluid/operators/collective/c_identity_op.cu.cc
paddle/fluid/operators/collective/c_identity_op.cu.cc
+3
-0
python/paddle/distributed/auto_parallel/constants.py
python/paddle/distributed/auto_parallel/constants.py
+7
-0
python/paddle/distributed/auto_parallel/operators/common.py
python/paddle/distributed/auto_parallel/operators/common.py
+4
-0
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+36
-18
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+9
-1
python/paddle/distributed/passes/__init__.py
python/paddle/distributed/passes/__init__.py
+1
-0
python/paddle/distributed/passes/auto_parallel_bf16.py
python/paddle/distributed/passes/auto_parallel_bf16.py
+661
-0
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_pass_bf16.py
...dle/fluid/tests/unittests/auto_parallel/test_pass_bf16.py
+211
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py
...ddle/fluid/tests/unittests/auto_parallel/test_strategy.py
+7
-0
未找到文件。
paddle/fluid/operators/collective/c_concat_op.cu.cc
浏览文件 @
418edae5
...
...
@@ -134,4 +134,7 @@ REGISTER_OP_CUDA_KERNEL(c_concat,
ops
::
CConcatOpCUDAKernel
<
double
>
,
ops
::
CConcatOpCUDAKernel
<
int
>
,
ops
::
CConcatOpCUDAKernel
<
int64_t
>
,
#if NCCL_VERSION_CODE >= 21000
ops
::
CConcatOpCUDAKernel
<
plat
::
bfloat16
>
,
#endif
ops
::
CConcatOpCUDAKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/c_identity_op.cu.cc
浏览文件 @
418edae5
...
...
@@ -22,4 +22,7 @@ REGISTER_OP_CUDA_KERNEL(c_identity,
ops
::
CIdentityOpKernel
<
double
>
,
ops
::
CIdentityOpKernel
<
int
>
,
ops
::
CIdentityOpKernel
<
int64_t
>
,
#if NCCL_VERSION_CODE >= 21000
ops
::
CIdentityOpKernel
<
plat
::
bfloat16
>
,
#endif
ops
::
CIdentityOpKernel
<
plat
::
float16
>
);
python/paddle/distributed/auto_parallel/constants.py
浏览文件 @
418edae5
...
...
@@ -76,6 +76,13 @@ set_field_default_config(AMP, "use_pure_fp16", False)
set_field_default_config
(
AMP
,
"use_fp16_guard"
,
True
)
set_field_default_config
(
AMP
,
"use_optimizer_fp16"
,
False
)
set_field_default_config
(
AMP
,
"enable_bf16"
,
False
)
set_field_default_config
(
AMP
,
"custom_bf16_list"
,
[])
set_field_default_config
(
AMP
,
"custom_fp32_list"
,
[])
set_field_default_config
(
AMP
,
"custom_fp32_varnames"
,
[])
set_field_default_config
(
AMP
,
"use_pure_bf16"
,
False
)
set_field_default_config
(
AMP
,
"use_bf16_guard"
,
False
)
#########################################
# sharding configuration
#########################################
...
...
python/paddle/distributed/auto_parallel/operators/common.py
浏览文件 @
418edae5
...
...
@@ -266,8 +266,12 @@ def is_parameter_related(varname, block):
varname
=
varname
[:
varname
.
index
(
".subprog_"
)]
if
".cast_fp"
in
varname
:
varname
=
varname
[:
varname
.
index
(
".cast_fp"
)]
if
".cast_bf"
in
varname
:
varname
=
varname
[:
varname
.
index
(
".cast_bf"
)]
if
".quantized"
in
varname
:
varname
=
varname
[:
varname
.
index
(
".quantized"
)]
# if "@RESHARD" in varname:
# varname = varname[: varname.index("@RESHARD")]
assert
block
.
_find_var_recursive
(
varname
)
var
=
block
.
_var_recursive
(
varname
)
return
var
.
is_parameter
...
...
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
418edae5
...
...
@@ -376,7 +376,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
check_variable_and_dtype
(
Out_grad
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
)
...
...
@@ -417,13 +417,13 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
set_comm_op_dist_attr_for_program
(
...
...
@@ -835,7 +835,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
check_variable_and_dtype
(
X_var
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
)
...
...
@@ -854,12 +854,15 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
attrs
=
{
...
...
@@ -1183,10 +1186,13 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
group
=
new_process_group
(
group_ranks
)
check_variable_and_dtype
(
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
)
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
attrs
=
{
'transpose_X'
:
trans_x
,
...
...
@@ -1731,7 +1737,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
check_variable_and_dtype
(
X_var
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
)
c_identity_op
=
main_block
.
append_op
(
...
...
@@ -1749,12 +1755,15 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
attrs
=
{
...
...
@@ -2077,10 +2086,13 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
group
=
new_process_group
(
group_ranks
)
check_variable_and_dtype
(
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
)
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
attrs
=
{
'trans_x'
:
trans_x
,
...
...
@@ -2610,7 +2622,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
check_variable_and_dtype
(
X_var
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
)
c_identity_op
=
main_block
.
append_op
(
...
...
@@ -2628,12 +2640,15 @@ class DistributedMulImpl0(DistributedOperatorImpl):
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
# attrs = {'trans_x': False, 'trans_y': False}
...
...
@@ -2965,10 +2980,13 @@ class DistributedMulImpl1(DistributedOperatorImpl):
group
=
new_process_group
(
group_ranks
)
check_variable_and_dtype
(
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
)
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
# attrs = {'trans_x': False, 'trans_y': False}
attrs
=
{
...
...
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
418edae5
...
...
@@ -221,13 +221,21 @@ class Parallelizer:
self
.
_dist_context
.
serial_feed_vars
[
"inputs"
]
+
self
.
_dist_context
.
serial_feed_vars
[
"labels"
]
)
if
config
[
"use_pure_fp16"
]:
if
config
[
"enable_bf16"
]:
auto_parallel_bf16_pass
=
new_pass
(
"auto_parallel_bf16"
,
config
)
auto_parallel_bf16_pass
.
apply
(
[
main_program
],
[
startup_program
],
self
.
_pass_context
)
loss
=
auto_parallel_bf16_pass
.
get_loss
()
elif
config
[
"use_pure_fp16"
]:
config
[
"base_opt"
]
=
optimizer
auto_parallel_fp16_pass
=
new_pass
(
"auto_parallel_fp16"
,
config
)
auto_parallel_fp16_pass
.
apply
(
[
main_program
],
[
startup_program
],
self
.
_pass_context
)
loss
=
auto_parallel_fp16_pass
.
get_loss
()
else
:
auto_parallel_amp_pass
=
new_pass
(
"auto_parallel_amp"
,
config
)
auto_parallel_amp_pass
.
apply
(
...
...
python/paddle/distributed/passes/__init__.py
浏览文件 @
418edae5
...
...
@@ -18,6 +18,7 @@ from .auto_parallel_gradient_merge import * # noqa: F403
from
.auto_parallel_sharding
import
*
# noqa: F403
from
.auto_parallel_amp
import
*
# noqa: F403
from
.auto_parallel_fp16
import
*
# noqa: F403
from
.auto_parallel_bf16
import
*
# noqa: F403
from
.auto_parallel_recompute
import
*
# noqa: F403
from
.auto_parallel_quantization
import
*
# noqa: F403
from
.auto_parallel_data_parallel_optimization
import
*
# noqa: F403
...
...
python/paddle/distributed/passes/auto_parallel_bf16.py
0 → 100644
浏览文件 @
418edae5
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
from
paddle
import
static
from
paddle.distributed.auto_parallel.dist_context
import
DistributedContext
from
paddle.distributed.auto_parallel.process_group
import
(
get_world_process_group
,
)
from
paddle.distributed.auto_parallel.utils
import
(
get_loss_op
,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
,
set_var_dist_attr
,
)
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
from
paddle.distributed.passes.pass_base
import
PassBase
,
register_pass
from
paddle.fluid
import
unique_name
from
paddle.fluid.contrib.mixed_precision.bf16
import
(
AutoMixedPrecisionListsBF16
,
)
from
paddle.fluid.contrib.mixed_precision.bf16.amp_utils
import
(
_dtype_to_str
,
_is_in_fp32_varnames
,
_valid_types
,
find_op_index
,
find_true_post_op
,
)
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
(
_rename_arg
,
find_true_prev_op
,
)
from
paddle.fluid.framework
import
Block
from
paddle.framework
import
core
from
..auto_parallel.utils
import
is_backward_op
,
is_forward_op
,
is_loss_op
world_process_group
=
get_world_process_group
()
class
BF16State
(
object
):
def
__init__
(
self
,
block
):
self
.
_block
:
Block
=
block
self
.
_op_bf16_dict
=
{}
self
.
_var_name_dict
=
{}
def
_is_bf16_op
(
self
,
op_id
):
return
self
.
_op_bf16_dict
.
get
(
op_id
,
None
)
def
_build_state
(
self
,
amp_lists
,
dist_context
):
ops
=
self
.
_block
.
ops
dist_op_context
=
dist_context
.
dist_op_context
training
=
False
for
op
in
ops
:
if
int
(
op
.
attr
(
"op_role"
))
==
257
:
training
=
True
if
int
(
op
.
attr
(
"op_role"
))
==
int
(
OpRole
.
Forward
):
self
.
_mark_black_white_op
(
amp_lists
,
op
,
ops
)
elif
int
(
op
.
attr
(
"op_role"
))
==
int
(
OpRole
.
Backward
):
if
op
.
desc
.
original_id
()
in
dist_op_context
.
grad_op_id_to_op_id
:
fwd_op_id
=
dist_op_context
.
grad_op_id_to_op_id
[
op
.
desc
.
original_id
()
]
if
self
.
_is_bf16_op
(
fwd_op_id
)
is
True
:
self
.
_op_bf16_dict
[
op
.
desc
.
original_id
()]
=
True
elif
self
.
_is_bf16_op
(
fwd_op_id
)
is
False
:
self
.
_op_bf16_dict
[
op
.
desc
.
original_id
()]
=
False
elif
int
(
op
.
attr
(
"op_role"
))
==
int
(
OpRole
.
Optimize
):
break
return
training
def
_mark_black_white_op
(
self
,
amp_lists
,
op
,
ops
):
if
op
.
type
==
"create_py_reader"
or
op
.
type
==
"read"
:
return
if
amp_lists
.
fp32_varnames
is
not
None
and
_is_in_fp32_varnames
(
op
,
amp_lists
):
self
.
_op_bf16_dict
[
op
.
desc
.
original_id
()]
=
False
return
if
op
.
type
in
amp_lists
.
bf16_list
:
self
.
_op_bf16_dict
[
op
.
desc
.
original_id
()]
=
True
elif
op
.
type
in
amp_lists
.
gray_list
:
is_fp32_op
=
False
is_bf16_op
=
False
for
in_name
in
op
.
input_names
:
if
in_name
:
for
in_var_name
in
op
.
input
(
in_name
):
in_var
=
self
.
_block
.
var
(
in_var_name
)
if
in_var
.
op
is
None
:
continue
elif
in_var
.
op
is
op
:
prev_op
=
find_true_prev_op
(
ops
,
op
,
in_var_name
)
if
prev_op
is
None
:
continue
else
:
prev_op
=
in_var
.
op
if
(
self
.
_op_bf16_dict
.
get
(
prev_op
.
desc
.
original_id
(),
False
)
is
False
or
prev_op
.
type
in
amp_lists
.
fp32_list
):
is_fp32_op
=
True
elif
(
self
.
_op_bf16_dict
.
get
(
prev_op
.
desc
.
original_id
(),
False
)
is
True
or
prev_op
.
type
in
amp_lists
.
bf16_list
):
is_bf16_op
=
True
if
is_fp32_op
:
self
.
_op_bf16_dict
[
op
.
desc
.
original_id
()]
=
False
elif
is_bf16_op
:
self
.
_op_bf16_dict
[
op
.
desc
.
original_id
()]
=
True
else
:
pass
else
:
self
.
_op_bf16_dict
[
op
.
desc
.
original_id
()]
=
False
def
cast_forward_program
(
self
,
dist_context
):
ops
=
self
.
_block
.
ops
idx
=
0
while
idx
<
len
(
ops
):
num_cast_ops
=
0
op
=
ops
[
idx
]
if
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Backward
):
break
if
self
.
_is_bf16_op
(
op
.
desc
.
original_id
())
is
False
:
num_cast_ops
=
self
.
_insert_cast_op_forward
(
op
,
idx
,
core
.
VarDesc
.
VarType
.
BF16
,
core
.
VarDesc
.
VarType
.
FP32
,
dist_context
,
)
elif
self
.
_is_bf16_op
(
op
.
desc
.
original_id
())
is
True
:
if
op
.
has_attr
(
'use_mkldnn'
):
op
.
_set_attr
(
'use_mkldnn'
,
True
)
op
.
_set_attr
(
'mkldnn_data_type'
,
'bfloat16'
)
elif
(
op
.
has_attr
(
'dtype'
)
and
op
.
attr
(
'dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
):
op
.
_set_attr
(
'dtype'
,
core
.
VarDesc
.
VarType
.
BF16
)
num_cast_ops
=
self
.
_insert_cast_op_forward
(
op
,
idx
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
BF16
,
dist_context
,
)
else
:
pass
idx
+=
num_cast_ops
+
1
self
.
_block
.
_sync_with_cpp
()
def
_insert_cast_op_forward
(
self
,
op
,
idx
,
src_dtype
,
dst_dtype
,
dist_context
:
DistributedContext
):
num_cast_ops
=
0
var_name_dict
=
{}
for
in_name
in
op
.
input_names
:
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
op
.
type
in
[
'batch_norm'
,
'fused_bn_add_activation'
,
'layer_norm'
,
]:
if
in_name
not
in
{
'X'
,
'Z'
}:
continue
for
in_var_name
in
op
.
input
(
in_name
):
in_var
=
self
.
_block
.
var
(
in_var_name
)
if
in_var
.
type
not
in
_valid_types
or
in_var
.
dtype
==
dst_dtype
:
continue
if
in_var
.
dtype
==
src_dtype
:
cast_name
=
(
in_var
.
name
+
'.cast_'
+
_dtype_to_str
(
dst_dtype
)
)
var_name_dict
[
in_var
.
name
]
=
cast_name
out_var
=
self
.
_block
.
vars
.
get
(
cast_name
)
consume_op_attr
=
dist_context
.
get_op_dist_attr_for_program
(
op
)
assert
consume_op_attr
is
not
None
in_var_dist_attr
=
consume_op_attr
.
get_input_dist_attr
(
in_var_name
)
if
out_var
is
None
or
out_var
.
dtype
!=
dst_dtype
:
assert
in_var_dist_attr
is
not
None
ref_mesh
=
in_var_dist_attr
.
process_mesh
ref_mapping
=
in_var_dist_attr
.
dims_mapping
consume_op_attr
.
set_input_dist_attr
(
cast_name
,
in_var_dist_attr
)
out_var
=
self
.
_block
.
create_var
(
name
=
cast_name
,
dtype
=
dst_dtype
,
persistable
=
False
,
stop_gradient
=
in_var
.
stop_gradient
,
)
set_var_dist_attr
(
dist_context
,
out_var
,
ref_mapping
,
ref_mesh
)
cast_op
=
self
.
_block
.
_insert_op_without_sync
(
idx
,
type
=
"cast"
,
inputs
=
{
"X"
:
in_var
},
outputs
=
{
"Out"
:
out_var
},
attrs
=
{
"in_dtype"
:
in_var
.
dtype
,
"out_dtype"
:
out_var
.
dtype
,
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
cast_op
,
ref_mesh
,
ref_mapping
,
dist_context
)
num_cast_ops
+=
1
else
:
consume_op_attr
.
set_input_dist_attr
(
cast_name
,
in_var_dist_attr
)
_rename_arg
(
op
,
in_var_name
,
out_var
.
name
)
else
:
if
op
.
has_attr
(
'in_dtype'
):
op
.
_set_attr
(
'in_dtype'
,
dst_dtype
)
self
.
_var_name_dict
[
op
.
desc
.
original_id
()]
=
var_name_dict
if
(
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
dst_dtype
==
core
.
VarDesc
.
VarType
.
BF16
):
for
out_name
in
op
.
output_names
:
if
(
op
.
type
in
[
'batch_norm'
,
'fused_bn_add_activation'
,
'layer_norm'
]
and
out_name
!=
'Y'
):
continue
for
out_var_name
in
op
.
output
(
out_name
):
out_var
=
self
.
_block
.
var
(
out_var_name
)
if
out_var
.
type
not
in
_valid_types
:
continue
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
BF16
)
if
op
.
has_attr
(
'out_dtype'
):
op
.
_set_attr
(
'out_dtype'
,
core
.
VarDesc
.
VarType
.
BF16
)
return
num_cast_ops
def
cast_backward_program
(
self
,
params_grads
,
dist_context
):
self
.
_block
.
_sync_with_cpp
()
ops
=
self
.
_block
.
ops
appended_grad_times
=
0
dist_op_context
=
dist_context
.
dist_op_context
loss_op
=
get_loss_op
(
self
.
_block
)
idx
=
find_op_index
(
self
.
_block
.
desc
,
loss_op
.
desc
)
+
1
while
idx
<
len
(
ops
):
num_cast_ops
=
0
grad_op
=
ops
[
idx
]
op_dist_attr
=
dist_context
.
get_op_dist_attr_for_program
(
grad_op
)
if
is_backward_op
(
grad_op
)
and
(
is_forward_op
(
ops
[
idx
-
1
])
or
is_loss_op
(
ops
[
idx
-
1
])
):
if
not
op_dist_attr
.
is_recompute
:
appended_grad_times
+=
1
if
(
grad_op
.
desc
.
original_id
()
in
dist_op_context
.
grad_op_id_to_op_id
):
if
self
.
_is_bf16_op
(
grad_op
.
desc
.
original_id
())
is
False
:
num_cast_ops
=
self
.
_insert_cast_op_backward
(
grad_op
,
idx
,
core
.
VarDesc
.
VarType
.
BF16
,
core
.
VarDesc
.
VarType
.
FP32
,
dist_context
,
appended_grad_times
,
)
elif
self
.
_is_bf16_op
(
grad_op
.
desc
.
original_id
())
is
True
:
if
grad_op
.
has_attr
(
'use_mkldnn'
):
grad_op
.
_set_attr
(
'use_mkldnn'
,
True
)
grad_op
.
_set_attr
(
'mkldnn_data_type'
,
'bfloat16'
)
elif
(
grad_op
.
has_attr
(
'dtype'
)
and
grad_op
.
attr
(
'dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
):
grad_op
.
_set_attr
(
'dtype'
,
core
.
VarDesc
.
VarType
.
BF16
)
num_cast_ops
=
self
.
_insert_cast_op_backward
(
grad_op
,
idx
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
BF16
,
dist_context
,
appended_grad_times
,
)
elif
grad_op
.
type
==
"sum"
:
in_var_name
=
grad_op
.
desc
.
input_arg_names
()[
0
]
src_dtype
=
self
.
_block
.
var
(
in_var_name
).
dtype
for
in_var_name
in
grad_op
.
desc
.
input_arg_names
():
assert
src_dtype
==
self
.
_block
.
var
(
in_var_name
).
dtype
out_var_name
=
grad_op
.
desc
.
output_arg_names
()[
0
]
out_var
=
self
.
_block
.
var
(
out_var_name
)
if
out_var
.
dtype
!=
src_dtype
:
out_var
.
desc
.
set_dtype
(
src_dtype
)
elif
int
(
grad_op
.
attr
(
"op_role"
))
==
257
:
pass
else
:
raise
ValueError
(
"'{}' op is not supported in the complete amp pass."
.
format
(
grad_op
.
type
)
)
idx
+=
num_cast_ops
+
1
self
.
_block
.
_sync_with_cpp
()
_update_backward_cast_ops
(
params_grads
,
dist_context
)
def
_insert_cast_op_backward
(
self
,
grad_op
,
idx
,
src_dtype
,
dst_dtype
,
dist_context
,
appended_grad_times
,
):
def
_keep_fp32_input
(
op
,
in_name
):
op_type
=
op
.
type
if
op_type
in
[
'layer_norm_grad'
]:
return
in_name
not
in
{
'X'
,
'Y@GRAD'
}
return
False
def
_keep_fp32_output
(
op
,
out_name
):
op_type
=
op
.
type
if
op_type
in
[
'layer_norm_grad'
]:
return
out_name
!=
'X@GRAD'
return
False
num_cast_ops
=
0
original_id
=
grad_op
.
desc
.
original_id
()
dist_op_context
=
dist_context
.
dist_op_context
fwd_op_id
=
dist_op_context
.
grad_op_id_to_op_id
[
original_id
]
for
in_name
in
grad_op
.
input_names
:
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
_keep_fp32_input
(
grad_op
,
in_name
):
for
in_var_name
in
grad_op
.
input
(
in_name
):
in_var
=
self
.
_block
.
_find_var_recursive
(
in_var_name
)
assert
in_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
continue
for
in_var_name
in
grad_op
.
input
(
in_name
):
in_var
=
self
.
_block
.
_find_var_recursive
(
in_var_name
)
if
in_var
.
dtype
==
src_dtype
:
consume_op_attr
=
dist_context
.
get_op_dist_attr_for_program
(
grad_op
)
if
in_var_name
in
self
.
_var_name_dict
[
fwd_op_id
]:
cast_name
=
self
.
_var_name_dict
[
fwd_op_id
][
in_var_name
]
grad_op
.
desc
.
_rename_input
(
in_var_name
,
cast_name
)
in_var_dist_attr
=
consume_op_attr
.
get_input_dist_attr
(
in_var_name
)
consume_op_attr
.
set_input_dist_attr
(
cast_name
,
in_var_dist_attr
)
else
:
assert
(
in_var
.
dtype
==
dst_dtype
),
"op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}"
.
format
(
grad_op
.
type
,
in_name
,
dst_dtype
,
in_var
.
dtype
,
str
(
grad_op
),
)
for
out_name
in
grad_op
.
output_names
:
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
_keep_fp32_output
(
grad_op
,
out_name
):
for
out_var_name
in
grad_op
.
output
(
out_name
):
out_var
=
self
.
_block
.
_find_var_recursive
(
out_var_name
)
assert
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
continue
for
out_var_name
in
grad_op
.
output
(
out_name
):
out_var
=
self
.
_block
.
_find_var_recursive
(
out_var_name
)
out_var_name_prefix
=
out_var_name
[:
out_var_name
.
find
(
'@'
)]
fwd_var
=
self
.
_block
.
_find_var_recursive
(
out_var_name_prefix
)
if
out_var
.
dtype
!=
fwd_var
.
dtype
:
out_var
.
desc
.
set_dtype
(
fwd_var
.
dtype
)
if
out_var
.
dtype
==
src_dtype
:
if
out_var_name_prefix
in
self
.
_var_name_dict
[
fwd_op_id
]:
consume_op_attr
=
(
dist_context
.
get_op_dist_attr_for_program
(
grad_op
)
)
fwd_cast_name
=
self
.
_var_name_dict
[
fwd_op_id
][
out_var_name_prefix
]
suffix
=
''
if
"@RENAME"
in
out_var_name
:
suffix
=
out_var_name
[
out_var_name
.
find
(
"@RENAME"
)
:
]
cast_name
=
fwd_cast_name
+
"@GRAD"
+
suffix
cast_var
=
self
.
_block
.
vars
.
get
(
cast_name
)
if
cast_var
is
None
or
cast_var
.
dtype
!=
dst_dtype
:
grad_op
.
desc
.
_rename_output
(
out_var_name
,
cast_name
)
out_var_dist_attr
=
(
consume_op_attr
.
get_output_dist_attr
(
out_var_name
)
)
ref_mesh
=
out_var_dist_attr
.
process_mesh
ref_mapping
=
out_var_dist_attr
.
dims_mapping
consume_op_attr
.
set_output_dist_attr
(
cast_name
,
out_var_dist_attr
)
assert
ref_mapping
is
not
None
cast_var
=
self
.
_block
.
create_var
(
name
=
cast_name
,
shape
=
out_var
.
shape
,
dtype
=
dst_dtype
,
persistable
=
False
,
stop_gradient
=
out_var
.
stop_gradient
,
)
set_var_dist_attr
(
dist_context
,
cast_var
,
ref_mapping
,
ref_mesh
)
dist_op_context
.
grad_var_to_var
[
appended_grad_times
][
cast_name
]
=
fwd_cast_name
cast_op
=
self
.
_block
.
_insert_op
(
idx
+
1
,
type
=
"cast"
,
inputs
=
{
"X"
:
cast_var
},
outputs
=
{
"Out"
:
out_var
},
attrs
=
{
"in_dtype"
:
cast_var
.
dtype
,
"out_dtype"
:
out_var
.
dtype
,
"op_role"
:
OpRole
.
Backward
,
},
)
cast_op
.
_remove_attr
(
"op_role_var"
)
cast_op
.
_remove_attr
(
"op_namescope"
)
cast_op
.
_remove_attr
(
"with_quant_attr"
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
cast_op
,
ref_mesh
,
ref_mapping
,
dist_context
)
num_cast_ops
+=
1
else
:
assert
out_var
.
dtype
==
dst_dtype
return
num_cast_ops
def
_update_backward_cast_ops
(
params_grads
,
dist_context
):
"""
move param grad cast to the end of backward segment
in order to enabel fp16 allreduce
"""
# TODO filter optimize ops in future
main_block
=
paddle
.
static
.
default_main_program
().
global_block
()
main_block
.
_sync_with_cpp
()
for
p
,
g
in
params_grads
:
op
=
g
.
op
if
g
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
op
.
type
==
'cast'
:
if
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Backward
)
and
op
.
has_attr
(
'op_role_var'
):
op
.
_remove_attr
(
"op_role_var"
)
post_ops
=
find_true_post_op
(
main_block
.
ops
,
op
,
g
.
name
)
if
post_ops
:
raise
ValueError
(
"The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"is used by {1}"
.
format
(
op
,
post_ops
[
0
])
)
if
op
==
main_block
.
ops
[
-
1
]:
continue
# add new op in the python and cpp at the same time
new_op_desc
=
main_block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
op
.
desc
)
new_op
=
paddle
.
fluid
.
framework
.
Operator
(
block
=
main_block
,
desc
=
new_op_desc
,
type
=
None
,
inputs
=
None
,
outputs
=
None
,
attrs
=
None
,
)
main_block
.
ops
.
append
(
new_op
)
# dist attr
param_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
p
)
output_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
main_block
.
var
(
op
.
output_arg_names
[
0
])
)
assert
param_dist_attr
is
not
None
assert
output_dist_attr
is
not
None
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
new_op
,
param_dist_attr
.
process_mesh
,
param_dist_attr
.
dims_mapping
,
dist_context
,
)
output_dist_attr
.
process_mesh
=
param_dist_attr
.
process_mesh
output_dist_attr
.
dims_mapping
=
param_dist_attr
.
dims_mapping
op_idx
=
find_op_index
(
main_block
.
desc
,
op
.
desc
)
if
op_idx
==
-
1
:
raise
ValueError
(
"The op {0} is not in program"
.
format
(
op
))
main_block
.
_remove_op
(
op_idx
,
sync
=
False
)
main_block
.
_sync_with_cpp
()
@
register_pass
(
"auto_parallel_bf16"
)
class
BF16Pass
(
PassBase
):
def
__init__
(
self
):
super
().
__init__
()
self
.
set_attr
(
"dist_context"
,
None
)
self
.
set_attr
(
"custom_bf16_list"
,
None
)
self
.
set_attr
(
"custom_fp32_list"
,
None
)
self
.
set_attr
(
"custom_fp32_varnames"
,
None
)
self
.
set_attr
(
"input_data"
,
[])
self
.
set_attr
(
"loss"
,
None
)
self
.
set_attr
(
"params_grads"
,
[])
self
.
set_attr
(
"use_bf16_guard"
,
False
)
self
.
_loss
=
None
def
_check_self
(
self
):
if
self
.
get_attr
(
"dist_context"
)
is
None
:
return
False
return
True
def
_check_conflict
(
self
,
other_pass
):
return
True
def
_apply_single_impl
(
self
,
main_program
,
startup_program
,
context
):
self
.
dist_context
=
self
.
get_attr
(
"dist_context"
)
params_grads
=
self
.
get_attr
(
"params_grads"
)
amp_lists
=
AutoMixedPrecisionListsBF16
(
self
.
get_attr
(
"custom_bf16_list"
),
self
.
get_attr
(
"custom_fp32_list"
),
self
.
get_attr
(
"custom_fp32_varnames"
),
)
with
static
.
program_guard
(
main_program
,
startup_program
):
amp_state
=
BF16State
(
main_program
.
global_block
())
training
=
amp_state
.
_build_state
(
amp_lists
,
self
.
dist_context
)
amp_state
.
cast_forward_program
(
self
.
dist_context
)
if
training
:
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
amp_state
.
cast_backward_program
(
params_grads
,
self
.
dist_context
)
self
.
_scale_loss
()
def
_scale_loss
(
self
):
main_block
=
paddle
.
static
.
default_main_program
().
global_block
()
main_block
.
_sync_with_cpp
()
OP_ROLE_KEY
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
loss
=
self
.
get_attr
(
"loss"
)
assert
loss
is
not
None
loss_op
=
loss
.
op
loss_op_dist_attr
=
self
.
dist_context
.
get_op_dist_attr_for_program
(
loss_op
)
if
loss
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP32
:
tmp_name
=
unique_name
.
generate
(
loss
.
name
+
".cast_fp32"
)
cast_loss
=
main_block
.
create_var
(
name
=
tmp_name
,
dtype
=
core
.
VarDesc
.
VarType
.
FP32
)
loss_dist_attr
=
self
.
dist_context
.
get_tensor_dist_attr_for_program
(
loss
)
ref_mesh
=
loss_op_dist_attr
.
process_mesh
self
.
dist_context
.
set_tensor_dist_attr_for_program
(
cast_loss
,
loss_dist_attr
)
loss_op_idx
=
find_op_index
(
main_block
.
desc
,
loss_op
.
desc
)
cast_op
=
main_block
.
_insert_op
(
loss_op_idx
+
1
,
type
=
'cast'
,
inputs
=
{
"X"
:
[
loss
]},
outputs
=
{
"Out"
:
[
cast_loss
]},
attrs
=
{
"in_dtype"
:
loss
.
dtype
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
,
"op_role"
:
loss_op
.
all_attrs
()[
OP_ROLE_KEY
],
},
)
loss_op
.
_set_attr
(
OP_ROLE_KEY
,
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
cast_op
,
ref_mesh
,
[
-
1
],
self
.
dist_context
)
first_backward_op
=
main_block
.
ops
[
loss_op_idx
+
2
]
assert
(
first_backward_op
.
type
==
"fill_constant"
and
int
(
first_backward_op
.
all_attrs
()[
OP_ROLE_KEY
])
==
257
)
cast_loss_grad
=
main_block
.
create_var
(
name
=
unique_name
.
generate
(
tmp_name
+
"@GRAD"
),
shape
=
loss
.
shape
,
dtype
=
core
.
VarDesc
.
VarType
.
FP32
,
persistable
=
loss
.
persistable
,
)
set_var_dist_attr
(
self
.
dist_context
,
cast_loss_grad
,
[
-
1
],
ref_mesh
)
pre_grad_name
=
first_backward_op
.
output_arg_names
[
0
]
first_backward_op
.
_rename_output
(
pre_grad_name
,
cast_loss_grad
.
name
)
cast_grad_op
=
main_block
.
_insert_op
(
loss_op_idx
+
3
,
type
=
'cast'
,
inputs
=
{
'X'
:
[
cast_loss_grad
]},
outputs
=
{
'Out'
:
[
pre_grad_name
]},
attrs
=
{
"in_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP16
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
,
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
cast_grad_op
,
ref_mesh
,
[
-
1
],
self
.
dist_context
)
loss
=
cast_loss
self
.
_loss
=
loss
main_block
.
_sync_with_cpp
()
def
get_loss
(
self
):
if
self
.
_loss
:
return
self
.
_loss
else
:
return
self
.
get_attr
(
"loss"
)
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
418edae5
...
...
@@ -128,5 +128,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_cluster_partition MODULES test_cluster_partition
)
py_test_modules
(
test_convert_to_process_meshes MODULES
test_convert_to_process_meshes
)
py_test_modules
(
test_pass_bf16 MODULES test_pass_bf16
)
py_test_modules
(
test_dist_saver MODULES test_dist_saver
)
endif
()
python/paddle/fluid/tests/unittests/auto_parallel/test_pass_bf16.py
0 → 100644
浏览文件 @
418edae5
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
random
import
unittest
import
numpy
as
np
import
paddle
import
paddle.fluid.core
as
core
import
paddle.nn
as
nn
from
paddle.distributed.fleet
import
auto
from
paddle.fluid.contrib.mixed_precision.bf16.amp_utils
import
_valid_types
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
find_true_prev_op
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.static
import
InputSpec
from
paddle.vision.datasets
import
MNIST
paddle
.
enable_static
()
def
apply_pass
(
use_bf16
=
False
):
strategy
=
auto
.
Strategy
()
strategy
.
auto_mode
=
"semi"
strategy
.
reinit
=
True
if
use_bf16
:
amp
=
strategy
.
amp
amp
.
enable
=
True
amp
.
enable_bf16
=
True
return
strategy
class
MnistDataset
(
MNIST
):
def
__init__
(
self
,
mode
,
return_label
=
True
):
super
().
__init__
(
mode
=
mode
)
self
.
return_label
=
return_label
def
__getitem__
(
self
,
idx
):
img
=
np
.
reshape
(
self
.
images
[
idx
],
[
1
,
28
,
28
])
if
self
.
return_label
:
return
img
,
np
.
array
(
self
.
labels
[
idx
]).
astype
(
'int64'
)
return
(
img
,)
def
__len__
(
self
):
return
len
(
self
.
images
)
def
reset_prog
():
paddle
.
fluid
.
framework
.
switch_main_program
(
paddle
.
static
.
Program
())
paddle
.
fluid
.
framework
.
switch_startup_program
(
paddle
.
static
.
Program
())
class
Model
(
nn
.
Layer
):
def
__init__
(
self
):
super
().
__init__
()
self
.
flatten
=
nn
.
Flatten
()
self
.
fc1
=
nn
.
Linear
(
784
,
120
)
self
.
relu1
=
nn
.
ReLU
()
self
.
fc2
=
nn
.
Linear
(
120
,
10
)
def
forward
(
self
,
input
):
input
.
stop_gradient
=
True
x
=
self
.
flatten
(
input
)
x
=
self
.
relu1
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
x
class
TestBF16Pass
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
rtol
=
1e-5
self
.
atol
=
1e-8
self
.
batch_size
=
256
self
.
batch_num
=
10
self
.
dataset
=
MnistDataset
(
"train"
)
self
.
eval_dataset
=
MnistDataset
(
"test"
)
def
init
(
self
,
engine
):
paddle
.
seed
(
2021
)
np
.
random
.
seed
(
2021
)
random
.
seed
(
2021
)
place
=
paddle
.
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
engine
.
_executor
=
paddle
.
static
.
Executor
(
place
)
def
get_engine
(
self
,
use_bf16
=
False
):
reset_prog
()
strategy
=
apply_pass
(
use_bf16
)
model
=
Model
()
opt
=
paddle
.
optimizer
.
SGD
(
0.001
,
parameters
=
model
.
parameters
())
loss
=
nn
.
CrossEntropyLoss
()
engine
=
auto
.
Engine
(
model
,
loss
,
opt
,
strategy
=
strategy
)
self
.
init
(
engine
)
return
engine
def
check_program
(
self
,
program
):
bf16_op_list
=
{
"matmul_v2"
,
"elementwise_add"
,
"relu"
,
"elementwise_add_grad"
,
"matmul_v2_grad"
,
"relu_grad"
,
}
fp32_op_list
=
{
"flatten_contiguous_range"
,
"reduce_mean"
,
"softmax_with_cross_entropy"
,
"fill_constant"
,
"reduce_mean_grad"
,
"softmax_with_cross_entropy_grad"
,
}
for
block
in
program
.
blocks
:
for
op
in
block
.
ops
:
if
op
not
in
bf16_op_list
and
op
not
in
fp32_op_list
:
continue
for
in_name
in
op
.
input_names
:
for
in_var_name
in
op
.
input
(
in_name
):
var
=
None
try
:
var
=
block
.
var
(
in_var_name
)
except
ValueError
as
e
:
var
=
block
.
_var_recursive
(
in_var_name
)
if
var
is
None
or
var
.
type
not
in
_valid_types
:
break
if
op
.
type
in
bf16_op_list
:
assert
var
.
dtype
==
core
.
VarDesc
.
VarType
.
BF16
if
"cast_bf16"
in
in_var_name
:
if
"@GRAD"
in
in_var_name
:
tmp_in_var_name
=
in_var_name
[
:
in_var_name
.
find
(
"@GRAD"
)
]
else
:
tmp_in_var_name
=
in_var_name
prev_op
=
find_true_prev_op
(
block
.
ops
,
op
,
tmp_in_var_name
)
assert
prev_op
is
not
None
assert
prev_op
.
type
==
"cast"
for
in_name
in
prev_op
.
input_names
:
for
in_var_name
in
prev_op
.
input
(
in_name
):
var
=
block
.
var
(
in_var_name
)
assert
(
var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
)
elif
op
.
type
in
fp32_op_list
:
if
(
op
.
type
==
"softmax_with_cross_entropy"
or
op
.
type
==
"softmax_with_cross_entropy_grad"
)
and
in_var_name
==
"label0"
:
continue
assert
var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
if
"cast_fp32"
in
in_var_name
:
prev_op
=
find_true_prev_op
(
block
.
ops
,
op
,
tmp_in_var_name
)
assert
prev_op
is
not
None
assert
prev_op
.
type
==
"cast"
for
in_name
in
prev_op
.
input_names
:
for
in_var_name
in
prev_op
.
input
(
in_name
):
var
=
block
.
var
(
in_var_name
)
assert
(
var
.
dtype
==
core
.
VarDesc
.
VarType
.
BF16
)
for
out_name
in
op
.
output_names
:
for
out_var_name
in
op
.
output
(
out_name
):
var
=
None
try
:
var
=
block
.
var
(
out_var_name
)
except
ValueError
as
e
:
var
=
block
.
_var_recursive
(
out_var_name
)
if
var
is
None
or
var
.
type
not
in
_valid_types
:
break
if
op
.
type
in
bf16_op_list
:
assert
var
.
dtype
==
core
.
VarDesc
.
VarType
.
BF16
elif
op
.
type
in
fp32_op_list
:
assert
var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
def
test_bf16_pass
(
self
):
bf16_o1_engine
=
self
.
get_engine
(
True
)
inputs_spec
=
[
InputSpec
([
None
,
1
,
28
,
28
],
'float32'
,
'input0'
)]
labels_spec
=
[
InputSpec
([
None
,
1
],
'int64'
,
'label0'
)]
bf16_o1_engine
.
prepare
(
inputs_spec
=
inputs_spec
,
labels_spec
=
labels_spec
,
mode
=
"train"
)
self
.
check_program
(
bf16_o1_engine
.
_dist_main_progs
[
"train"
][
0
])
print
(
"BF16!check program successfully!"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py
浏览文件 @
418edae5
...
...
@@ -41,6 +41,13 @@ class TestStrategy(unittest.TestCase):
self
.
assertEqual
(
amp
.
use_fp16_guard
,
True
)
self
.
assertEqual
(
amp
.
use_optimizer_fp16
,
False
)
self
.
assertEqual
(
amp
.
enable_bf16
,
False
)
self
.
assertEqual
(
amp
.
custom_bf16_list
,
[])
self
.
assertEqual
(
amp
.
custom_fp32_list
,
[])
self
.
assertEqual
(
amp
.
custom_fp32_varnames
,
[])
self
.
assertEqual
(
amp
.
use_pure_bf16
,
False
)
self
.
assertEqual
(
amp
.
use_bf16_guard
,
False
)
sharding
=
strategy
.
sharding
self
.
assertEqual
(
sharding
.
enable
,
False
)
self
.
assertEqual
(
sharding
.
stage
,
1
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录