Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
418edae5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
418edae5
编写于
12月 29, 2022
作者:
X
xu98bin
提交者:
GitHub
12月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
auto parallel bf16 (#49079)
* auto parallel bf16
上级
1078e064
变更
11
展开全部
隐藏空白更改
内联
并排
Showing
11 changed file
with
943 addition
and
19 deletion
+943
-19
paddle/fluid/operators/collective/c_concat_op.cu.cc
paddle/fluid/operators/collective/c_concat_op.cu.cc
+3
-0
paddle/fluid/operators/collective/c_identity_op.cu.cc
paddle/fluid/operators/collective/c_identity_op.cu.cc
+3
-0
python/paddle/distributed/auto_parallel/constants.py
python/paddle/distributed/auto_parallel/constants.py
+7
-0
python/paddle/distributed/auto_parallel/operators/common.py
python/paddle/distributed/auto_parallel/operators/common.py
+4
-0
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+36
-18
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+9
-1
python/paddle/distributed/passes/__init__.py
python/paddle/distributed/passes/__init__.py
+1
-0
python/paddle/distributed/passes/auto_parallel_bf16.py
python/paddle/distributed/passes/auto_parallel_bf16.py
+661
-0
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_pass_bf16.py
...dle/fluid/tests/unittests/auto_parallel/test_pass_bf16.py
+211
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py
...ddle/fluid/tests/unittests/auto_parallel/test_strategy.py
+7
-0
未找到文件。
paddle/fluid/operators/collective/c_concat_op.cu.cc
浏览文件 @
418edae5
...
@@ -134,4 +134,7 @@ REGISTER_OP_CUDA_KERNEL(c_concat,
...
@@ -134,4 +134,7 @@ REGISTER_OP_CUDA_KERNEL(c_concat,
ops
::
CConcatOpCUDAKernel
<
double
>
,
ops
::
CConcatOpCUDAKernel
<
double
>
,
ops
::
CConcatOpCUDAKernel
<
int
>
,
ops
::
CConcatOpCUDAKernel
<
int
>
,
ops
::
CConcatOpCUDAKernel
<
int64_t
>
,
ops
::
CConcatOpCUDAKernel
<
int64_t
>
,
#if NCCL_VERSION_CODE >= 21000
ops
::
CConcatOpCUDAKernel
<
plat
::
bfloat16
>
,
#endif
ops
::
CConcatOpCUDAKernel
<
plat
::
float16
>
);
ops
::
CConcatOpCUDAKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/c_identity_op.cu.cc
浏览文件 @
418edae5
...
@@ -22,4 +22,7 @@ REGISTER_OP_CUDA_KERNEL(c_identity,
...
@@ -22,4 +22,7 @@ REGISTER_OP_CUDA_KERNEL(c_identity,
ops
::
CIdentityOpKernel
<
double
>
,
ops
::
CIdentityOpKernel
<
double
>
,
ops
::
CIdentityOpKernel
<
int
>
,
ops
::
CIdentityOpKernel
<
int
>
,
ops
::
CIdentityOpKernel
<
int64_t
>
,
ops
::
CIdentityOpKernel
<
int64_t
>
,
#if NCCL_VERSION_CODE >= 21000
ops
::
CIdentityOpKernel
<
plat
::
bfloat16
>
,
#endif
ops
::
CIdentityOpKernel
<
plat
::
float16
>
);
ops
::
CIdentityOpKernel
<
plat
::
float16
>
);
python/paddle/distributed/auto_parallel/constants.py
浏览文件 @
418edae5
...
@@ -76,6 +76,13 @@ set_field_default_config(AMP, "use_pure_fp16", False)
...
@@ -76,6 +76,13 @@ set_field_default_config(AMP, "use_pure_fp16", False)
set_field_default_config
(
AMP
,
"use_fp16_guard"
,
True
)
set_field_default_config
(
AMP
,
"use_fp16_guard"
,
True
)
set_field_default_config
(
AMP
,
"use_optimizer_fp16"
,
False
)
set_field_default_config
(
AMP
,
"use_optimizer_fp16"
,
False
)
set_field_default_config
(
AMP
,
"enable_bf16"
,
False
)
set_field_default_config
(
AMP
,
"custom_bf16_list"
,
[])
set_field_default_config
(
AMP
,
"custom_fp32_list"
,
[])
set_field_default_config
(
AMP
,
"custom_fp32_varnames"
,
[])
set_field_default_config
(
AMP
,
"use_pure_bf16"
,
False
)
set_field_default_config
(
AMP
,
"use_bf16_guard"
,
False
)
#########################################
#########################################
# sharding configuration
# sharding configuration
#########################################
#########################################
...
...
python/paddle/distributed/auto_parallel/operators/common.py
浏览文件 @
418edae5
...
@@ -266,8 +266,12 @@ def is_parameter_related(varname, block):
...
@@ -266,8 +266,12 @@ def is_parameter_related(varname, block):
varname
=
varname
[:
varname
.
index
(
".subprog_"
)]
varname
=
varname
[:
varname
.
index
(
".subprog_"
)]
if
".cast_fp"
in
varname
:
if
".cast_fp"
in
varname
:
varname
=
varname
[:
varname
.
index
(
".cast_fp"
)]
varname
=
varname
[:
varname
.
index
(
".cast_fp"
)]
if
".cast_bf"
in
varname
:
varname
=
varname
[:
varname
.
index
(
".cast_bf"
)]
if
".quantized"
in
varname
:
if
".quantized"
in
varname
:
varname
=
varname
[:
varname
.
index
(
".quantized"
)]
varname
=
varname
[:
varname
.
index
(
".quantized"
)]
# if "@RESHARD" in varname:
# varname = varname[: varname.index("@RESHARD")]
assert
block
.
_find_var_recursive
(
varname
)
assert
block
.
_find_var_recursive
(
varname
)
var
=
block
.
_var_recursive
(
varname
)
var
=
block
.
_var_recursive
(
varname
)
return
var
.
is_parameter
return
var
.
is_parameter
...
...
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
418edae5
...
@@ -376,7 +376,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
...
@@ -376,7 +376,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
check_variable_and_dtype
(
check_variable_and_dtype
(
Out_grad
,
Out_grad
,
'tensor'
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
'_c_identity'
,
)
)
...
@@ -417,13 +417,13 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
...
@@ -417,13 +417,13 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
check_variable_and_dtype
(
check_variable_and_dtype
(
intermediate_var_0
,
intermediate_var_0
,
'x'
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
'linear'
,
)
)
check_dtype
(
check_dtype
(
intermediate_var_0
.
dtype
,
intermediate_var_0
.
dtype
,
'dtype'
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
'linear'
,
)
)
set_comm_op_dist_attr_for_program
(
set_comm_op_dist_attr_for_program
(
...
@@ -835,7 +835,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -835,7 +835,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
check_variable_and_dtype
(
check_variable_and_dtype
(
X_var
,
X_var
,
'tensor'
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
'_c_identity'
,
)
)
...
@@ -854,12 +854,15 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -854,12 +854,15 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
check_variable_and_dtype
(
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
)
check_dtype
(
check_dtype
(
intermediate_var_0
.
dtype
,
intermediate_var_0
.
dtype
,
'dtype'
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
'linear'
,
)
)
attrs
=
{
attrs
=
{
...
@@ -1183,10 +1186,13 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -1183,10 +1186,13 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
group
=
new_process_group
(
group_ranks
)
group
=
new_process_group
(
group_ranks
)
check_variable_and_dtype
(
check_variable_and_dtype
(
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
)
)
check_dtype
(
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
)
attrs
=
{
attrs
=
{
'transpose_X'
:
trans_x
,
'transpose_X'
:
trans_x
,
...
@@ -1731,7 +1737,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1731,7 +1737,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
check_variable_and_dtype
(
check_variable_and_dtype
(
X_var
,
X_var
,
'tensor'
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
'_c_identity'
,
)
)
c_identity_op
=
main_block
.
append_op
(
c_identity_op
=
main_block
.
append_op
(
...
@@ -1749,12 +1755,15 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1749,12 +1755,15 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
check_variable_and_dtype
(
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
)
check_dtype
(
check_dtype
(
intermediate_var_0
.
dtype
,
intermediate_var_0
.
dtype
,
'dtype'
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
'linear'
,
)
)
attrs
=
{
attrs
=
{
...
@@ -2077,10 +2086,13 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -2077,10 +2086,13 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
group
=
new_process_group
(
group_ranks
)
group
=
new_process_group
(
group_ranks
)
check_variable_and_dtype
(
check_variable_and_dtype
(
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
)
)
check_dtype
(
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
)
attrs
=
{
attrs
=
{
'trans_x'
:
trans_x
,
'trans_x'
:
trans_x
,
...
@@ -2610,7 +2622,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2610,7 +2622,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
check_variable_and_dtype
(
check_variable_and_dtype
(
X_var
,
X_var
,
'tensor'
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
'_c_identity'
,
)
)
c_identity_op
=
main_block
.
append_op
(
c_identity_op
=
main_block
.
append_op
(
...
@@ -2628,12 +2640,15 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2628,12 +2640,15 @@ class DistributedMulImpl0(DistributedOperatorImpl):
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
check_variable_and_dtype
(
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
)
check_dtype
(
check_dtype
(
intermediate_var_0
.
dtype
,
intermediate_var_0
.
dtype
,
'dtype'
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
'linear'
,
)
)
# attrs = {'trans_x': False, 'trans_y': False}
# attrs = {'trans_x': False, 'trans_y': False}
...
@@ -2965,10 +2980,13 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -2965,10 +2980,13 @@ class DistributedMulImpl1(DistributedOperatorImpl):
group
=
new_process_group
(
group_ranks
)
group
=
new_process_group
(
group_ranks
)
check_variable_and_dtype
(
check_variable_and_dtype
(
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
)
)
check_dtype
(
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
)
# attrs = {'trans_x': False, 'trans_y': False}
# attrs = {'trans_x': False, 'trans_y': False}
attrs
=
{
attrs
=
{
...
...
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
418edae5
...
@@ -221,13 +221,21 @@ class Parallelizer:
...
@@ -221,13 +221,21 @@ class Parallelizer:
self
.
_dist_context
.
serial_feed_vars
[
"inputs"
]
self
.
_dist_context
.
serial_feed_vars
[
"inputs"
]
+
self
.
_dist_context
.
serial_feed_vars
[
"labels"
]
+
self
.
_dist_context
.
serial_feed_vars
[
"labels"
]
)
)
if
config
[
"use_pure_fp16"
]:
if
config
[
"enable_bf16"
]:
auto_parallel_bf16_pass
=
new_pass
(
"auto_parallel_bf16"
,
config
)
auto_parallel_bf16_pass
.
apply
(
[
main_program
],
[
startup_program
],
self
.
_pass_context
)
loss
=
auto_parallel_bf16_pass
.
get_loss
()
elif
config
[
"use_pure_fp16"
]:
config
[
"base_opt"
]
=
optimizer
config
[
"base_opt"
]
=
optimizer
auto_parallel_fp16_pass
=
new_pass
(
"auto_parallel_fp16"
,
config
)
auto_parallel_fp16_pass
=
new_pass
(
"auto_parallel_fp16"
,
config
)
auto_parallel_fp16_pass
.
apply
(
auto_parallel_fp16_pass
.
apply
(
[
main_program
],
[
startup_program
],
self
.
_pass_context
[
main_program
],
[
startup_program
],
self
.
_pass_context
)
)
loss
=
auto_parallel_fp16_pass
.
get_loss
()
loss
=
auto_parallel_fp16_pass
.
get_loss
()
else
:
else
:
auto_parallel_amp_pass
=
new_pass
(
"auto_parallel_amp"
,
config
)
auto_parallel_amp_pass
=
new_pass
(
"auto_parallel_amp"
,
config
)
auto_parallel_amp_pass
.
apply
(
auto_parallel_amp_pass
.
apply
(
...
...
python/paddle/distributed/passes/__init__.py
浏览文件 @
418edae5
...
@@ -18,6 +18,7 @@ from .auto_parallel_gradient_merge import * # noqa: F403
...
@@ -18,6 +18,7 @@ from .auto_parallel_gradient_merge import * # noqa: F403
from
.auto_parallel_sharding
import
*
# noqa: F403
from
.auto_parallel_sharding
import
*
# noqa: F403
from
.auto_parallel_amp
import
*
# noqa: F403
from
.auto_parallel_amp
import
*
# noqa: F403
from
.auto_parallel_fp16
import
*
# noqa: F403
from
.auto_parallel_fp16
import
*
# noqa: F403
from
.auto_parallel_bf16
import
*
# noqa: F403
from
.auto_parallel_recompute
import
*
# noqa: F403
from
.auto_parallel_recompute
import
*
# noqa: F403
from
.auto_parallel_quantization
import
*
# noqa: F403
from
.auto_parallel_quantization
import
*
# noqa: F403
from
.auto_parallel_data_parallel_optimization
import
*
# noqa: F403
from
.auto_parallel_data_parallel_optimization
import
*
# noqa: F403
...
...
python/paddle/distributed/passes/auto_parallel_bf16.py
0 → 100644
浏览文件 @
418edae5
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
418edae5
...
@@ -128,5 +128,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
...
@@ -128,5 +128,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_cluster_partition MODULES test_cluster_partition
)
py_test_modules
(
test_cluster_partition MODULES test_cluster_partition
)
py_test_modules
(
test_convert_to_process_meshes MODULES
py_test_modules
(
test_convert_to_process_meshes MODULES
test_convert_to_process_meshes
)
test_convert_to_process_meshes
)
py_test_modules
(
test_pass_bf16 MODULES test_pass_bf16
)
py_test_modules
(
test_dist_saver MODULES test_dist_saver
)
py_test_modules
(
test_dist_saver MODULES test_dist_saver
)
endif
()
endif
()
python/paddle/fluid/tests/unittests/auto_parallel/test_pass_bf16.py
0 → 100644
浏览文件 @
418edae5
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
random
import
unittest
import
numpy
as
np
import
paddle
import
paddle.fluid.core
as
core
import
paddle.nn
as
nn
from
paddle.distributed.fleet
import
auto
from
paddle.fluid.contrib.mixed_precision.bf16.amp_utils
import
_valid_types
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
find_true_prev_op
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.static
import
InputSpec
from
paddle.vision.datasets
import
MNIST
paddle
.
enable_static
()
def
apply_pass
(
use_bf16
=
False
):
strategy
=
auto
.
Strategy
()
strategy
.
auto_mode
=
"semi"
strategy
.
reinit
=
True
if
use_bf16
:
amp
=
strategy
.
amp
amp
.
enable
=
True
amp
.
enable_bf16
=
True
return
strategy
class
MnistDataset
(
MNIST
):
def
__init__
(
self
,
mode
,
return_label
=
True
):
super
().
__init__
(
mode
=
mode
)
self
.
return_label
=
return_label
def
__getitem__
(
self
,
idx
):
img
=
np
.
reshape
(
self
.
images
[
idx
],
[
1
,
28
,
28
])
if
self
.
return_label
:
return
img
,
np
.
array
(
self
.
labels
[
idx
]).
astype
(
'int64'
)
return
(
img
,)
def
__len__
(
self
):
return
len
(
self
.
images
)
def
reset_prog
():
paddle
.
fluid
.
framework
.
switch_main_program
(
paddle
.
static
.
Program
())
paddle
.
fluid
.
framework
.
switch_startup_program
(
paddle
.
static
.
Program
())
class
Model
(
nn
.
Layer
):
def
__init__
(
self
):
super
().
__init__
()
self
.
flatten
=
nn
.
Flatten
()
self
.
fc1
=
nn
.
Linear
(
784
,
120
)
self
.
relu1
=
nn
.
ReLU
()
self
.
fc2
=
nn
.
Linear
(
120
,
10
)
def
forward
(
self
,
input
):
input
.
stop_gradient
=
True
x
=
self
.
flatten
(
input
)
x
=
self
.
relu1
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
x
class
TestBF16Pass
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
rtol
=
1e-5
self
.
atol
=
1e-8
self
.
batch_size
=
256
self
.
batch_num
=
10
self
.
dataset
=
MnistDataset
(
"train"
)
self
.
eval_dataset
=
MnistDataset
(
"test"
)
def
init
(
self
,
engine
):
paddle
.
seed
(
2021
)
np
.
random
.
seed
(
2021
)
random
.
seed
(
2021
)
place
=
paddle
.
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
engine
.
_executor
=
paddle
.
static
.
Executor
(
place
)
def
get_engine
(
self
,
use_bf16
=
False
):
reset_prog
()
strategy
=
apply_pass
(
use_bf16
)
model
=
Model
()
opt
=
paddle
.
optimizer
.
SGD
(
0.001
,
parameters
=
model
.
parameters
())
loss
=
nn
.
CrossEntropyLoss
()
engine
=
auto
.
Engine
(
model
,
loss
,
opt
,
strategy
=
strategy
)
self
.
init
(
engine
)
return
engine
def
check_program
(
self
,
program
):
bf16_op_list
=
{
"matmul_v2"
,
"elementwise_add"
,
"relu"
,
"elementwise_add_grad"
,
"matmul_v2_grad"
,
"relu_grad"
,
}
fp32_op_list
=
{
"flatten_contiguous_range"
,
"reduce_mean"
,
"softmax_with_cross_entropy"
,
"fill_constant"
,
"reduce_mean_grad"
,
"softmax_with_cross_entropy_grad"
,
}
for
block
in
program
.
blocks
:
for
op
in
block
.
ops
:
if
op
not
in
bf16_op_list
and
op
not
in
fp32_op_list
:
continue
for
in_name
in
op
.
input_names
:
for
in_var_name
in
op
.
input
(
in_name
):
var
=
None
try
:
var
=
block
.
var
(
in_var_name
)
except
ValueError
as
e
:
var
=
block
.
_var_recursive
(
in_var_name
)
if
var
is
None
or
var
.
type
not
in
_valid_types
:
break
if
op
.
type
in
bf16_op_list
:
assert
var
.
dtype
==
core
.
VarDesc
.
VarType
.
BF16
if
"cast_bf16"
in
in_var_name
:
if
"@GRAD"
in
in_var_name
:
tmp_in_var_name
=
in_var_name
[
:
in_var_name
.
find
(
"@GRAD"
)
]
else
:
tmp_in_var_name
=
in_var_name
prev_op
=
find_true_prev_op
(
block
.
ops
,
op
,
tmp_in_var_name
)
assert
prev_op
is
not
None
assert
prev_op
.
type
==
"cast"
for
in_name
in
prev_op
.
input_names
:
for
in_var_name
in
prev_op
.
input
(
in_name
):
var
=
block
.
var
(
in_var_name
)
assert
(
var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
)
elif
op
.
type
in
fp32_op_list
:
if
(
op
.
type
==
"softmax_with_cross_entropy"
or
op
.
type
==
"softmax_with_cross_entropy_grad"
)
and
in_var_name
==
"label0"
:
continue
assert
var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
if
"cast_fp32"
in
in_var_name
:
prev_op
=
find_true_prev_op
(
block
.
ops
,
op
,
tmp_in_var_name
)
assert
prev_op
is
not
None
assert
prev_op
.
type
==
"cast"
for
in_name
in
prev_op
.
input_names
:
for
in_var_name
in
prev_op
.
input
(
in_name
):
var
=
block
.
var
(
in_var_name
)
assert
(
var
.
dtype
==
core
.
VarDesc
.
VarType
.
BF16
)
for
out_name
in
op
.
output_names
:
for
out_var_name
in
op
.
output
(
out_name
):
var
=
None
try
:
var
=
block
.
var
(
out_var_name
)
except
ValueError
as
e
:
var
=
block
.
_var_recursive
(
out_var_name
)
if
var
is
None
or
var
.
type
not
in
_valid_types
:
break
if
op
.
type
in
bf16_op_list
:
assert
var
.
dtype
==
core
.
VarDesc
.
VarType
.
BF16
elif
op
.
type
in
fp32_op_list
:
assert
var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
def
test_bf16_pass
(
self
):
bf16_o1_engine
=
self
.
get_engine
(
True
)
inputs_spec
=
[
InputSpec
([
None
,
1
,
28
,
28
],
'float32'
,
'input0'
)]
labels_spec
=
[
InputSpec
([
None
,
1
],
'int64'
,
'label0'
)]
bf16_o1_engine
.
prepare
(
inputs_spec
=
inputs_spec
,
labels_spec
=
labels_spec
,
mode
=
"train"
)
self
.
check_program
(
bf16_o1_engine
.
_dist_main_progs
[
"train"
][
0
])
print
(
"BF16!check program successfully!"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py
浏览文件 @
418edae5
...
@@ -41,6 +41,13 @@ class TestStrategy(unittest.TestCase):
...
@@ -41,6 +41,13 @@ class TestStrategy(unittest.TestCase):
self
.
assertEqual
(
amp
.
use_fp16_guard
,
True
)
self
.
assertEqual
(
amp
.
use_fp16_guard
,
True
)
self
.
assertEqual
(
amp
.
use_optimizer_fp16
,
False
)
self
.
assertEqual
(
amp
.
use_optimizer_fp16
,
False
)
self
.
assertEqual
(
amp
.
enable_bf16
,
False
)
self
.
assertEqual
(
amp
.
custom_bf16_list
,
[])
self
.
assertEqual
(
amp
.
custom_fp32_list
,
[])
self
.
assertEqual
(
amp
.
custom_fp32_varnames
,
[])
self
.
assertEqual
(
amp
.
use_pure_bf16
,
False
)
self
.
assertEqual
(
amp
.
use_bf16_guard
,
False
)
sharding
=
strategy
.
sharding
sharding
=
strategy
.
sharding
self
.
assertEqual
(
sharding
.
enable
,
False
)
self
.
assertEqual
(
sharding
.
enable
,
False
)
self
.
assertEqual
(
sharding
.
stage
,
1
)
self
.
assertEqual
(
sharding
.
stage
,
1
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录