Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6f3c9643
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6f3c9643
编写于
4月 14, 2023
作者:
J
JZ-LIANG
提交者:
GitHub
4月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Eb118 BF16 Adoption (#52827)
* pr1 * pr2 * pr3 * fixed unitest * adopt for scale
上级
8cbc75ca
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
1878 addition
and
970 deletion
+1878
-970
python/paddle/distributed/auto_parallel/constants.py
python/paddle/distributed/auto_parallel/constants.py
+4
-2
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
...dle/distributed/auto_parallel/operators/dist_embedding.py
+7
-4
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+1049
-631
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+14
-5
python/paddle/distributed/passes/auto_parallel_amp.py
python/paddle/distributed/passes/auto_parallel_amp.py
+347
-166
python/paddle/distributed/passes/auto_parallel_fp16.py
python/paddle/distributed/passes/auto_parallel_fp16.py
+252
-156
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py
...paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py
+142
-0
python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
.../fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
+1
-1
python/paddle/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py
...e/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py
+55
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py
...ddle/fluid/tests/unittests/auto_parallel/test_strategy.py
+4
-5
未找到文件。
python/paddle/distributed/auto_parallel/constants.py
浏览文件 @
6f3c9643
...
...
@@ -62,6 +62,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False)
#########################################
AMP
=
"amp"
set_field_default_config
(
AMP
,
"enable"
,
False
)
set_field_default_config
(
AMP
,
"dtype"
,
"float16"
)
set_field_default_config
(
AMP
,
"level"
,
"o1"
)
set_field_default_config
(
AMP
,
"init_loss_scaling"
,
32768.0
)
set_field_default_config
(
AMP
,
"incr_every_n_steps"
,
1000
)
set_field_default_config
(
AMP
,
"decr_every_n_nan_or_inf"
,
2
)
...
...
@@ -71,8 +73,8 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
set_field_default_config
(
AMP
,
"custom_white_list"
,
[])
set_field_default_config
(
AMP
,
"custom_black_list"
,
[])
set_field_default_config
(
AMP
,
"custom_black_varnames"
,
[])
set_field_default_config
(
AMP
,
"use_
pure_fp16
"
,
False
)
set_field_default_config
(
AMP
,
"use_
fp16_guard"
,
Tru
e
)
set_field_default_config
(
AMP
,
"use_
fp16_guard
"
,
False
)
set_field_default_config
(
AMP
,
"use_
bf16_guard"
,
Fals
e
)
set_field_default_config
(
AMP
,
"use_optimizer_fp16"
,
False
)
#########################################
...
...
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
浏览文件 @
6f3c9643
...
...
@@ -459,7 +459,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype
(
Out_var
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'c_allreduce_sum'
,
)
...
...
@@ -649,7 +649,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype
(
Out_grad
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
)
...
...
@@ -691,12 +691,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
},
)
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
...
...
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
6f3c9643
...
...
@@ -20,7 +20,11 @@ from .common import DistributedOperatorImpl
from
.common
import
register_distributed_operator_impl_container
from
.common
import
register_distributed_operator_impl
from
.common
import
gradient_synchronization
from
.common
import
set_comm_op_dist_attr_for_program
,
naive_copy_op_dist_attr_for_program
,
is_parameter_related
from
.common
import
(
set_comm_op_dist_attr_for_program
,
naive_copy_op_dist_attr_for_program
,
is_parameter_related
,
)
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_replicate
from
..utils
import
is_valid_list_index
...
...
@@ -33,24 +37,39 @@ from paddle.fluid import core, unique_name
from
paddle.fluid.framework
import
_non_static_mode
from
paddle.fluid.framework
import
Program
,
Parameter
,
Variable
,
program_guard
from
paddle.fluid.data_feeder
import
check_variable_and_dtype
,
check_dtype
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
from
paddle.distributed.fleet.meta_optimizers.common
import
(
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
,
)
from
..process_group
import
new_process_group
from
..utils
import
_get_comm_group
,
_get_corresponding_rank
from
.dist_default
import
DistributedDefaultImpl0
from
..cost
import
build_comp_desc_from_dist_op
,
build_comm_desc_from_dist_op
,
build_dp_costs
from
..cost
import
(
build_comp_desc_from_dist_op
,
build_comm_desc_from_dist_op
,
build_dp_costs
,
)
from
..cost
import
build_comm_costs_from_descs
,
build_comp_costs_from_descs
from
..cost
import
MatmulV2OpCost
,
MatmulOpCost
,
MulOpCost
from
..cost
import
MatmulV2GradOpCost
,
MatmulGradOpCost
,
MulGradOpCost
from
paddle.distributed.auto_parallel.cost.comm_op_cost
import
AllreduceSumOpCost
,
IdentityOpCost
from
paddle.distributed.auto_parallel.cost.comm_op_cost
import
(
AllreduceSumOpCost
,
IdentityOpCost
,
)
def
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
):
if
trans_x
:
x_dims_mapping
[
-
1
],
x_dims_mapping
[
-
2
]
=
x_dims_mapping
[
-
2
],
x_dims_mapping
[
-
1
]
x_dims_mapping
[
-
1
],
x_dims_mapping
[
-
2
]
=
(
x_dims_mapping
[
-
2
],
x_dims_mapping
[
-
1
],
)
if
trans_y
:
y_dims_mapping
[
-
1
],
y_dims_mapping
[
-
2
]
=
y_dims_mapping
[
-
2
],
y_dims_mapping
[
-
1
]
y_dims_mapping
[
-
1
],
y_dims_mapping
[
-
2
]
=
(
y_dims_mapping
[
-
2
],
y_dims_mapping
[
-
1
],
)
def
copy_op_with_new_input_output
(
ctx
,
block
,
src_op
,
**
kwargs
):
...
...
@@ -123,13 +142,17 @@ def _update_dims_mapping_for_matmul(dist_op):
for
i
in
range
(
new_out_dims_mapping_len
-
2
):
broadcast_out_dims_mapping
.
append
(
out_dims_mapping
[
i
])
compatible_dims_mapping
=
compute_compatible_dims_mapping
([
broadcast_x_dims_mapping
,
broadcast_y_dims_mapping
,
broadcast_out_dims_mapping
])
compatible_dims_mapping
=
compute_compatible_dims_mapping
(
[
broadcast_x_dims_mapping
,
broadcast_y_dims_mapping
,
broadcast_out_dims_mapping
,
]
)
if
compatible_dims_mapping
is
None
:
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
return
False
for
i
in
range
(
new_x_dims_mapping_len
-
2
):
...
...
@@ -152,17 +175,20 @@ def _update_dims_mapping_for_matmul(dist_op):
# The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
dim_changed
=
compute_compatible_and_update_dim_mapping
(
[
x_dims_mapping
,
y_dims_mapping
],
[
-
1
,
-
2
])
[
x_dims_mapping
,
y_dims_mapping
],
[
-
1
,
-
2
]
)
if
dim_changed
:
changed
=
True
dim_changed
=
compute_compatible_and_update_dim_mapping
(
[
x_dims_mapping
,
out_dims_mapping
],
[
-
2
,
-
2
])
[
x_dims_mapping
,
out_dims_mapping
],
[
-
2
,
-
2
]
)
if
dim_changed
:
changed
=
True
dim_changed
=
compute_compatible_and_update_dim_mapping
(
[
y_dims_mapping
,
out_dims_mapping
],
[
-
1
,
-
1
])
[
y_dims_mapping
,
out_dims_mapping
],
[
-
1
,
-
1
]
)
if
dim_changed
:
changed
=
True
...
...
@@ -202,7 +228,8 @@ def _is_auto_compatible_for_matmul(dist_op):
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
out_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_output_dims_mapping
(
out_name
))
op_dist_attr
.
get_output_dims_mapping
(
out_name
)
)
x_dims_mapping_len
=
len
(
x_dims_mapping
)
y_dims_mapping_len
=
len
(
y_dims_mapping
)
out_dims_mapping_len
=
len
(
out_dims_mapping
)
...
...
@@ -234,22 +261,23 @@ def _is_auto_compatible_for_matmul(dist_op):
for
i
in
range
(
out_dims_mapping_len
-
2
):
broadcast_out_dims_mapping
.
append
(
out_dims_mapping
[
i
])
is_same
=
((
broadcast_x_dims_mapping
==
broadcast_y_dims_mapping
)
and
(
broadcast_x_dims_mapping
==
broadcast_out_dims_mapping
))
is_same
=
(
broadcast_x_dims_mapping
==
broadcast_y_dims_mapping
)
and
(
broadcast_x_dims_mapping
==
broadcast_out_dims_mapping
)
if
not
is_same
:
return
False
# The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
is_same
=
(
x_dims_mapping
[
-
1
]
==
y_dims_mapping
[
-
2
])
is_same
=
x_dims_mapping
[
-
1
]
==
y_dims_mapping
[
-
2
]
if
not
is_same
:
return
False
is_same
=
(
x_dims_mapping
[
-
2
]
==
out_dims_mapping
[
-
2
])
is_same
=
x_dims_mapping
[
-
2
]
==
out_dims_mapping
[
-
2
]
if
not
is_same
:
return
False
is_same
=
(
y_dims_mapping
[
-
1
]
==
out_dims_mapping
[
-
1
])
is_same
=
y_dims_mapping
[
-
1
]
==
out_dims_mapping
[
-
1
]
if
not
is_same
:
return
False
...
...
@@ -265,8 +293,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
backward_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
backward_op
)
assert
dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
backward_op
))
assert
(
dist_attr
is
not
None
),
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
backward_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
dist_attr
.
process_mesh
.
processes
:
...
...
@@ -277,22 +306,26 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
assert
'Out@GRAD'
in
kwargs
,
"input [{}] is not given"
.
format
(
'Out@GRAD'
)
assert
'Y@GRAD'
in
kwargs
,
"output [{}] is not given"
.
format
(
'Y@GRAD'
)
assert
'X@GRAD'
in
kwargs
,
"output [{}] is not given"
.
format
(
'X@GRAD'
)
assert
len
(
assert
(
len
(
kwargs
[
'Y'
])
==
1
),
"row_parallel_embedding input Ids take 1 variable but got {}"
.
format
(
kwargs
[
'Y'
]
)
==
1
,
"row_parallel_embedding input Ids take 1 variable but got {}"
.
format
(
kwargs
[
'Y'
])
assert
len
(
)
assert
(
len
(
kwargs
[
'X'
])
==
1
),
"row_parallel_embedding input Ids take 1 variable but got {}"
.
format
(
kwargs
[
'X'
]
)
==
1
,
"row_parallel_embedding input Ids take 1 variable but got {}"
.
format
(
kwargs
[
'X'
])
assert
len
(
kwargs
[
'Out@GRAD'
]
)
==
1
,
"row_parallel_embedding input Ids take 1 variable but got {}"
.
format
(
kwargs
[
'Out'
])
assert
len
(
)
assert
(
len
(
kwargs
[
'Out@GRAD'
])
==
1
),
"row_parallel_embedding input Ids take 1 variable but got {}"
.
format
(
kwargs
[
'Out'
]
)
assert
(
len
(
kwargs
[
'Y@GRAD'
])
==
1
),
"row_parallel_embedding output Ids take 1 variable but got {}"
.
format
(
kwargs
[
'Y@GRAD'
]
)
==
1
,
"row_parallel_embedding output Ids take 1 variable but got {}"
.
format
(
kwargs
[
'Y@GRAD'
])
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Y_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
...
...
@@ -302,7 +335,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
assert
not
is_parameter_related
(
X_var
.
name
,
main_block
),
"left operand(X) [{}] of dist matmul should not be parameter"
.
format
(
X_var
.
name
)
X_var
.
name
)
X_var_dims_mapping
=
dist_attr
.
get_input_dims_mapping
(
X_var
.
name
)
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var
.
name
)
...
...
@@ -339,28 +373,34 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
parallel_axis
=
Y_var_dim_mapping
[
0
]
check_variable_and_dtype
(
Out_grad
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'_c_identity'
)
Out_grad
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
)
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
[
"c_identity"
,
'tmp'
]))
+
"@GRAD"
,
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
([
"c_identity"
,
'tmp'
])
)
+
"@GRAD"
,
dtype
=
Out_grad
.
dtype
,
shape
=
Out_grad
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
stop_gradient
=
Out_grad
.
stop_gradient
)
stop_gradient
=
Out_grad
.
stop_gradient
,
)
# copy X_var's dist_attr to intermediate_var_0's dist_attr
out_grad_dist_attr
=
dist_attr
.
get_input_dist_attr
(
Out_grad
.
name
)
assert
out_grad_dist_attr
is
not
None
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
out_grad_dist_attr
)
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
out_grad_dist_attr
)
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group
=
new_process_group
(
group_ranks
)
c_identity_op
=
main_block
.
append_op
(
type
=
'c_identity'
,
...
...
@@ -371,20 +411,29 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
,
})
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
set_comm_op_dist_attr_for_program
(
c_identity_op
,
dist_attr
.
process_mesh
,
out_grad_dist_attr
,
ctx
)
},
)
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
set_comm_op_dist_attr_for_program
(
c_identity_op
,
dist_attr
.
process_mesh
,
out_grad_dist_attr
,
ctx
)
new_kwargs
=
copy
.
deepcopy
(
kwargs
)
new_kwargs
[
'Out@GRAD'
]
=
[
intermediate_var_0
.
name
]
matmul_op_desc
=
copy_op_with_new_input_output
(
ctx
,
main_block
,
backward_op
,
**
new_kwargs
)
ctx
,
main_block
,
backward_op
,
**
new_kwargs
)
else
:
# col parallel: matmul + allreduce
assert
Y_var_dim_mapping
[
0
]
<
0
...
...
@@ -397,28 +446,36 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
assert
len
(
kwargs
[
'X@GRAD'
])
==
1
X_grad
=
main_block
.
var
(
kwargs
[
'X@GRAD'
][
0
])
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
[
"c_identity"
,
'tmp'
]))
+
"@GRAD"
,
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
([
"c_identity"
,
'tmp'
])
)
+
"@GRAD"
,
dtype
=
X_grad
.
dtype
,
shape
=
X_grad
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
stop_gradient
=
X_grad
.
stop_gradient
)
stop_gradient
=
X_grad
.
stop_gradient
,
)
X_grad_dist_attr
=
dist_attr
.
get_output_dist_attr
(
X_grad
.
name
)
assert
X_grad_dist_attr
is
not
None
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
X_grad_dist_attr
)
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
X_grad_dist_attr
)
new_kwargs
[
'X@GRAD'
]
=
[
intermediate_var_0
.
name
]
matmul_op_desc
=
copy_op_with_new_input_output
(
ctx
,
main_block
,
backward_op
,
**
new_kwargs
)
ctx
,
main_block
,
backward_op
,
**
new_kwargs
)
# NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad
if
has_x_grad
:
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
,
)
group
=
new_process_group
(
group_ranks
)
c_allreduce_sum_op
=
main_block
.
append_op
(
type
=
'c_allreduce_sum'
,
...
...
@@ -428,15 +485,20 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
set_comm_op_dist_attr_for_program
(
c_allreduce_sum_op
,
OP_ROLE_KEY
:
OpRole
.
Backward
,
},
)
set_comm_op_dist_attr_for_program
(
c_allreduce_sum_op
,
dist_attr
.
process_mesh
,
X_grad_dist_attr
,
ctx
)
X_grad_dist_attr
,
ctx
,
)
else
:
# replicate
matmul_op_desc
=
copy_op_with_new_input_output
(
ctx
,
main_block
,
backward_op
,
**
kwargs
)
matmul_op_desc
=
copy_op_with_new_input_output
(
ctx
,
main_block
,
backward_op
,
**
kwargs
)
# data parallel gradient synchronization
act_grad_names
=
[
X_var
.
name
]
...
...
@@ -448,8 +510,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
if
trans_x
:
trans_x_y_dims_mapping
(
True
,
False
,
X_var_dims_mapping
,
None
)
gradient_synchronization
(
ctx
,
backward_op
,
act_grad_names
,
out_grad_names
,
rank_id
)
gradient_synchronization
(
ctx
,
backward_op
,
act_grad_names
,
out_grad_names
,
rank_id
)
if
trans_x
:
trans_x_y_dims_mapping
(
True
,
False
,
X_var_dims_mapping
,
None
)
...
...
@@ -472,23 +535,25 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
if
size
<=
1
or
axis
in
dim_mapping
:
pass
else
:
group_ranks
=
_get_comm_group
(
process_mesh
.
processes
,
process_mesh
.
topology
,
axis
,
rank_id
)
group_ranks
=
_get_comm_group
(
process_mesh
.
processes
,
process_mesh
.
topology
,
axis
,
rank_id
)
sync_group
=
new_process_group
(
group_ranks
)
startup_block
.
append_op
(
type
=
'c_broadcast'
,
startup_block
.
append_op
(
type
=
'c_broadcast'
,
inputs
=
{
'X'
:
param
},
outputs
=
{
'Out'
:
param
},
attrs
=
{
'ring_id'
:
sync_group
.
id
,
'root'
:
0
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
OP_ROLE_KEY
:
OpRole
.
Forward
,
},
)
class
DistributedMatmul
(
DistributedOperatorImplContainer
):
def
__init__
(
self
,
op_type
):
super
(
DistributedMatmul
,
self
).
__init__
(
op_type
)
...
...
@@ -498,7 +563,6 @@ register_distributed_operator_impl_container(DistributedMatmul("matmul"))
# ColumnParallel
class
DistributedMatmulImpl0
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
super
(
DistributedMatmulImpl0
,
self
).
__init__
(
name
)
self
.
_forward_implemented
=
True
...
...
@@ -521,7 +585,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
])
backward_op
.
input
(
"Y"
)[
0
]
)
# col parallel: matmul + allreduce
assert
Y_var_dim_mapping
[
0
]
<
0
parallel_axis
=
Y_var_dim_mapping
[
1
]
...
...
@@ -531,13 +596,14 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
assert
len
(
backward_op
.
output
(
"X@GRAD"
))
==
1
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
cost_mapping
=
build_comp_costs_from_descs
(
MatmulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
# calc comm op cost
...
...
@@ -550,40 +616,52 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
)
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
,
)
res
.
append
(
comm_op_cost_list
)
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
if
(
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
cost_mapping
=
build_comp_costs_from_descs
(
MatmulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
# calc comm op cost
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
])[
-
1
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
1
]
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
var_names
=
serial_op
.
input
(
"X"
)
c_identity_desc_mapping
=
build_comm_desc_from_dist_op
(
...
...
@@ -592,10 +670,12 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
res_cost
=
[
comm_op_cost_list
,
cost_mapping
]
return
res_cost
...
...
@@ -606,16 +686,19 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
)
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
)
trans_x
=
op_desc
.
attr
(
'transpose_X'
)
trans_y
=
op_desc
.
attr
(
'transpose_Y'
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
2
])
or
is_dim_replicate
(
y_dims_mapping
[
-
1
]):
y_dims_mapping
[
-
1
]
):
return
False
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
if
is_dim_shard
(
mapping
):
...
...
@@ -635,8 +718,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
return
True
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
...
...
@@ -661,28 +745,33 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
assert
op_dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
assert
(
op_dist_attr
is
not
None
),
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
# check validation of inputs / outputs
for
input_name
in
src_op
.
desc
.
input_names
():
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
input_name
)
input_name
)
assert
len
(
kwargs
[
input_name
])
==
len
(
src_op
.
desc
.
input
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
for
output_name
in
src_op
.
desc
.
output_names
():
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
output_name
)
output_name
)
assert
len
(
kwargs
[
output_name
])
==
len
(
src_op
.
desc
.
output
(
output_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
output_name
)
output_name
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
var
(
kwargs
[
'Y'
][
0
])
...
...
@@ -692,18 +781,24 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
# TODO infer logic comm presentation
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
Weight_var
.
name
)[
-
1
]
if
trans_y
:
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
assert
matmul_col_dim_mapping
>=
0
,
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_col_dim_mapping
)
Weight_var
.
name
)[
-
2
]
assert
(
matmul_col_dim_mapping
>=
0
),
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_col_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
parallel_axis
=
matmul_col_dim_mapping
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group
=
new_process_group
(
group_ranks
)
# infer new var shape with op dist attr
...
...
@@ -711,31 +806,39 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
assert
x_tensor_dist_attr
is
not
None
identity_var_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
X_var
.
name
)
assert
identity_var_dist_attr
is
not
None
ref_shape_x
=
infer_shape
(
main_block
,
X_var
,
x_tensor_dist_attr
,
identity_var_dist_attr
)
ref_shape_x
=
infer_shape
(
main_block
,
X_var
,
x_tensor_dist_attr
,
identity_var_dist_attr
)
# infer out var shape with op dist attr
out_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
Out_var
)
assert
out_tensor_dist_attr
is
not
None
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
out_var_dist_attr
is
not
None
ref_shape_out
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
ref_shape_out
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
[
"c_identity"
,
'tmp'
])),
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
([
"c_identity"
,
'tmp'
])
),
dtype
=
X_var
.
dtype
,
shape
=
X_var
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
stop_gradient
=
X_var
.
stop_gradient
)
stop_gradient
=
X_var
.
stop_gradient
,
)
# set intermediate_var_0's dist_attr with X_var's dist_attr
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
identity_var_dist_attr
)
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
identity_var_dist_attr
)
check_variable_and_dtype
(
X_var
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'_c_identity'
)
X_var
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
)
c_identity_op
=
main_block
.
append_op
(
type
=
'c_identity'
,
...
...
@@ -745,26 +848,34 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
),
},
)
if
intermediate_var_0
.
shape
!=
ref_shape_x
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
attrs
=
{
'transpose_X'
:
trans_x
,
'transpose_Y'
:
trans_y
,
'alpha'
:
1
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
,
}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
matmul_op
=
main_block
.
append_op
(
type
=
'matmul'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
Out_var
},
attrs
=
attrs
)
matmul_op
=
main_block
.
append_op
(
type
=
'matmul'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
Out_var
},
attrs
=
attrs
)
if
Out_var
.
shape
!=
ref_shape_out
:
Out_var
.
desc
.
set_shape
(
ref_shape_out
)
...
...
@@ -778,13 +889,16 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
input_varname
=
c_identity_op
.
desc
.
input_arg_names
()[
0
]
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
identity_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
op_dist_attr
)
identity_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
# output
output_varname
=
c_identity_op
.
desc
.
output_arg_names
()[
0
]
identity_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
input_dist_attr
)
identity_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
input_dist_attr
)
# set op dist attr
ctx
.
set_op_dist_attr_for_program
(
c_identity_op
,
identity_op_dist_attr
)
...
...
@@ -797,31 +911,39 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
for
input_varname
in
matmul_op
.
desc
.
input_arg_names
():
if
input_varname
in
src_op
.
desc
.
input_arg_names
():
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmul_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
op_dist_attr
)
matmul_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
else
:
input_var
=
main_block
.
var
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
matmul_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
tensor_dist_attr
)
input_var
)
matmul_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
tensor_dist_attr
)
# output
output_varname
=
matmul_op
.
desc
.
output_arg_names
()[
0
]
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmul_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
op_dist_attr
)
matmul_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
# set op dist attr
ctx
.
set_op_dist_attr_for_program
(
matmul_op
,
matmul_op_dist_attr
)
# init param sync
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
...
...
@@ -830,7 +952,6 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
# RowParallel
class
DistributedMatmulImpl1
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
super
(
DistributedMatmulImpl1
,
self
).
__init__
(
name
)
self
.
_forward_implemented
=
True
...
...
@@ -853,7 +974,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
])
backward_op
.
input
(
"Y"
)[
0
]
)
assert
Y_var_dim_mapping
[
1
]
<
0
parallel_axis
=
Y_var_dim_mapping
[
0
]
...
...
@@ -866,50 +988,60 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
comm_op_cost_list
=
build_comm_costs_from_descs
(
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
res
.
append
(
comm_op_cost_list
)
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
cost_mapping
=
build_comp_costs_from_descs
(
MatmulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
cost_mapping
=
build_comp_costs_from_descs
(
MatmulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
if
(
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
cost_mapping
=
build_comp_costs_from_descs
(
MatmulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
# calc comm op cost
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
])[
-
2
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
2
]
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
var_names
=
serial_op
.
output
(
"Out"
)
...
...
@@ -919,11 +1051,16 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
)
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
,
)
res_cost
=
[
cost_mapping
,
comm_op_cost_list
]
...
...
@@ -935,16 +1072,19 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
)
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
)
trans_x
=
op_desc
.
attr
(
'transpose_X'
)
trans_y
=
op_desc
.
attr
(
'transpose_Y'
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
if
is_dim_replicate
(
x_dims_mapping
[
-
1
]):
return
False
if
is_dim_replicate
(
y_dims_mapping
[
-
2
])
or
is_dim_shard
(
y_dims_mapping
[
-
1
]):
y_dims_mapping
[
-
1
]
):
return
False
# Other dimensions must be replicate except the batch dimension
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
...
...
@@ -966,8 +1106,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
return
True
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
...
...
@@ -992,28 +1133,33 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
assert
op_dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
assert
(
op_dist_attr
is
not
None
),
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
# check validation of inputs / outputs
for
input_name
in
src_op
.
desc
.
input_names
():
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
input_name
)
input_name
)
assert
len
(
kwargs
[
input_name
])
==
len
(
src_op
.
desc
.
input
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
for
output_name
in
src_op
.
desc
.
output_names
():
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
output_name
)
output_name
)
assert
len
(
kwargs
[
output_name
])
==
len
(
src_op
.
desc
.
output
(
output_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
output_name
)
output_name
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
var
(
kwargs
[
'Y'
][
0
])
...
...
@@ -1023,29 +1169,40 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# TODO infer logic comm presentation
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
Weight_var
.
name
)[
-
2
]
if
trans_y
:
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
assert
matmul_row_dim_mapping
>=
0
,
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_row_dim_mapping
)
Weight_var
.
name
)[
-
1
]
assert
(
matmul_row_dim_mapping
>=
0
),
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_row_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
parallel_axis
=
matmul_row_dim_mapping
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group
=
new_process_group
(
group_ranks
)
check_variable_and_dtype
(
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_variable_and_dtype
(
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
)
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
attrs
=
{
'transpose_X'
:
trans_x
,
'transpose_Y'
:
trans_y
,
'alpha'
:
1
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
,
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
...
...
@@ -1054,27 +1211,33 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
assert
out_tensor_dist_attr
is
not
None
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
out_var_dist_attr
is
not
None
ref_shape
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
ref_shape
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
[
"c_allreduce_sum"
,
'tmp'
])),
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
([
"c_allreduce_sum"
,
'tmp'
])
),
shape
=
Out_var
.
shape
,
dtype
=
Out_var
.
dtype
,
type
=
Out_var
.
type
,
lod_level
=
Out_var
.
lod_level
,
persistable
=
False
,
is_data
=
False
,
need_check_feed
=
Out_var
.
desc
.
need_check_feed
())
need_check_feed
=
Out_var
.
desc
.
need_check_feed
(),
)
# set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
out_var_dist_attr
)
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
out_var_dist_attr
)
matmul_op
=
main_block
.
append_op
(
type
=
'matmul'
,
matmul_op
=
main_block
.
append_op
(
type
=
'matmul'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
intermediate_var_0
},
attrs
=
attrs
)
attrs
=
attrs
,
)
if
intermediate_var_0
.
shape
!=
ref_shape
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape
)
...
...
@@ -1086,8 +1249,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
),
},
)
if
Out_var
.
shape
!=
ref_shape
:
Out_var
.
desc
.
set_shape
(
ref_shape
)
...
...
@@ -1100,15 +1264,19 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
for
input_varname
in
matmul_op
.
desc
.
input_arg_names
():
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmul_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
op_dist_attr
)
matmul_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
output_varname
=
matmul_op
.
desc
.
output_arg_names
()[
0
]
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmul_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
op_dist_attr
)
matmul_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
matmul_op
,
matmul_op_dist_attr
)
# allreduce
...
...
@@ -1120,21 +1288,26 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
input_var
=
main_block
.
var
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
assert
tensor_dist_attr
is
not
None
allreduce_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
tensor_dist_attr
)
allreduce_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
tensor_dist_attr
)
for
output_varname
in
c_allreduce_sum_op
.
desc
.
output_arg_names
():
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
allreduce_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
c_allreduce_sum_op
,
allreduce_op_dist_attr
)
op_dist_attr
)
allreduce_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
c_allreduce_sum_op
,
allreduce_op_dist_attr
)
# init param sync
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
...
...
@@ -1143,7 +1316,6 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# ReplicateParallel
class
DistributedMatmulImpl2
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
super
(
DistributedMatmulImpl2
,
self
).
__init__
(
name
)
...
...
@@ -1164,38 +1336,45 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
vars
=
main_block
.
vars
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
cost_mapping
=
build_comp_costs_from_descs
(
MatmulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
if
(
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
cost_mapping
=
build_comp_costs_from_descs
(
MatmulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res_cost
=
[
cost_mapping
]
return
res_cost
...
...
@@ -1211,13 +1390,15 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
if
is_valid_list_index
(
x_dims_mapping
,
-
2
)
and
is_dim_shard
(
x_dims_mapping
[
-
2
]):
x_dims_mapping
[
-
2
]
):
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
1
]):
return
False
if
is_valid_list_index
(
y_dims_mapping
,
-
2
)
and
is_dim_shard
(
y_dims_mapping
[
-
2
]):
y_dims_mapping
[
-
2
]
):
return
False
return
True
...
...
@@ -1231,14 +1412,16 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
if
is_dim_shard
(
out_dims_mapping
[
-
1
]):
return
False
if
is_valid_list_index
(
out_dims_mapping
,
-
2
)
and
is_dim_shard
(
out_dims_mapping
[
-
2
]):
out_dims_mapping
[
-
2
]
):
return
False
return
True
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
...
...
@@ -1262,16 +1445,18 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
_right_operand_parameter_matmul_backward
(
ctx
,
*
args
,
**
kwargs
)
register_distributed_operator_impl
(
"matmul"
,
DistributedMatmulImpl0
(
"column_parallel"
))
register_distributed_operator_impl
(
"matmul"
,
DistributedMatmulImpl1
(
"row_parallel"
))
register_distributed_operator_impl
(
"matmul"
,
DistributedMatmulImpl2
(
"replicate_parallel"
))
register_distributed_operator_impl
(
"matmul"
,
DistributedMatmulImpl0
(
"column_parallel"
)
)
register_distributed_operator_impl
(
"matmul"
,
DistributedMatmulImpl1
(
"row_parallel"
)
)
register_distributed_operator_impl
(
"matmul"
,
DistributedMatmulImpl2
(
"replicate_parallel"
)
)
class
DistributedMatmulV2
(
DistributedOperatorImplContainer
):
def
__init__
(
self
,
op_type
):
super
(
DistributedMatmulV2
,
self
).
__init__
(
op_type
)
...
...
@@ -1281,7 +1466,6 @@ register_distributed_operator_impl_container(DistributedMatmulV2("matmul_v2"))
# ColumnParallel
class
DistributedMatmulV2Impl0
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
super
(
DistributedMatmulV2Impl0
,
self
).
__init__
(
name
)
self
.
_forward_implemented
=
True
...
...
@@ -1304,7 +1488,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
])
backward_op
.
input
(
"Y"
)[
0
]
)
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
# col parallel: matmul + allreduce
...
...
@@ -1318,12 +1503,13 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
assert
len
(
backward_op
.
output
(
"X@GRAD"
))
==
1
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2GradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2GradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
# calc comm op cost
...
...
@@ -1336,45 +1522,55 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
)
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
,
)
res
.
append
(
comm_op_cost_list
)
# need gradient allreduce
process_mesh
=
dist_attr
.
process_mesh
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
if
(
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
# TODO: trans shape if trans_x or trans_y is True
comp_desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
comp_desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
comp_cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2OpCost
,
ctx
,
processes
,
comp_desc_mapping
,
cluster
)
comp_cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2OpCost
,
ctx
,
processes
,
comp_desc_mapping
,
cluster
)
# calc comm op cost
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
])[
-
1
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
1
]
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
var_names
=
serial_op
.
input
(
"X"
)
...
...
@@ -1384,9 +1580,11 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
res_cost
=
[
comm_op_cost_list
,
comp_cost_mapping
]
return
res_cost
...
...
@@ -1397,16 +1595,19 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
)
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
)
trans_x
=
op_desc
.
attr
(
'trans_x'
)
trans_y
=
op_desc
.
attr
(
'trans_y'
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
2
])
or
is_dim_replicate
(
y_dims_mapping
[
-
1
]):
y_dims_mapping
[
-
1
]
):
return
False
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
if
is_dim_shard
(
mapping
):
...
...
@@ -1426,8 +1627,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
return
True
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
...
...
@@ -1452,28 +1654,33 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
assert
op_dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
assert
(
op_dist_attr
is
not
None
),
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
# check validation of inputs / outputs
for
input_name
in
src_op
.
desc
.
input_names
():
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
input_name
)
input_name
)
assert
len
(
kwargs
[
input_name
])
==
len
(
src_op
.
desc
.
input
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
for
output_name
in
src_op
.
desc
.
output_names
():
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
output_name
)
output_name
)
assert
len
(
kwargs
[
output_name
])
==
len
(
src_op
.
desc
.
output
(
output_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
output_name
)
output_name
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
...
...
@@ -1483,18 +1690,24 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
# TODO infer logic comm presentation
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
Weight_var
.
name
)[
-
1
]
if
trans_y
:
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
assert
matmul_col_dim_mapping
>=
0
,
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_col_dim_mapping
)
Weight_var
.
name
)[
-
2
]
assert
(
matmul_col_dim_mapping
>=
0
),
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_col_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
parallel_axis
=
matmul_col_dim_mapping
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group
=
new_process_group
(
group_ranks
)
# infer new var shape with op dist attr
...
...
@@ -1502,31 +1715,39 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
assert
x_tensor_dist_attr
is
not
None
identity_var_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
X_var
.
name
)
assert
identity_var_dist_attr
is
not
None
ref_shape_x
=
infer_shape
(
main_block
,
X_var
,
x_tensor_dist_attr
,
identity_var_dist_attr
)
ref_shape_x
=
infer_shape
(
main_block
,
X_var
,
x_tensor_dist_attr
,
identity_var_dist_attr
)
# infer out var shape with op dist attr
out_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
Out_var
)
assert
out_tensor_dist_attr
is
not
None
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
out_var_dist_attr
is
not
None
ref_shape_out
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
ref_shape_out
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
[
"c_identity"
,
'tmp'
])),
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
([
"c_identity"
,
'tmp'
])
),
dtype
=
X_var
.
dtype
,
shape
=
X_var
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
stop_gradient
=
X_var
.
stop_gradient
)
stop_gradient
=
X_var
.
stop_gradient
,
)
# set intermediate_var_0's dist_attr with X_var's dist_attr
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
identity_var_dist_attr
)
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
identity_var_dist_attr
)
check_variable_and_dtype
(
X_var
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'_c_identity'
)
X_var
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
)
c_identity_op
=
main_block
.
append_op
(
type
=
'c_identity'
,
inputs
=
{
'X'
:
[
X_var
]},
...
...
@@ -1536,24 +1757,35 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
),
})
},
)
if
intermediate_var_0
.
shape
!=
ref_shape_x
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
attrs
=
{
'trans_x'
:
trans_x
,
'trans_y'
:
trans_y
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
,
}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
matmul_v2_op
=
main_block
.
append_op
(
type
=
'matmul_v2'
,
matmul_v2_op
=
main_block
.
append_op
(
type
=
'matmul_v2'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
Out_var
},
attrs
=
attrs
)
attrs
=
attrs
,
)
if
Out_var
.
shape
!=
ref_shape_out
:
Out_var
.
desc
.
set_shape
(
ref_shape_out
)
...
...
@@ -1567,13 +1799,16 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
input_varname
=
c_identity_op
.
desc
.
input_arg_names
()[
0
]
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
identity_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
op_dist_attr
)
identity_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
# output
output_varname
=
c_identity_op
.
desc
.
output_arg_names
()[
0
]
identity_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
input_dist_attr
)
identity_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
input_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
c_identity_op
,
identity_op_dist_attr
)
# matmulv2
...
...
@@ -1584,29 +1819,37 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
for
input_varname
in
matmul_v2_op
.
desc
.
input_arg_names
():
if
input_varname
in
src_op
.
desc
.
input_arg_names
():
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
input_varname
,
input_dist_attr
)
else
:
input_var
=
main_block
.
var
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
input_var
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
tensor_dist_attr
)
input_varname
,
tensor_dist_attr
)
for
output_varname
in
matmul_v2_op
.
desc
.
output_arg_names
():
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
op_dist_attr
)
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
matmul_v2_op
,
matmulv2_op_dist_attr
)
# init param sync
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
...
...
@@ -1615,7 +1858,6 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
# RowParallel
class
DistributedMatmulV2Impl1
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
super
(
DistributedMatmulV2Impl1
,
self
).
__init__
(
name
)
self
.
_forward_implemented
=
True
...
...
@@ -1638,7 +1880,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
])
backward_op
.
input
(
"Y"
)[
0
]
)
assert
Y_var_dim_mapping
[
1
]
<
0
parallel_axis
=
Y_var_dim_mapping
[
0
]
...
...
@@ -1653,50 +1896,59 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
res
.
append
(
comm_op_cost_list
)
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2GradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2GradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
# need gradient allreduce
process_mesh
=
dist_attr
.
process_mesh
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
if
(
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2OpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2OpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
# calc comm op cost
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
])[
-
2
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
2
]
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
var_names
=
serial_op
.
output
(
"Out"
)
...
...
@@ -1706,11 +1958,16 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
)
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
,
)
res_cost
=
[
cost_mapping
,
comm_op_cost_list
]
return
res_cost
...
...
@@ -1721,16 +1978,19 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
)
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
)
trans_x
=
op_desc
.
attr
(
'trans_x'
)
trans_y
=
op_desc
.
attr
(
'trans_y'
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
if
is_dim_replicate
(
x_dims_mapping
[
-
1
]):
return
False
if
is_dim_replicate
(
y_dims_mapping
[
-
2
])
or
is_dim_shard
(
y_dims_mapping
[
-
1
]):
y_dims_mapping
[
-
1
]
):
return
False
# Other dimensions must be replicate except the batch dimension
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
...
...
@@ -1752,8 +2012,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
return
True
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
...
...
@@ -1778,28 +2039,33 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
assert
op_dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
assert
(
op_dist_attr
is
not
None
),
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
# check validation of inputs / outputs
for
input_name
in
src_op
.
desc
.
input_names
():
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
input_name
)
input_name
)
assert
len
(
kwargs
[
input_name
])
==
len
(
src_op
.
desc
.
input
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
for
output_name
in
src_op
.
desc
.
output_names
():
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
output_name
)
output_name
)
assert
len
(
kwargs
[
output_name
])
==
len
(
src_op
.
desc
.
output
(
output_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
output_name
)
output_name
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
...
...
@@ -1809,28 +2075,39 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
# TODO infer logic comm presentation
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
Weight_var
.
name
)[
-
2
]
if
trans_y
:
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
assert
matmul_row_dim_mapping
>=
0
,
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_row_dim_mapping
)
Weight_var
.
name
)[
-
1
]
assert
(
matmul_row_dim_mapping
>=
0
),
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_row_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
parallel_axis
=
matmul_row_dim_mapping
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group
=
new_process_group
(
group_ranks
)
check_variable_and_dtype
(
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_variable_and_dtype
(
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
)
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
attrs
=
{
'trans_x'
:
trans_x
,
'trans_y'
:
trans_y
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
,
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
...
...
@@ -1839,27 +2116,33 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
assert
out_tensor_dist_attr
is
not
None
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
out_var_dist_attr
is
not
None
ref_shape
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
ref_shape
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
[
"c_allreduce_sum"
,
'tmp'
])),
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
([
"c_allreduce_sum"
,
'tmp'
])
),
shape
=
Out_var
.
shape
,
dtype
=
Out_var
.
dtype
,
type
=
Out_var
.
type
,
lod_level
=
Out_var
.
lod_level
,
persistable
=
False
,
is_data
=
False
,
need_check_feed
=
Out_var
.
desc
.
need_check_feed
())
need_check_feed
=
Out_var
.
desc
.
need_check_feed
(),
)
# set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
out_var_dist_attr
)
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
out_var_dist_attr
)
matmul_v2_op
=
main_block
.
append_op
(
type
=
'matmul_v2'
,
matmul_v2_op
=
main_block
.
append_op
(
type
=
'matmul_v2'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
intermediate_var_0
},
attrs
=
attrs
)
attrs
=
attrs
,
)
if
intermediate_var_0
.
shape
!=
ref_shape
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape
)
...
...
@@ -1871,8 +2154,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
),
},
)
if
Out_var
.
shape
!=
ref_shape
:
Out_var
.
desc
.
set_shape
(
ref_shape
)
...
...
@@ -1885,15 +2169,19 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
for
input_varname
in
matmul_v2_op
.
desc
.
input_arg_names
():
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
op_dist_attr
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
output_varname
=
matmul_v2_op
.
desc
.
output_arg_names
()[
0
]
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
op_dist_attr
)
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
matmul_v2_op
,
matmulv2_op_dist_attr
)
# allreduce
...
...
@@ -1905,21 +2193,26 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
input_var
=
main_block
.
var
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
assert
tensor_dist_attr
is
not
None
allreduce_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
tensor_dist_attr
)
allreduce_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
tensor_dist_attr
)
for
output_varname
in
c_allreduce_sum_op
.
desc
.
output_arg_names
():
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
allreduce_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
c_allreduce_sum_op
,
allreduce_op_dist_attr
)
op_dist_attr
)
allreduce_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
c_allreduce_sum_op
,
allreduce_op_dist_attr
)
# init param sync
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
...
...
@@ -1928,7 +2221,6 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
# ReplicateParallel
class
DistributedMatmulV2Impl2
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
super
(
DistributedMatmulV2Impl2
,
self
).
__init__
(
name
)
...
...
@@ -1950,38 +2242,44 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
process_mesh
=
dist_attr
.
process_mesh
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2GradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2GradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
if
(
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2OpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2OpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res_cost
=
[
cost_mapping
]
...
...
@@ -1998,13 +2296,15 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
if
is_valid_list_index
(
x_dims_mapping
,
-
2
)
and
is_dim_shard
(
x_dims_mapping
[
-
2
]):
x_dims_mapping
[
-
2
]
):
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
1
]):
return
False
if
is_valid_list_index
(
y_dims_mapping
,
-
2
)
and
is_dim_shard
(
y_dims_mapping
[
-
2
]):
y_dims_mapping
[
-
2
]
):
return
False
return
True
...
...
@@ -2019,14 +2319,16 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
if
is_dim_shard
(
out_dims_mapping
[
-
1
]):
return
False
if
is_valid_list_index
(
out_dims_mapping
,
-
2
)
and
is_dim_shard
(
out_dims_mapping
[
-
2
]):
out_dims_mapping
[
-
2
]
):
return
False
return
True
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
...
...
@@ -2050,16 +2352,18 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
_right_operand_parameter_matmul_backward
(
ctx
,
*
args
,
**
kwargs
)
register_distributed_operator_impl
(
"matmul_v2"
,
DistributedMatmulV2Impl0
(
"column_parallel"
))
register_distributed_operator_impl
(
"matmul_v2"
,
DistributedMatmulV2Impl1
(
"row_parallel"
))
register_distributed_operator_impl
(
"matmul_v2"
,
DistributedMatmulV2Impl2
(
"replicate_parallel"
))
"matmul_v2"
,
DistributedMatmulV2Impl0
(
"column_parallel"
)
)
register_distributed_operator_impl
(
"matmul_v2"
,
DistributedMatmulV2Impl1
(
"row_parallel"
)
)
register_distributed_operator_impl
(
"matmul_v2"
,
DistributedMatmulV2Impl2
(
"replicate_parallel"
)
)
class
DistributedMul
(
DistributedOperatorImplContainer
):
def
__init__
(
self
,
op_type
):
super
(
DistributedMul
,
self
).
__init__
(
op_type
)
...
...
@@ -2069,7 +2373,6 @@ register_distributed_operator_impl_container(DistributedMul("mul"))
# ColumnParallel
class
DistributedMulImpl0
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
super
(
DistributedMulImpl0
,
self
).
__init__
(
name
)
self
.
_forward_implemented
=
True
...
...
@@ -2092,7 +2395,8 @@ class DistributedMulImpl0(DistributedOperatorImpl):
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
])
backward_op
.
input
(
"Y"
)[
0
]
)
# col parallel: matmul + allreduce
assert
Y_var_dim_mapping
[
0
]
<
0
parallel_axis
=
Y_var_dim_mapping
[
1
]
...
...
@@ -2102,13 +2406,14 @@ class DistributedMulImpl0(DistributedOperatorImpl):
assert
len
(
backward_op
.
output
(
"X@GRAD"
))
==
1
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
cost_mapping
=
build_comp_costs_from_descs
(
MulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
# calc comm op cost
...
...
@@ -2121,40 +2426,52 @@ class DistributedMulImpl0(DistributedOperatorImpl):
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
)
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
,
)
res
.
append
(
comm_op_cost_list
)
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
if
(
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
cost_mapping
=
build_comp_costs_from_descs
(
MulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
# calc comm op cost
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
])[
-
1
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
1
]
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
var_names
=
serial_op
.
input
(
"X"
)
c_identity_desc_mapping
=
build_comm_desc_from_dist_op
(
...
...
@@ -2163,10 +2480,12 @@ class DistributedMulImpl0(DistributedOperatorImpl):
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
res_cost
=
[
comm_op_cost_list
,
cost_mapping
]
return
res_cost
...
...
@@ -2181,7 +2500,8 @@ class DistributedMulImpl0(DistributedOperatorImpl):
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
2
])
or
is_dim_replicate
(
y_dims_mapping
[
-
1
]):
y_dims_mapping
[
-
1
]
):
return
False
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
if
is_dim_shard
(
mapping
):
...
...
@@ -2201,8 +2521,9 @@ class DistributedMulImpl0(DistributedOperatorImpl):
return
True
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
...
...
@@ -2229,28 +2550,33 @@ class DistributedMulImpl0(DistributedOperatorImpl):
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
assert
op_dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
assert
(
op_dist_attr
is
not
None
),
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
# check validation of inputs / outputs
for
input_name
in
src_op
.
desc
.
input_names
():
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
input_name
)
input_name
)
assert
len
(
kwargs
[
input_name
])
==
len
(
src_op
.
desc
.
input
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
for
output_name
in
src_op
.
desc
.
output_names
():
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
output_name
)
output_name
)
assert
len
(
kwargs
[
output_name
])
==
len
(
src_op
.
desc
.
output
(
output_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
output_name
)
output_name
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
...
...
@@ -2258,15 +2584,20 @@ class DistributedMulImpl0(DistributedOperatorImpl):
# TODO infer logic comm presentation
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
assert
matmul_col_dim_mapping
>=
0
,
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_col_dim_mapping
)
Weight_var
.
name
)[
-
1
]
assert
(
matmul_col_dim_mapping
>=
0
),
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_col_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
parallel_axis
=
matmul_col_dim_mapping
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group
=
new_process_group
(
group_ranks
)
# infer new var shape with op dist attr
...
...
@@ -2274,31 +2605,39 @@ class DistributedMulImpl0(DistributedOperatorImpl):
assert
x_tensor_dist_attr
is
not
None
identity_var_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
X_var
.
name
)
assert
identity_var_dist_attr
is
not
None
ref_shape_x
=
infer_shape
(
main_block
,
X_var
,
x_tensor_dist_attr
,
identity_var_dist_attr
)
ref_shape_x
=
infer_shape
(
main_block
,
X_var
,
x_tensor_dist_attr
,
identity_var_dist_attr
)
# infer out var shape with op dist attr
out_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
Out_var
)
assert
out_tensor_dist_attr
is
not
None
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
out_var_dist_attr
is
not
None
ref_shape_out
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
ref_shape_out
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
[
"c_identity"
,
'tmp'
])),
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
([
"c_identity"
,
'tmp'
])
),
dtype
=
X_var
.
dtype
,
shape
=
X_var
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
stop_gradient
=
X_var
.
stop_gradient
)
stop_gradient
=
X_var
.
stop_gradient
,
)
# set intermediate_var_0's dist_attr with X_var's dist_attr
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
identity_var_dist_attr
)
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
identity_var_dist_attr
)
check_variable_and_dtype
(
X_var
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'_c_identity'
)
X_var
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
)
c_identity_op
=
main_block
.
append_op
(
type
=
'c_identity'
,
inputs
=
{
'X'
:
[
X_var
]},
...
...
@@ -2307,20 +2646,29 @@ class DistributedMulImpl0(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
),
},
)
if
intermediate_var_0
.
shape
!=
ref_shape_x
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
# attrs = {'trans_x': False, 'trans_y': False}
attrs
=
{
"x_num_col_dims"
:
src_op
.
desc
.
attr
(
"x_num_col_dims"
),
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
),
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
,
}
inputs
=
{
'X'
:
intermediate_var_0
,
'Y'
:
Weight_var
}
...
...
@@ -2334,16 +2682,15 @@ class DistributedMulImpl0(DistributedOperatorImpl):
inputs_original_shape
[
var_name
]
=
var
.
shape
input_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
var
)
input_var_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
var
.
name
)
input_ref_shape
=
infer_shape
(
main_block
,
var
,
input_tensor_dist_attr
,
input_var_dist_attr
)
input_ref_shape
=
infer_shape
(
main_block
,
var
,
input_tensor_dist_attr
,
input_var_dist_attr
)
inputs_ref_shape
[
var_name
]
=
input_ref_shape
var
.
desc
.
set_shape
(
input_ref_shape
)
mul_op
=
main_block
.
append_op
(
type
=
'mul'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
Out_var
},
attrs
=
attrs
)
mul_op
=
main_block
.
append_op
(
type
=
'mul'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
Out_var
},
attrs
=
attrs
)
if
Out_var
.
shape
!=
ref_shape_out
:
Out_var
.
desc
.
set_shape
(
ref_shape_out
)
...
...
@@ -2362,13 +2709,16 @@ class DistributedMulImpl0(DistributedOperatorImpl):
input_varname
=
c_identity_op
.
desc
.
input_arg_names
()[
0
]
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
identity_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
op_dist_attr
)
identity_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
# output
output_varname
=
c_identity_op
.
desc
.
output_arg_names
()[
0
]
identity_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
input_dist_attr
)
identity_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
input_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
c_identity_op
,
identity_op_dist_attr
)
# matmulv2
...
...
@@ -2379,29 +2729,37 @@ class DistributedMulImpl0(DistributedOperatorImpl):
for
input_varname
in
mul_op
.
desc
.
input_arg_names
():
if
input_varname
in
src_op
.
desc
.
input_arg_names
():
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
input_varname
,
input_dist_attr
)
else
:
input_var
=
main_block
.
var
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
input_var
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
tensor_dist_attr
)
input_varname
,
tensor_dist_attr
)
for
output_varname
in
mul_op
.
desc
.
output_arg_names
():
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
op_dist_attr
)
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
mul_op
,
matmulv2_op_dist_attr
)
# init param sync
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
...
...
@@ -2410,7 +2768,6 @@ class DistributedMulImpl0(DistributedOperatorImpl):
# RowParallel
class
DistributedMulImpl1
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
super
(
DistributedMulImpl1
,
self
).
__init__
(
name
)
self
.
_forward_implemented
=
True
...
...
@@ -2434,7 +2791,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
])
backward_op
.
input
(
"Y"
)[
0
]
)
assert
Y_var_dim_mapping
[
1
]
<
0
parallel_axis
=
Y_var_dim_mapping
[
0
]
...
...
@@ -2447,49 +2805,59 @@ class DistributedMulImpl1(DistributedOperatorImpl):
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
processes
=
process_mesh
.
processes
comm_op_cost_list
=
build_comm_costs_from_descs
(
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
res
.
append
(
comm_op_cost_list
)
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
cost_mapping
=
build_comp_costs_from_descs
(
MulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
cost_mapping
=
build_comp_costs_from_descs
(
MulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
if
(
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
cost_mapping
=
build_comp_costs_from_descs
(
MulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
# calc comm op cost
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
])[
-
2
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
2
]
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
var_names
=
serial_op
.
output
(
"Out"
)
...
...
@@ -2499,12 +2867,17 @@ class DistributedMulImpl1(DistributedOperatorImpl):
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
# print("dist_matmul.py dist_op: ", dist_op)
comm_op_cost_list
=
build_comm_costs_from_descs
(
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
)
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
,
)
res_cost
=
[
cost_mapping
,
comm_op_cost_list
]
...
...
@@ -2520,7 +2893,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
if
is_dim_replicate
(
x_dims_mapping
[
-
1
]):
return
False
if
is_dim_replicate
(
y_dims_mapping
[
-
2
])
or
is_dim_shard
(
y_dims_mapping
[
-
1
]):
y_dims_mapping
[
-
1
]
):
return
False
# Other dimensions must be replicate except the batch dimension
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
...
...
@@ -2542,8 +2916,9 @@ class DistributedMulImpl1(DistributedOperatorImpl):
return
True
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
...
...
@@ -2570,28 +2945,33 @@ class DistributedMulImpl1(DistributedOperatorImpl):
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
assert
op_dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
assert
(
op_dist_attr
is
not
None
),
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
# check validation of inputs / outputs
for
input_name
in
src_op
.
desc
.
input_names
():
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
input_name
)
input_name
)
assert
len
(
kwargs
[
input_name
])
==
len
(
src_op
.
desc
.
input
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
for
output_name
in
src_op
.
desc
.
output_names
():
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
output_name
)
output_name
)
assert
len
(
kwargs
[
output_name
])
==
len
(
src_op
.
desc
.
output
(
output_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
output_name
)
output_name
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
...
...
@@ -2599,26 +2979,36 @@ class DistributedMulImpl1(DistributedOperatorImpl):
# TODO infer logic comm presentation
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
assert
matmul_row_dim_mapping
>=
0
,
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_row_dim_mapping
)
Weight_var
.
name
)[
-
2
]
assert
(
matmul_row_dim_mapping
>=
0
),
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_row_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
parallel_axis
=
matmul_row_dim_mapping
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group
=
new_process_group
(
group_ranks
)
check_variable_and_dtype
(
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_variable_and_dtype
(
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
)
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
# attrs = {'trans_x': False, 'trans_y': False}
attrs
=
{
"x_num_col_dims"
:
src_op
.
desc
.
attr
(
"x_num_col_dims"
),
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
),
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
,
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
...
...
@@ -2627,22 +3017,26 @@ class DistributedMulImpl1(DistributedOperatorImpl):
assert
out_tensor_dist_attr
is
not
None
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
out_var_dist_attr
is
not
None
ref_shape
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
ref_shape
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
[
"c_allreduce_sum"
,
'tmp'
])),
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
([
"c_allreduce_sum"
,
'tmp'
])
),
shape
=
Out_var
.
shape
,
dtype
=
Out_var
.
dtype
,
type
=
Out_var
.
type
,
lod_level
=
Out_var
.
lod_level
,
persistable
=
False
,
is_data
=
False
,
need_check_feed
=
Out_var
.
desc
.
need_check_feed
())
need_check_feed
=
Out_var
.
desc
.
need_check_feed
(),
)
# set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
out_var_dist_attr
)
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
out_var_dist_attr
)
inputs_ref_shape
=
{}
inputs_original_shape
=
{}
...
...
@@ -2651,16 +3045,18 @@ class DistributedMulImpl1(DistributedOperatorImpl):
inputs_original_shape
[
var_name
]
=
var
.
shape
input_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
var
)
input_var_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
var
.
name
)
input_ref_shape
=
infer_shape
(
main_block
,
var
,
input_tensor_dist_attr
,
input_var_dist_attr
)
input_ref_shape
=
infer_shape
(
main_block
,
var
,
input_tensor_dist_attr
,
input_var_dist_attr
)
inputs_ref_shape
[
var_name
]
=
input_ref_shape
var
.
desc
.
set_shape
(
input_ref_shape
)
mul_op
=
main_block
.
append_op
(
type
=
'mul'
,
mul_op
=
main_block
.
append_op
(
type
=
'mul'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
intermediate_var_0
},
attrs
=
attrs
)
attrs
=
attrs
,
)
if
intermediate_var_0
.
shape
!=
ref_shape
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape
)
...
...
@@ -2678,8 +3074,9 @@ class DistributedMulImpl1(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
),
},
)
if
Out_var
.
shape
!=
ref_shape
:
Out_var
.
desc
.
set_shape
(
ref_shape
)
...
...
@@ -2693,15 +3090,19 @@ class DistributedMulImpl1(DistributedOperatorImpl):
for
input_varname
in
mul_op
.
desc
.
input_arg_names
():
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
op_dist_attr
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
output_varname
=
mul_op
.
desc
.
output_arg_names
()[
0
]
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
op_dist_attr
)
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
mul_op
,
matmulv2_op_dist_attr
)
# allreduce
...
...
@@ -2713,21 +3114,26 @@ class DistributedMulImpl1(DistributedOperatorImpl):
input_var
=
main_block
.
var
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
assert
tensor_dist_attr
is
not
None
allreduce_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
tensor_dist_attr
)
allreduce_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
tensor_dist_attr
)
for
output_varname
in
c_allreduce_sum_op
.
desc
.
output_arg_names
():
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
allreduce_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
c_allreduce_sum_op
,
allreduce_op_dist_attr
)
op_dist_attr
)
allreduce_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
c_allreduce_sum_op
,
allreduce_op_dist_attr
)
# init param sync
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
...
...
@@ -2736,7 +3142,6 @@ class DistributedMulImpl1(DistributedOperatorImpl):
# ReplicateParallel
class
DistributedMulImpl2
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
super
(
DistributedMulImpl2
,
self
).
__init__
(
name
)
...
...
@@ -2757,38 +3162,45 @@ class DistributedMulImpl2(DistributedOperatorImpl):
vars
=
main_block
.
vars
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
cost_mapping
=
build_comp_costs_from_descs
(
MulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
if
(
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
cost_mapping
=
build_comp_costs_from_descs
(
MulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res_cost
=
[
cost_mapping
]
return
res_cost
...
...
@@ -2804,12 +3216,14 @@ class DistributedMulImpl2(DistributedOperatorImpl):
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
if
is_valid_list_index
(
x_dims_mapping
,
-
2
)
and
is_dim_shard
(
x_dims_mapping
[
-
2
]):
x_dims_mapping
[
-
2
]
):
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
1
]):
return
False
if
is_valid_list_index
(
y_dims_mapping
,
-
2
)
and
is_dim_shard
(
y_dims_mapping
[
-
2
]):
y_dims_mapping
[
-
2
]
):
return
False
return
True
...
...
@@ -2824,14 +3238,16 @@ class DistributedMulImpl2(DistributedOperatorImpl):
if
is_dim_shard
(
out_dims_mapping
[
-
1
]):
return
False
if
is_valid_list_index
(
out_dims_mapping
,
-
2
)
and
is_dim_shard
(
out_dims_mapping
[
-
2
]):
out_dims_mapping
[
-
2
]
):
return
False
return
True
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
(
not
self
.
is_output_compatible
(
dist_op
)):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
...
...
@@ -2855,8 +3271,10 @@ class DistributedMulImpl2(DistributedOperatorImpl):
_right_operand_parameter_matmul_backward
(
ctx
,
*
args
,
**
kwargs
)
register_distributed_operator_impl
(
"mul"
,
DistributedMulImpl0
(
"column_parallel"
))
register_distributed_operator_impl
(
"mul"
,
DistributedMulImpl0
(
"column_parallel"
)
)
register_distributed_operator_impl
(
"mul"
,
DistributedMulImpl1
(
"row_parallel"
))
register_distributed_operator_impl
(
"mul"
,
DistributedMulImpl2
(
"replicate_parallel"
))
register_distributed_operator_impl
(
"mul"
,
DistributedMulImpl2
(
"replicate_parallel"
)
)
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
6f3c9643
...
...
@@ -254,17 +254,26 @@ class Parallelizer:
self
.
_dist_context
.
serial_feed_vars
[
"inputs"
]
+
self
.
_dist_context
.
serial_feed_vars
[
"labels"
]
)
if
config
[
"use_pure_fp16"
]:
self
.
_logger
.
info
(
"Applying AMP-{}-{} ..."
.
format
(
config
[
"dtype"
],
config
[
'level'
]
),
)
if
config
[
'level'
]
==
"o1"
:
auto_parallel_amp_pass
=
new_pass
(
"auto_parallel_amp"
,
config
)
auto_parallel_amp_pass
.
apply
(
[
main_program
],
[
startup_program
],
self
.
_pass_context
)
loss
=
auto_parallel_amp_pass
.
get_loss
()
elif
config
[
'level'
]
in
[
'o2'
,
'o3'
]:
config
[
"base_opt"
]
=
optimizer
auto_parallel_fp16_pass
=
new_pass
(
"auto_parallel_fp16"
,
config
)
auto_parallel_fp16_pass
.
apply
(
[
main_program
],
[
startup_program
],
self
.
_pass_context
)
loss
=
auto_parallel_fp16_pass
.
get_loss
()
else
:
auto_parallel_amp_pass
=
new_pass
(
"auto_parallel_amp"
,
config
)
auto_parallel_amp_pass
.
apply
(
[
main_program
],
[
startup_program
],
self
.
_pass_context
)
raise
ValueError
(
"AMP level should be one of o1, o2, o3"
)
# apply recompute pass
# recompute is then train-only optimization
...
...
python/paddle/distributed/passes/auto_parallel_amp.py
浏览文件 @
6f3c9643
...
...
@@ -18,25 +18,48 @@ from paddle.fluid import unique_name
from
.pass_base
import
PassBase
,
register_pass
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
from
paddle.fluid.data_feeder
import
check_variable_and_dtype
,
check_type
from
paddle.distributed.auto_parallel.utils
import
get_loss_op
,
set_var_dist_attr
from
paddle.distributed.auto_parallel.utils
import
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from
paddle.distributed.auto_parallel.process_group
import
get_world_process_group
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
AutoMixedPrecisionLists
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
_keep_fp32_input
,
_keep_fp32_output
,
find_op_index
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
_valid_types
,
find_true_post_op
,
find_true_prev_op
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
_is_in_black_varnames
,
_dtype_to_str
,
_rename_arg
from
paddle.distributed.auto_parallel.dist_attribute
import
OperatorDistributedAttribute
from
paddle.distributed.auto_parallel.utils
import
(
get_loss_op
,
set_var_dist_attr
,
)
from
paddle.distributed.auto_parallel.utils
import
(
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
,
)
from
paddle.distributed.auto_parallel.process_group
import
(
get_world_process_group
,
)
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
(
AutoMixedPrecisionLists
,
)
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
(
_keep_fp32_input
,
_keep_fp32_output
,
find_op_index
,
)
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
(
_valid_types
,
find_true_post_op
,
find_true_prev_op
,
)
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
(
_is_in_black_varnames
,
_dtype_to_str
,
_rename_arg
,
)
from
paddle.distributed.auto_parallel.dist_attribute
import
(
OperatorDistributedAttribute
,
)
from
..auto_parallel.utils
import
is_forward_op
,
is_backward_op
,
is_loss_op
world_process_group
=
get_world_process_group
()
class
AMPState
(
object
):
def
__init__
(
self
,
block
):
self
.
_block
=
block
self
.
_op_fp16_dict
=
{
}
# op_id --> True/False. 'True' means that the current op is in fp16 mode.
self
.
_op_fp16_dict
=
(
{}
)
# op_id --> True/False. 'True' means that the current op is in fp16 mode.
self
.
_var_name_dict
=
{}
# fwd_op_id --> {old_name: cast_name}
self
.
is_train
=
False
...
...
@@ -55,7 +78,8 @@ class AMPState(object):
elif
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Backward
):
if
op
.
desc
.
original_id
()
in
dist_op_context
.
grad_op_id_to_op_id
:
fwd_op_id
=
dist_op_context
.
grad_op_id_to_op_id
[
op
.
desc
.
original_id
()]
op
.
desc
.
original_id
()
]
if
self
.
_is_fp16_op
(
fwd_op_id
)
==
True
:
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
True
elif
self
.
_is_fp16_op
(
fwd_op_id
)
==
False
:
...
...
@@ -78,7 +102,8 @@ class AMPState(object):
if
op
.
type
==
'create_py_reader'
or
op
.
type
==
'read'
:
continue
if
amp_lists
.
black_varnames
is
not
None
and
_is_in_black_varnames
(
op
,
amp_lists
):
op
,
amp_lists
):
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
False
continue
if
op
.
type
in
amp_lists
.
black_list
:
...
...
@@ -98,17 +123,24 @@ class AMPState(object):
continue
elif
in_var
.
op
is
op
:
prev_op
=
find_true_prev_op
(
ops
,
op
,
in_var_name
)
ops
,
op
,
in_var_name
)
if
prev_op
is
None
:
continue
else
:
prev_op
=
in_var
.
op
# if it's one of inputs
if
self
.
_is_fp16_op
(
prev_op
.
desc
.
original_id
())
==
False
or
\
prev_op
.
type
in
amp_lists
.
black_list
:
if
(
self
.
_is_fp16_op
(
prev_op
.
desc
.
original_id
())
==
False
or
prev_op
.
type
in
amp_lists
.
black_list
):
is_black_op
=
True
elif
self
.
_is_fp16_op
(
prev_op
.
desc
.
original_id
())
==
True
or
\
prev_op
.
type
in
amp_lists
.
white_list
:
elif
(
self
.
_is_fp16_op
(
prev_op
.
desc
.
original_id
())
==
True
or
prev_op
.
type
in
amp_lists
.
white_list
):
is_white_op
=
True
if
is_black_op
:
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
False
...
...
@@ -131,19 +163,28 @@ class AMPState(object):
break
if
self
.
_is_fp16_op
(
op
.
desc
.
original_id
())
==
False
:
num_cast_ops
=
self
.
_insert_cast_op_forward
(
op
,
idx
,
core
.
VarDesc
.
VarType
.
FP16
,
core
.
VarDesc
.
VarType
.
FP32
,
dist_context
)
op
,
idx
,
core
.
VarDesc
.
VarType
.
FP16
,
core
.
VarDesc
.
VarType
.
FP32
,
dist_context
,
)
elif
self
.
_is_fp16_op
(
op
.
desc
.
original_id
())
==
True
:
num_cast_ops
=
self
.
_insert_cast_op_forward
(
op
,
idx
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP16
,
dist_context
)
op
,
idx
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP16
,
dist_context
,
)
else
:
pass
idx
+=
num_cast_ops
+
1
self
.
_block
.
_sync_with_cpp
()
def
_insert_cast_op_forward
(
self
,
op
,
idx
,
src_dtype
,
dst_dtype
,
dist_context
):
def
_insert_cast_op_forward
(
self
,
op
,
idx
,
src_dtype
,
dst_dtype
,
dist_context
):
"""
only for forward cast
modified from paddle.fluid.contrib.mixed_precision
...
...
@@ -152,38 +193,45 @@ class AMPState(object):
var_name_dict
=
{}
for
in_name
in
op
.
input_names
:
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
_keep_fp32_input
(
op
,
in_name
):
op
,
in_name
):
continue
for
in_var_name
in
op
.
input
(
in_name
):
in_var
=
self
.
_block
.
_find_var_recursive
(
in_var_name
)
if
in_var
.
type
not
in
_valid_types
or
in_var
.
dtype
==
dst_dtype
:
continue
if
in_var
.
dtype
==
src_dtype
:
cast_name
=
in_var
.
name
+
'.cast_'
+
_dtype_to_str
(
dst_dtype
)
cast_name
=
(
in_var
.
name
+
'.cast_'
+
_dtype_to_str
(
dst_dtype
)
)
out_var
=
self
.
_block
.
vars
.
get
(
cast_name
)
var_name_dict
[
in_var
.
name
]
=
cast_name
consume_op_attr
=
dist_context
.
get_op_dist_attr_for_program
(
op
)
op
)
assert
consume_op_attr
is
not
None
if
out_var
is
None
or
out_var
.
dtype
!=
dst_dtype
:
# NOTE we make the cast op and var's dist attr as the op that consume the
# cast var instead of the op which generates the var
in_var_dist_attr
=
consume_op_attr
.
get_input_dist_attr
(
in_var
.
name
)
in_var
.
name
)
assert
in_var_dist_attr
is
not
None
ref_mesh
=
in_var_dist_attr
.
process_mesh
ref_mapping
=
in_var_dist_attr
.
dims_mapping
consume_op_attr
.
set_input_dist_attr
(
cast_name
,
in_var_dist_attr
)
cast_name
,
in_var_dist_attr
)
out_var
=
self
.
_block
.
create_var
(
name
=
cast_name
,
dtype
=
dst_dtype
,
persistable
=
False
,
stop_gradient
=
in_var
.
stop_gradient
)
set_var_dist_attr
(
dist_context
,
out_var
,
ref_mapping
,
ref_mesh
)
stop_gradient
=
in_var
.
stop_gradient
,
)
set_var_dist_attr
(
dist_context
,
out_var
,
ref_mapping
,
ref_mesh
)
cast_op
=
self
.
_block
.
_insert_op_without_sync
(
idx
,
...
...
@@ -193,22 +241,29 @@ class AMPState(object):
attrs
=
{
"in_dtype"
:
in_var
.
dtype
,
"out_dtype"
:
out_var
.
dtype
,
})
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
cast_op
,
ref_mesh
,
ref_mapping
,
dist_context
)
cast_op
,
ref_mesh
,
ref_mapping
,
dist_context
)
num_cast_ops
+=
1
else
:
in_var_dist_attr
=
consume_op_attr
.
get_input_dist_attr
(
in_var
.
name
)
in_var
.
name
)
consume_op_attr
.
set_input_dist_attr
(
cast_name
,
in_var_dist_attr
)
cast_name
,
in_var_dist_attr
)
_rename_arg
(
op
,
in_var
.
name
,
cast_name
)
else
:
if
op
.
has_attr
(
'in_dtype'
):
op
.
_set_attr
(
'in_dtype'
,
dst_dtype
)
self
.
_var_name_dict
[
op
.
desc
.
original_id
()]
=
var_name_dict
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
dst_dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
(
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
dst_dtype
==
core
.
VarDesc
.
VarType
.
FP16
):
for
out_name
in
op
.
output_names
:
if
_keep_fp32_output
(
op
,
out_name
):
continue
...
...
@@ -238,8 +293,9 @@ class AMPState(object):
# NOTE: the map in `grad_var_to_var` may be changed when the var is casted,
# which will affect the dist_op to insert allreduce_sum op.
op_dist_attr
=
dist_context
.
get_op_dist_attr_for_program
(
grad_op
)
if
is_backward_op
(
grad_op
)
and
(
is_forward_op
(
ops
[
idx
-
1
])
or
is_loss_op
(
ops
[
idx
-
1
])):
if
is_backward_op
(
grad_op
)
and
(
is_forward_op
(
ops
[
idx
-
1
])
or
is_loss_op
(
ops
[
idx
-
1
])
):
if
not
op_dist_attr
.
is_recompute
:
appended_grad_times
+=
1
...
...
@@ -248,14 +304,22 @@ class AMPState(object):
if
grad_op_orig_id
in
dist_op_context
.
grad_op_id_to_op_id
:
if
self
.
_is_fp16_op
(
grad_op_orig_id
)
==
False
:
# fp32
num_cast_ops
=
self
.
_insert_cast_op_backward
(
grad_op
,
idx
,
core
.
VarDesc
.
VarType
.
FP16
,
core
.
VarDesc
.
VarType
.
FP32
,
dist_context
,
appended_grad_times
)
grad_op
,
idx
,
core
.
VarDesc
.
VarType
.
FP16
,
core
.
VarDesc
.
VarType
.
FP32
,
dist_context
,
appended_grad_times
,
)
elif
self
.
_is_fp16_op
(
grad_op_orig_id
)
==
True
:
# fp16
num_cast_ops
=
self
.
_insert_cast_op_backward
(
grad_op
,
idx
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP16
,
dist_context
,
appended_grad_times
)
grad_op
,
idx
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP16
,
dist_context
,
appended_grad_times
,
)
elif
grad_op
.
type
==
"sum"
:
in_var_name
=
grad_op
.
desc
.
input_arg_names
()[
0
]
src_dtype
=
self
.
_block
.
var
(
in_var_name
).
dtype
...
...
@@ -270,15 +334,24 @@ class AMPState(object):
else
:
raise
ValueError
(
"'{}' op is not supported in the complete amp pass."
.
format
(
grad_op
.
type
))
grad_op
.
type
)
)
idx
+=
num_cast_ops
+
1
self
.
_block
.
_sync_with_cpp
()
_update_backward_cast_ops
(
params_grads
,
dist_context
)
def
_insert_cast_op_backward
(
self
,
grad_op
,
idx
,
src_dtype
,
dst_dtype
,
dist_context
,
appended_grad_times
):
""" only for backward cast """
def
_insert_cast_op_backward
(
self
,
grad_op
,
idx
,
src_dtype
,
dst_dtype
,
dist_context
,
appended_grad_times
,
):
"""only for backward cast"""
def
_keep_fp32_input
(
op
,
in_name
):
op_type
=
op
.
type
...
...
@@ -299,7 +372,8 @@ class AMPState(object):
for
in_name
in
grad_op
.
input_names
:
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
_keep_fp32_input
(
grad_op
,
in_name
):
grad_op
,
in_name
):
for
in_var_name
in
grad_op
.
input
(
in_name
):
in_var
=
self
.
_block
.
_find_var_recursive
(
in_var_name
)
assert
in_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
...
...
@@ -309,24 +383,34 @@ class AMPState(object):
in_var
=
self
.
_block
.
_find_var_recursive
(
in_var_name
)
if
in_var
.
dtype
==
src_dtype
:
consume_op_attr
=
dist_context
.
get_op_dist_attr_for_program
(
grad_op
)
grad_op
)
if
in_var_name
in
self
.
_var_name_dict
[
fwd_op_id
]:
# NOTE: if in_var of consume grad_op has been casted before,
# it should be renamed and reset dist_attr.
cast_name
=
self
.
_var_name_dict
[
fwd_op_id
][
in_var_name
]
grad_op
.
desc
.
_rename_input
(
in_var_name
,
cast_name
)
in_var_dist_attr
=
consume_op_attr
.
get_input_dist_attr
(
in_var_name
)
in_var_name
)
consume_op_attr
.
set_input_dist_attr
(
cast_name
,
in_var_dist_attr
)
cast_name
,
in_var_dist_attr
)
else
:
assert
in_var
.
dtype
==
dst_dtype
,
"op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}"
.
format
(
grad_op
.
type
,
in_name
,
dst_dtype
,
in_var
.
dtype
,
str
(
grad_op
))
assert
(
in_var
.
dtype
==
dst_dtype
),
"op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}"
.
format
(
grad_op
.
type
,
in_name
,
dst_dtype
,
in_var
.
dtype
,
str
(
grad_op
),
)
for
out_name
in
grad_op
.
output_names
:
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
_keep_fp32_output
(
grad_op
,
out_name
):
grad_op
,
out_name
):
for
out_var_name
in
grad_op
.
output
(
out_name
):
out_var
=
self
.
_block
.
_find_var_recursive
(
out_var_name
)
assert
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
...
...
@@ -334,7 +418,7 @@ class AMPState(object):
for
out_var_name
in
grad_op
.
output
(
out_name
):
out_var
=
self
.
_block
.
_find_var_recursive
(
out_var_name
)
out_var_name_prefix
=
out_var_name
[:
out_var_name
.
find
(
"@"
)]
out_var_name_prefix
=
out_var_name
[:
out_var_name
.
find
(
"@"
)]
fwd_var
=
self
.
_block
.
_find_var_recursive
(
out_var_name_prefix
)
# NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype
if
out_var
.
dtype
!=
fwd_var
.
dtype
:
...
...
@@ -345,34 +429,45 @@ class AMPState(object):
# NOTE: if out_var of consume grad_op has been casted before,
# it should be renamed and reset dist_attr, then we insert cast op to
# convert the cast_var to original dtype
consume_op_attr
=
dist_context
.
get_op_dist_attr_for_program
(
grad_op
)
consume_op_attr
=
(
dist_context
.
get_op_dist_attr_for_program
(
grad_op
)
)
fwd_cast_name
=
self
.
_var_name_dict
[
fwd_op_id
][
out_var_name_prefix
]
out_var_name_prefix
]
suffix
=
""
if
"@RENAME"
in
out_var_name
:
suffix
=
out_var_name
[
out_var_name
.
find
(
"@RENAME"
):]
suffix
=
out_var_name
[
out_var_name
.
find
(
"@RENAME"
)
:
]
cast_name
=
fwd_cast_name
+
"@GRAD"
+
suffix
cast_var
=
self
.
_block
.
vars
.
get
(
cast_name
)
if
cast_var
is
None
or
cast_var
.
dtype
!=
dst_dtype
:
grad_op
.
desc
.
_rename_output
(
out_var_name
,
cast_name
)
out_var_dist_attr
=
consume_op_attr
.
get_output_dist_attr
(
out_var_name
)
out_var_dist_attr
=
(
consume_op_attr
.
get_output_dist_attr
(
out_var_name
)
)
ref_mesh
=
out_var_dist_attr
.
process_mesh
ref_mapping
=
out_var_dist_attr
.
dims_mapping
consume_op_attr
.
set_output_dist_attr
(
cast_name
,
out_var_dist_attr
)
cast_name
,
out_var_dist_attr
)
assert
ref_mapping
is
not
None
cast_var
=
self
.
_block
.
create_var
(
name
=
cast_name
,
shape
=
out_var
.
shape
,
dtype
=
dst_dtype
,
persistable
=
False
,
stop_gradient
=
out_var
.
stop_gradient
)
set_var_dist_attr
(
dist_context
,
cast_var
,
ref_mapping
,
ref_mesh
)
stop_gradient
=
out_var
.
stop_gradient
,
)
set_var_dist_attr
(
dist_context
,
cast_var
,
ref_mapping
,
ref_mesh
)
dist_op_context
.
grad_var_to_var
[
appended_grad_times
][
cast_name
]
=
fwd_cast_name
appended_grad_times
][
cast_name
]
=
fwd_cast_name
cast_op
=
self
.
_block
.
_insert_op
(
idx
+
1
,
...
...
@@ -382,13 +477,15 @@ class AMPState(object):
attrs
=
{
"in_dtype"
:
cast_var
.
dtype
,
"out_dtype"
:
out_var
.
dtype
,
"op_role"
:
OpRole
.
Backward
})
"op_role"
:
OpRole
.
Backward
,
},
)
cast_op
.
_remove_attr
(
"op_role_var"
)
cast_op
.
_remove_attr
(
"op_namescope"
)
cast_op
.
_remove_attr
(
"with_quant_attr"
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
cast_op
,
ref_mesh
,
ref_mapping
,
dist_context
)
cast_op
,
ref_mesh
,
ref_mapping
,
dist_context
)
num_cast_ops
+=
1
else
:
assert
out_var
.
dtype
==
dst_dtype
...
...
@@ -409,15 +506,18 @@ def _update_backward_cast_ops(params_grads, dist_context):
for
p
,
g
in
params_grads
:
op
=
g
.
op
if
g
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
op
.
type
==
'cast'
:
if
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Backward
)
and
op
.
has_attr
(
'op_role_var'
):
if
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Backward
)
and
op
.
has_attr
(
'op_role_var'
):
op
.
_remove_attr
(
"op_role_var"
)
post_ops
=
find_true_post_op
(
main_block
.
ops
,
op
,
g
.
name
)
if
post_ops
:
raise
ValueError
(
"The cast op {0}'s output should not be"
raise
ValueError
(
"The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"is used by {1}"
.
format
(
op
,
post_ops
[
0
]))
"is used by {1}"
.
format
(
op
,
post_ops
[
0
])
)
if
op
==
main_block
.
ops
[
-
1
]:
continue
...
...
@@ -425,23 +525,29 @@ def _update_backward_cast_ops(params_grads, dist_context):
# add new op in the python and cpp at the same time
new_op_desc
=
main_block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
op
.
desc
)
new_op
=
paddle
.
fluid
.
framework
.
Operator
(
block
=
main_block
,
new_op
=
paddle
.
fluid
.
framework
.
Operator
(
block
=
main_block
,
desc
=
new_op_desc
,
type
=
None
,
inputs
=
None
,
outputs
=
None
,
attrs
=
None
)
attrs
=
None
,
)
main_block
.
ops
.
append
(
new_op
)
# dist attr
param_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
p
)
output_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
main_block
.
var
(
op
.
output_arg_names
[
0
]))
main_block
.
var
(
op
.
output_arg_names
[
0
])
)
assert
param_dist_attr
is
not
None
assert
output_dist_attr
is
not
None
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
new_op
,
param_dist_attr
.
process_mesh
,
param_dist_attr
.
dims_mapping
,
dist_context
)
new_op
,
param_dist_attr
.
process_mesh
,
param_dist_attr
.
dims_mapping
,
dist_context
,
)
output_dist_attr
.
process_mesh
=
param_dist_attr
.
process_mesh
output_dist_attr
.
dims_mapping
=
param_dist_attr
.
dims_mapping
...
...
@@ -462,26 +568,34 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
grads
=
[
g
for
_
,
g
in
params_grads
]
check_type
(
grads
,
'x'
,
(
tuple
,
list
),
'check_finite_and_unscale'
)
for
e
in
grads
:
check_variable_and_dtype
(
e
,
"x"
,
[
'float16'
,
'float32'
,
'float64'
],
'check_finite_and_unscale'
)
check_variable_and_dtype
(
e
,
"x"
,
[
'float16'
,
'float32'
,
'float64'
],
'check_finite_and_unscale'
,
)
found_inf
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
[
'find_infinite_scale'
,
'tmp'
])),
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
([
'find_infinite_scale'
,
'tmp'
])
),
shape
=
[
1
],
dtype
=
'bool'
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
stop_gradient
=
False
)
stop_gradient
=
False
,
)
set_var_dist_attr
(
dist_context
,
found_inf
,
[
-
1
],
world_process_group
.
ranks
)
inputs
=
{
'X'
:
grads
,
'Scale'
:
loss_scaling
}
outputs
=
{
'Out'
:
grads
,
'FoundInfinite'
:
found_inf
}
attrs
=
{
'op_role'
:
OpRole
.
Optimize
}
new_op
=
main_block
.
append_op
(
type
=
'check_finite_and_unscale'
,
new_op
=
main_block
.
append_op
(
type
=
'check_finite_and_unscale'
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
)
attrs
=
attrs
,
)
new_op_dist_attr
=
OperatorDistributedAttribute
()
new_op_dist_attr
.
process_mesh
=
world_process_group
.
ranks
...
...
@@ -491,17 +605,18 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
for
g
in
grads
:
g_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
g
)
assert
g_dist_attr
is
not
None
new_op_dist_attr
.
set_input_dims_mapping
(
g
.
name
,
g_dist_attr
.
dims_mapping
)
new_op_dist_attr
.
set_output_dims_mapping
(
g
.
name
,
g_dist_attr
.
dims_mapping
)
new_op_dist_attr
.
set_input_dims_mapping
(
g
.
name
,
g_dist_attr
.
dims_mapping
)
new_op_dist_attr
.
set_output_dims_mapping
(
g
.
name
,
g_dist_attr
.
dims_mapping
)
dist_context
.
set_op_dist_attr_for_program
(
new_op
,
new_op_dist_attr
)
return
grads
,
found_inf
@
register_pass
(
"auto_parallel_amp"
)
class
AMPPass
(
PassBase
):
def
__init__
(
self
):
super
(
AMPPass
,
self
).
__init__
()
self
.
set_attr
(
"loss"
,
None
)
...
...
@@ -517,6 +632,7 @@ class AMPPass(PassBase):
self
.
set_attr
(
"use_dynamic_loss_scaling"
,
False
)
self
.
set_attr
(
"input_data"
,
[])
self
.
set_attr
(
"params_grads"
,
[])
self
.
set_attr
(
"dtype"
,
""
)
# fp16/bf16
self
.
_loss
=
None
self
.
_loss_scaling
=
None
self
.
_num_good_steps
=
None
...
...
@@ -524,6 +640,8 @@ class AMPPass(PassBase):
self
.
_loss
=
None
def
_check_self
(
self
):
if
self
.
get_attr
(
"dtype"
)
not
in
[
"float16"
,
"bfloat16"
]:
return
False
if
self
.
get_attr
(
"init_loss_scaling"
)
<
0
:
return
False
if
self
.
get_attr
(
"incr_every_n_steps"
)
<
0
:
...
...
@@ -548,11 +666,13 @@ class AMPPass(PassBase):
def
_apply_single_impl
(
self
,
main_program
,
startup_program
,
context
):
self
.
dist_context
=
self
.
get_attr
(
"dist_context"
)
params_grads
=
self
.
get_attr
(
"params_grads"
)
self
.
amp_dtype
=
self
.
get_attr
(
"dtype"
)
amp_lists
=
AutoMixedPrecisionLists
(
set
(
self
.
get_attr
(
"custom_white_list"
)),
set
(
self
.
get_attr
(
"custom_black_list"
)),
set
(
self
.
get_attr
(
"custom_black_varnames"
)))
set
(
self
.
get_attr
(
"custom_black_varnames"
)),
)
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
amp_state
=
AMPState
(
main_program
.
global_block
())
...
...
@@ -566,10 +686,13 @@ class AMPPass(PassBase):
self
.
_init_amp_var
()
self
.
_scale_loss
()
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
)
or
self
.
get_attr
(
"init_loss_scaling"
)
!=
1.0
:
if
(
self
.
get_attr
(
"use_dynamic_loss_scaling"
)
or
self
.
get_attr
(
"init_loss_scaling"
)
!=
1.0
):
grads
,
found_inf
=
_check_and_update_gradient
(
params_grads
,
self
.
_loss_scaling
,
self
.
dist_context
)
params_grads
,
self
.
_loss_scaling
,
self
.
dist_context
)
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
):
self
.
_update_loss_scaling
(
grads
,
found_inf
)
...
...
@@ -580,9 +703,14 @@ class AMPPass(PassBase):
shape
=
[
1
],
value
=
self
.
get_attr
(
"init_loss_scaling"
),
dtype
=
'float32'
,
persistable
=
True
)
set_var_dist_attr
(
self
.
dist_context
,
self
.
_loss_scaling
,
[
-
1
],
world_process_group
.
ranks
)
persistable
=
True
,
)
set_var_dist_attr
(
self
.
dist_context
,
self
.
_loss_scaling
,
[
-
1
],
world_process_group
.
ranks
,
)
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
):
self
.
_num_good_steps
=
paddle
.
static
.
create_global_var
(
...
...
@@ -590,18 +718,28 @@ class AMPPass(PassBase):
shape
=
[
1
],
value
=
0
,
dtype
=
'int32'
,
persistable
=
True
)
set_var_dist_attr
(
self
.
dist_context
,
self
.
_num_good_steps
,
[
-
1
],
world_process_group
.
ranks
)
persistable
=
True
,
)
set_var_dist_attr
(
self
.
dist_context
,
self
.
_num_good_steps
,
[
-
1
],
world_process_group
.
ranks
,
)
self
.
_num_bad_steps
=
paddle
.
static
.
create_global_var
(
name
=
unique_name
.
generate
(
"num_bad_steps"
),
shape
=
[
1
],
value
=
0
,
dtype
=
'int32'
,
persistable
=
True
)
set_var_dist_attr
(
self
.
dist_context
,
self
.
_num_bad_steps
,
[
-
1
],
world_process_group
.
ranks
)
persistable
=
True
,
)
set_var_dist_attr
(
self
.
dist_context
,
self
.
_num_bad_steps
,
[
-
1
],
world_process_group
.
ranks
,
)
def
_scale_loss
(
self
):
...
...
@@ -613,7 +751,8 @@ class AMPPass(PassBase):
assert
loss
is
not
None
loss_op
=
loss
.
op
loss_op_dist_attr
=
self
.
dist_context
.
get_op_dist_attr_for_program
(
loss_op
)
loss_op
)
if
loss
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP32
:
# cast loss here will change the effective loss tensor for the computation graph
...
...
@@ -626,10 +765,12 @@ class AMPPass(PassBase):
tmp_name
=
unique_name
.
generate
(
loss
.
name
+
".cast_fp32"
)
cast_loss
=
main_block
.
create_var
(
name
=
tmp_name
,
dtype
=
dtype
)
loss_dist_attr
=
self
.
dist_context
.
get_tensor_dist_attr_for_program
(
loss
)
loss
)
ref_mesh
=
loss_op_dist_attr
.
process_mesh
self
.
dist_context
.
set_tensor_dist_attr_for_program
(
cast_loss
,
loss_dist_attr
)
cast_loss
,
loss_dist_attr
)
loss_op_idx
=
find_op_index
(
main_block
.
desc
,
loss_op
.
desc
)
cast_op
=
main_block
.
_insert_op
(
...
...
@@ -641,16 +782,21 @@ class AMPPass(PassBase):
"in_dtype"
:
loss
.
dtype
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
,
'op_role'
:
loss_op
.
all_attrs
()[
OP_ROLE_KEY
],
})
},
)
loss_op
.
_set_attr
(
OP_ROLE_KEY
,
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
loss_op
.
_set_attr
(
OP_ROLE_KEY
,
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
cast_op
,
ref_mesh
,
[
-
1
],
self
.
dist_context
)
cast_op
,
ref_mesh
,
[
-
1
],
self
.
dist_context
)
loss
=
loss
.
astype
(
'float32'
)
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
)
or
self
.
get_attr
(
"init_loss_scaling"
)
!=
1.0
:
if
self
.
amp_dtype
==
"float16"
and
(
self
.
get_attr
(
"use_dynamic_loss_scaling"
)
or
self
.
get_attr
(
"init_loss_scaling"
)
!=
1.0
):
loss_op_idx
=
find_op_index
(
main_block
.
desc
,
loss_op
.
desc
)
...
...
@@ -660,63 +806,76 @@ class AMPPass(PassBase):
name
=
unique_name
.
generate
(
"scaled_loss"
),
shape
=
loss
.
shape
,
dtype
=
loss
.
dtype
,
persistable
=
loss
.
persistable
)
set_var_dist_attr
(
self
.
dist_context
,
self
.
_scaled_loss
,
[
-
1
],
ref_mesh
)
persistable
=
loss
.
persistable
,
)
set_var_dist_attr
(
self
.
dist_context
,
self
.
_scaled_loss
,
[
-
1
],
ref_mesh
)
elementwise_mul_op
=
main_block
.
_insert_op
(
loss_op_idx
+
1
,
type
=
'elementwise_mul'
,
inputs
=
{
'X'
:
[
loss
],
'Y'
:
[
self
.
_loss_scaling
]
},
inputs
=
{
'X'
:
[
loss
],
'Y'
:
[
self
.
_loss_scaling
]},
outputs
=
{
'Out'
:
[
self
.
_scaled_loss
]},
attrs
=
{
'op_role'
:
loss_op
.
all_attrs
()[
OP_ROLE_KEY
],
})
loss_op
.
_set_attr
(
OP_ROLE_KEY
,
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
},
)
loss_op
.
_set_attr
(
OP_ROLE_KEY
,
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
elementwise_mul_op
,
ref_mesh
,
[
-
1
],
self
.
dist_context
)
elementwise_mul_op
,
ref_mesh
,
[
-
1
],
self
.
dist_context
)
# backward
first_backward_op
=
main_block
.
ops
[
loss_op_idx
+
2
]
assert
first_backward_op
.
type
==
"fill_constant"
and
int
(
first_backward_op
.
all_attrs
()[
OP_ROLE_KEY
])
==
257
assert
(
first_backward_op
.
type
==
"fill_constant"
and
int
(
first_backward_op
.
all_attrs
()[
OP_ROLE_KEY
])
==
257
)
self
.
_scaled_loss_grad
=
main_block
.
create_var
(
name
=
unique_name
.
generate
(
"scaled_loss"
)
+
"@GRAD"
,
shape
=
loss
.
shape
,
dtype
=
loss
.
dtype
,
persistable
=
loss
.
persistable
)
set_var_dist_attr
(
self
.
dist_context
,
self
.
_scaled_loss_grad
,
[
-
1
],
ref_mesh
)
persistable
=
loss
.
persistable
,
)
set_var_dist_attr
(
self
.
dist_context
,
self
.
_scaled_loss_grad
,
[
-
1
],
ref_mesh
)
pre_grad_name
=
first_backward_op
.
output_arg_names
[
0
]
first_backward_op
.
_rename_output
(
pre_grad_name
,
self
.
_scaled_loss_grad
.
name
)
first_backward_op
.
_rename_output
(
pre_grad_name
,
self
.
_scaled_loss_grad
.
name
)
# FIXME(JZ-LIANG) a trick to insert backward op
main_block
.
_sync_with_cpp
()
elementwise_mul_grad_op_desc
=
main_block
.
desc
.
_insert_op
(
loss_op_idx
+
3
)
loss_op_idx
+
3
)
elementwise_mul_grad_op_desc
.
set_type
(
"elementwise_mul_grad"
)
elementwise_mul_grad_op_desc
.
set_input
(
'Out@GRAD'
,
[
self
.
_scaled_loss_grad
.
name
])
'Out@GRAD'
,
[
self
.
_scaled_loss_grad
.
name
]
)
elementwise_mul_grad_op_desc
.
set_input
(
'X'
,
[
loss
.
name
])
elementwise_mul_grad_op_desc
.
set_input
(
'Y'
,
[
self
.
_loss_scaling
.
name
])
elementwise_mul_grad_op_desc
.
set_input
(
'Y'
,
[
self
.
_loss_scaling
.
name
]
)
elementwise_mul_grad_op_desc
.
set_output
(
'X@GRAD'
,
[
pre_grad_name
])
elementwise_mul_grad_op_desc
.
set_output
(
'Y@GRAD'
,
[])
elementwise_mul_grad_op_desc
.
_set_attr
(
OP_ROLE_KEY
,
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
)
OP_ROLE_KEY
,
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
)
elementwise_mul_grad_op_desc
.
_set_attr
(
'axis'
,
-
1
)
elementwise_mul_grad_op
=
paddle
.
fluid
.
framework
.
Operator
(
main_block
,
elementwise_mul_grad_op_desc
)
main_block
,
elementwise_mul_grad_op_desc
)
main_block
.
ops
.
insert
(
loss_op_idx
+
3
,
elementwise_mul_grad_op
)
main_block
.
_sync_with_cpp
()
elementwise_mul_grad_op
=
main_block
.
ops
[
loss_op_idx
+
3
]
assert
elementwise_mul_grad_op
.
type
==
"elementwise_mul_grad"
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
elementwise_mul_grad_op
,
ref_mesh
,
[
-
1
],
self
.
dist_context
)
elementwise_mul_grad_op
,
ref_mesh
,
[
-
1
],
self
.
dist_context
)
else
:
self
.
_scaled_loss
=
loss
...
...
@@ -728,31 +887,39 @@ class AMPPass(PassBase):
main_block
=
paddle
.
static
.
default_main_program
().
global_block
()
main_block
.
_sync_with_cpp
()
check_variable_and_dtype
(
self
.
_loss_scaling
,
"prev_loss_scaling"
,
[
'float32'
,
'float64'
],
"update_loss_scaling"
)
check_variable_and_dtype
(
self
.
_loss_scaling
,
"prev_loss_scaling"
,
[
'float32'
,
'float64'
],
"update_loss_scaling"
,
)
check_type
(
grads
,
'x'
,
(
tuple
,
list
),
'update_loss_scaling'
)
for
e
in
grads
:
check_variable_and_dtype
(
e
,
"x"
,
[
'float16'
,
'float32'
,
'float64'
],
'update_loss_scaling'
)
check_variable_and_dtype
(
e
,
"x"
,
[
'float16'
,
'float32'
,
'float64'
],
'update_loss_scaling'
)
if
e
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
assert
self
.
_loss_scaling
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
,
\
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
assert
(
self
.
_loss_scaling
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
),
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
else
:
assert
self
.
_loss_scaling
.
dtype
==
e
.
dtype
,
"The dtype of prev_loss_scaling should be equal to the dtype of x."
assert
(
self
.
_loss_scaling
.
dtype
==
e
.
dtype
),
"The dtype of prev_loss_scaling should be equal to the dtype of x."
inputs
=
{
'X'
:
grads
,
'FoundInfinite'
:
found_inf
,
'PrevLossScaling'
:
self
.
_loss_scaling
,
'InGoodSteps'
:
self
.
_num_good_steps
,
'InBadSteps'
:
self
.
_num_bad_steps
'InBadSteps'
:
self
.
_num_bad_steps
,
}
outputs
=
{
'Out'
:
grads
,
'LossScaling'
:
self
.
_loss_scaling
,
'OutGoodSteps'
:
self
.
_num_good_steps
,
'OutBadSteps'
:
self
.
_num_bad_steps
'OutBadSteps'
:
self
.
_num_bad_steps
,
}
attrs
=
{
...
...
@@ -761,13 +928,15 @@ class AMPPass(PassBase):
'incr_ratio'
:
self
.
get_attr
(
"incr_ratio"
),
'decr_ratio'
:
self
.
get_attr
(
"decr_ratio"
),
'stop_update'
:
self
.
get_attr
(
"stop_update"
),
'op_role'
:
OpRole
.
Optimize
'op_role'
:
OpRole
.
Optimize
,
}
new_op
=
main_block
.
append_op
(
type
=
'update_loss_scaling'
,
new_op
=
main_block
.
append_op
(
type
=
'update_loss_scaling'
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
)
attrs
=
attrs
,
)
new_op_dist_attr
=
OperatorDistributedAttribute
()
new_op_dist_attr
.
process_mesh
=
world_process_group
.
ranks
...
...
@@ -777,10 +946,22 @@ class AMPPass(PassBase):
for
g
in
grads
:
g_dist_attr
=
self
.
dist_context
.
get_tensor_dist_attr_for_program
(
g
)
assert
g_dist_attr
is
not
None
new_op_dist_attr
.
set_input_dims_mapping
(
g
.
name
,
g_dist_attr
.
dims_mapping
)
new_op_dist_attr
.
set_output_dims_mapping
(
g
.
name
,
g_dist_attr
.
dims_mapping
)
new_op_dist_attr
.
set_input_dims_mapping
(
g
.
name
,
g_dist_attr
.
dims_mapping
)
new_op_dist_attr
.
set_output_dims_mapping
(
g
.
name
,
g_dist_attr
.
dims_mapping
)
self
.
dist_context
.
set_op_dist_attr_for_program
(
new_op
,
new_op_dist_attr
)
main_block
.
_sync_with_cpp
()
def
get_loss
(
self
):
# the amp might change the effective loss variable for network and
# therefore would affect the subsequent passes that rely on the loss.
# return the effective loss after amp pass.
if
self
.
_loss
:
return
self
.
_loss
else
:
return
self
.
get_attr
(
"loss"
)
python/paddle/distributed/passes/auto_parallel_fp16.py
浏览文件 @
6f3c9643
...
...
@@ -27,14 +27,13 @@ from paddle.distributed.auto_parallel.utils import (
from
paddle.distributed.auto_parallel.process_group
import
(
get_world_process_group
,
)
from
paddle.fluid.contrib.mixed_precision.fp16_
util
s
import
(
from
paddle.fluid.contrib.mixed_precision.fp16_
list
s
import
(
AutoMixedPrecisionLists
,
)
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
(
_keep_layer_norm_scale_bias_to_fp32
,
_need_keep_fp32
,
_valid_types
,
_dtype_to_str
,
)
from
paddle.distributed.auto_parallel.dist_attribute
import
(
OperatorDistributedAttribute
,
...
...
@@ -55,6 +54,23 @@ __amp_skip_ops__ = [
'while'
,
'cast'
,
]
__target_dtype__
=
None
def
_dtype_to_str
(
dtype
):
"""
Convert specific variable type to its corresponding string.
Args:
dtype (VarType): Variable type.
"""
if
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
# TODO(Xreki): change the returned str to "bf16" for BF16 data type.
# Currently too many codes use "cast_fp16" as key.
return
'fp16'
elif
dtype
==
core
.
VarDesc
.
VarType
.
BF16
:
return
'bf16'
else
:
return
'fp32'
def
set_op_dtype_to_fp16
(
op
):
...
...
@@ -62,14 +78,20 @@ def set_op_dtype_to_fp16(op):
op
.
has_attr
(
'in_dtype'
)
and
op
.
attr
(
'in_dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
):
op
.
_set_attr
(
'in_dtype'
,
core
.
VarDesc
.
VarType
.
FP16
)
op
.
_set_attr
(
'in_dtype'
,
__target_dtype__
)
if
(
op
.
has_attr
(
'out_dtype'
)
and
op
.
attr
(
'out_dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
):
op
.
_set_attr
(
'out_dtype'
,
core
.
VarDesc
.
VarType
.
FP16
)
op
.
_set_attr
(
'out_dtype'
,
__target_dtype__
)
if
op
.
has_attr
(
'dtype'
)
and
op
.
attr
(
'dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
:
op
.
_set_attr
(
'dtype'
,
core
.
VarDesc
.
VarType
.
FP16
)
op
.
_set_attr
(
'dtype'
,
__target_dtype__
)
if
__target_dtype__
==
core
.
VarDesc
.
VarType
.
BF16
:
if
op
.
has_attr
(
'use_mkldnn'
):
op
.
_set_attr
(
'use_mkldnn'
,
True
)
if
op
.
has_attr
(
'mkldnn_data_type'
):
op
.
_set_attr
(
'mkldnn_data_type'
,
'bfloat16'
)
# adapot for backward op
...
...
@@ -156,6 +178,7 @@ class FP16State(object):
list
)
# {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]}
self
.
is_train
=
False
self
.
out_var_op_deps
=
{}
def
_is_fp16_op
(
self
,
op_id
):
return
self
.
_op_fp16_dict
.
get
(
op_id
,
None
)
...
...
@@ -169,6 +192,14 @@ class FP16State(object):
# assume all backward block are behind forward blocks
for
block
in
self
.
program
.
blocks
:
for
op
in
block
.
ops
:
for
name
in
op
.
output_arg_names
:
if
name
not
in
self
.
out_var_op_deps
:
self
.
out_var_op_deps
[
name
]
=
[
op
.
desc
.
original_id
()]
else
:
self
.
out_var_op_deps
[
name
].
extend
(
[
op
.
desc
.
original_id
()]
)
self
.
_mark_op
(
op
)
# set forward tensor dtype
...
...
@@ -192,6 +223,18 @@ class FP16State(object):
if
op
.
type
==
"assign"
and
"array_"
in
op
.
input_arg_names
[
0
]:
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
False
return
# If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var.
if
op
.
type
==
"assign"
:
out_name
=
op
.
output_arg_names
[
0
]
if
len
(
self
.
out_var_op_deps
[
out_name
])
>
1
:
if
not
self
.
_op_fp16_dict
[
self
.
out_var_op_deps
[
out_name
][
0
]
]:
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
False
else
:
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
True
return
if
_need_keep_fp32
(
op
,
self
.
amp_list
.
unsupported_list
,
self
.
use_fp16_guard
):
...
...
@@ -228,7 +271,7 @@ class FP16State(object):
return
if
var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP16
)
var
.
desc
.
set_dtype
(
__target_dtype__
)
def
resolute_tensor_dtype
(
self
,
block
):
...
...
@@ -260,7 +303,7 @@ class FP16State(object):
out_var
=
block
.
vars
.
get
(
out_var_name
)
if
out_var
is
None
or
out_var
.
type
not
in
_valid_types
:
continue
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
out_var
.
dtype
==
__target_dtype__
:
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
elif
is_backward_op
(
op
):
if
self
.
_is_fp16_op
(
op
.
desc
.
original_id
())
==
True
:
...
...
@@ -276,7 +319,7 @@ class FP16State(object):
out_var
=
block
.
vars
.
get
(
out_var_name
)
if
out_var
is
None
or
out_var
.
type
not
in
_valid_types
:
continue
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
out_var
.
dtype
==
__target_dtype__
:
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
def
cast_block
(
self
,
block
):
...
...
@@ -295,7 +338,7 @@ class FP16State(object):
op
,
idx
,
block
,
core
.
VarDesc
.
VarType
.
FP16
,
__target_dtype__
,
core
.
VarDesc
.
VarType
.
FP32
,
self
.
dist_context
,
)
...
...
@@ -305,7 +348,7 @@ class FP16State(object):
idx
,
block
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP16
,
__target_dtype__
,
self
.
dist_context
,
)
elif
is_backward_op
(
op
):
...
...
@@ -315,7 +358,7 @@ class FP16State(object):
op
,
idx
,
block
,
core
.
VarDesc
.
VarType
.
FP16
,
__target_dtype__
,
core
.
VarDesc
.
VarType
.
FP32
,
self
.
dist_context
,
)
...
...
@@ -325,7 +368,7 @@ class FP16State(object):
idx
,
block
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP16
,
__target_dtype__
,
self
.
dist_context
,
)
elif
op
.
type
==
"sum"
:
...
...
@@ -399,6 +442,9 @@ class FP16State(object):
dist_context
,
cast_var
,
ref_mapping
,
ref_mesh
)
op_namescope
=
"/"
if
op
.
has_attr
(
'op_namescope'
):
op_namescope
=
op
.
attr
(
'op_namescope'
)
cast_op
=
block
.
_insert_op_without_sync
(
idx
,
type
=
"cast"
,
...
...
@@ -410,6 +456,9 @@ class FP16State(object):
OP_ROLE_KEY
:
OpRole
.
Forward
,
},
)
cast_op
.
_set_attr
(
'op_namescope'
,
op_namescope
)
# for recompute
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
cast_op
,
ref_mesh
,
ref_mapping
,
dist_context
)
...
...
@@ -455,22 +504,36 @@ class FP16State(object):
)
in
self
.
forward_input_cast_ops
[
forward_op_id
]:
# rename input
# some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy
if
op
.
type
!=
"scale"
and
slot_name
in
op
.
input_names
:
assert
src_name
in
op
.
input
(
slot_name
),
"var: {} not in op's {}. {}"
.
format
(
src_name
,
slot_name
,
str
(
op
))
),
"var: {} not in op's {}. {}"
.
format
(
src_name
,
slot_name
,
str
(
op
)
)
src_var_dist_attr
=
grad_op_attr
.
get_input_dist_attr
(
src_name
)
assert
src_var_dist_attr
is
not
None
op
.
_rename_input
(
src_name
,
cast_name
)
grad_op_attr
.
set_input_dist_attr
(
cast_name
,
src_var_dist_attr
)
# NOTE Special for scale op, scale op's grad op is scale,
# so slot name map rule could not apply to grad scale op
# cast_name: mean_0.tmp_0.cast_bf16, src_name: mean_0.tmp_0, dst_dtype: paddle.bfloat16, src_dtype: paddle.float32, slot_name: X.
if
op
.
type
==
"scale"
:
grad_slot_name
=
"X"
# create cast grad
else
:
grad_slot_name
=
slot_name
+
"@GRAD"
assert
grad_slot_name
in
op
.
output_names
if
grad_slot_name
in
op
.
output_names
:
# some forward input maybe stop_gradient=True, e.g. input_mask
if
len
(
op
.
output
(
grad_slot_name
))
==
0
:
var
=
block
.
var
(
src_name
)
assert
var
.
stop_gradient
is
True
continue
assert
len
(
op
.
output
(
grad_slot_name
))
==
1
assert
(
len
(
op
.
output
(
grad_slot_name
))
==
1
),
"[{}], Current Op: {}"
.
format
(
grad_slot_name
,
str
(
op
))
grad_name
=
op
.
output
(
grad_slot_name
)[
0
]
grad
=
block
.
var
(
grad_name
)
grad_dist_attr
=
grad_op_attr
.
get_output_dist_attr
(
grad_name
)
...
...
@@ -492,7 +555,9 @@ class FP16State(object):
cast_grad
,
grad_dist_attr
)
op
.
_rename_output
(
grad_name
,
cast_grad
.
name
)
grad_op_attr
.
set_output_dist_attr
(
cast_grad
.
name
,
grad_dist_attr
)
grad_op_attr
.
set_output_dist_attr
(
cast_grad
.
name
,
grad_dist_attr
)
# add cast
cast_op
=
block
.
_insert_op_without_sync
(
...
...
@@ -573,7 +638,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
def
_split_grads
(
params_grads
):
grads
=
[
g
for
_
,
g
in
params_grads
]
fp32_grads
=
[
g
for
g
in
grads
if
g
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
]
fp16_grads
=
[
g
for
g
in
grads
if
g
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
]
fp16_grads
=
[
g
for
g
in
grads
if
g
.
dtype
==
__target_dtype__
]
assert
len
(
fp32_grads
)
+
len
(
fp16_grads
)
==
len
(
grads
),
"Data types of all grads must be either fp16 or fp32."
...
...
@@ -633,17 +698,15 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
# TODO to support CUDAPinned/NPU/XPU Places
if
direction
==
"D2H"
:
dst_place_type
=
0
elif
direction
==
"D2H"
:
dst_place_type
=
1
else
:
raise
NotImplementedError
(
"direction [{}] is not supported yet."
.
format
(
direction
)
f
"direction [
{
direction
}
] is not supported yet."
)
attrs
=
{
'dst_place_type'
:
dst_place_type
}
new_op
=
block
.
_insert_op_without_sync
(
index
=
idx
,
type
=
'memcpy'
,
type
=
'memcpy
_d2h
'
,
inputs
=
{
'X'
:
[
src_var
]},
outputs
=
{
'Out'
:
[
output_var
]},
attrs
=
attrs
,
...
...
@@ -678,17 +741,17 @@ def cast_startup_program():
for
op
in
startup_program
.
global_block
().
ops
:
if
is_initialization_op
(
op
):
output_name
=
op
.
output_arg_names
[
0
]
if
(
param_to_dtype
.
get
(
output_name
,
None
)
==
core
.
VarDesc
.
VarType
.
FP16
):
if
param_to_dtype
.
get
(
output_name
,
None
)
==
__target_dtype__
:
assert
op
.
has_attr
(
'dtype'
),
"initialization op is supported to has dtype attribute but got {}."
.
format
(
str
(
op
)
)
out_var
=
startup_program
.
global_block
().
var
(
output_name
)
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
out_var
.
desc
.
set_dtype
(
__target_dtype__
)
if
op
.
attr
(
'dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
:
op
.
_set_attr
(
'dtype'
,
core
.
VarDesc
.
VarType
.
FP16
)
op
.
_set_attr
(
'dtype'
,
__target_dtype__
)
@
register_pass
(
"auto_parallel_fp16"
)
...
...
@@ -701,14 +764,44 @@ class FP16Pass(AMPPass):
# in distributed scenario, all ranks should have the same modification.
def
_apply_single_impl
(
self
,
main_program
,
startup_program
,
context
):
self
.
dist_context
=
self
.
get_attr
(
"dist_context"
)
self
.
target_dtype
=
self
.
get_attr
(
"dtype"
)
params_grads
=
self
.
get_attr
(
"params_grads"
)
self
.
use_optimizer_fp16
=
self
.
get_attr
(
"use_optimizer_fp16"
,
None
)
if
self
.
use_optimizer_fp16
is
None
:
self
.
use_optimizer_fp16
=
self
.
get_attr
(
"level"
,
None
)
==
"o3"
# swith enviroment for fp16 / bf16.
if
self
.
target_dtype
==
"float16"
:
__target_dtype
=
core
.
VarDesc
.
VarType
.
FP16
elif
self
.
target_dtype
==
"bfloat16"
:
__target_dtype
=
core
.
VarDesc
.
VarType
.
BF16
else
:
raise
NotImplementedError
(
"target dtype [{}] is for amp o2 not supported yet."
.
format
(
self
.
target_dtype
)
)
global
__target_dtype__
__target_dtype__
=
__target_dtype
amp_list
=
AutoMixedPrecisionLists
(
set
(
self
.
get_attr
(
"custom_white_list"
)),
set
(
self
.
get_attr
(
"custom_black_list"
)),
None
,
)
dtype
=
self
.
target_dtype
,
)
amp_list
.
unsupported_list
-=
{
"conditional_block_grad"
,
"conditional_block"
,
"conditional_block_infer"
,
"select_input"
,
"while"
,
"while_grad"
,
"cast"
,
"tensor_array_to_tensor"
,
"lod_array_length"
,
"write_to_array"
,
}
# NOTE don't not change input data dtype, since it is controled by dataloader
# and which is out of control of FP16 Pass
input_data_var_names
=
[
var
.
name
for
var
in
self
.
get_attr
(
"input_data"
)]
...
...
@@ -726,6 +819,7 @@ class FP16Pass(AMPPass):
cast_startup_program
()
if
is_train
:
if
self
.
target_dtype
==
"fp16"
:
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
# TODO (JZ-LIANG)support cast forward program only when inference
self
.
_init_amp_var
()
...
...
@@ -801,10 +895,12 @@ class FP16Pass(AMPPass):
# modify optimizer
base_opt
=
self
.
get_attr
(
"base_opt"
)
base_opt
.
_multi_precision
=
True
if
self
.
get_attr
(
"use_optimizer_fp16"
)
:
if
self
.
use_optimizer_fp16
:
base_opt
.
_multi_precision
=
False
if
self
.
target_dtype
==
"fp16"
:
if
isinstance
(
base_opt
,
(
paddle
.
fluid
.
optimizer
.
Adam
,
paddle
.
optimizer
.
AdamW
)
base_opt
,
(
paddle
.
fluid
.
optimizer
.
Adam
,
paddle
.
optimizer
.
AdamW
),
):
with
main_program
.
_optimized_guard
([]):
# found_inf = paddle.tensor.creation._memcpy(
...
...
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
6f3c9643
...
...
@@ -40,6 +40,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_random_ctrl MODULES test_random_ctrl ENVS
${
dist_ENVS
}
)
set_tests_properties
(
test_random_ctrl PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 50
)
py_test_modules
(
test_amp_o2_pass MODULES test_amp_o2_pass ENVS
${
dist_ENVS
}
)
set_tests_properties
(
test_amp_o2_pass PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 50
)
py_test_modules
(
test_iterable_dataset MODULES test_iterable_dataset ENVS
${
dist_ENVS
}
)
set_tests_properties
(
test_iterable_dataset
...
...
python/paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py
0 → 100644
浏览文件 @
6f3c9643
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
random
import
re
import
unittest
import
numpy
as
np
from
get_gpt_model
import
FakeDataset
,
generate_model
import
paddle
from
paddle.distributed.fleet
import
auto
from
paddle.fluid.framework
import
core
paddle
.
enable_static
()
def
get_cuda_version
():
result
=
os
.
popen
(
"nvcc --version"
).
read
()
regex
=
r
'release (\S+),'
match
=
re
.
search
(
regex
,
result
)
if
match
:
num
=
str
(
match
.
group
(
1
))
integer
,
decimal
=
num
.
split
(
'.'
)
return
int
(
integer
)
*
1000
+
int
(
float
(
decimal
)
*
10
)
else
:
return
-
1
def
apply_pass
(
use_amp
=
False
,
amp_dtype
=
"bfloat16"
):
strategy
=
auto
.
Strategy
()
strategy
.
auto_mode
=
"semi"
strategy
.
reinit
=
True
if
use_amp
:
amp
=
strategy
.
amp
amp
.
enable
=
True
amp
.
dtype
=
amp_dtype
amp
.
level
=
"o2"
amp
.
custom_black_list
=
[
'c_softmax_with_cross_entropy'
,
'elementwise_div'
,
'reduce_sum'
,
]
return
strategy
def
reset_prog
():
paddle
.
fluid
.
framework
.
switch_main_program
(
paddle
.
static
.
Program
())
paddle
.
fluid
.
framework
.
switch_startup_program
(
paddle
.
static
.
Program
())
class
TestShardingStage2WithNewEXE
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
batch_size
=
2
self
.
batch_num
=
10
self
.
clip_norm
=
0.2
self
.
dataset
=
FakeDataset
(
self
.
batch_size
*
self
.
batch_num
)
def
init
(
self
,
engine
):
paddle
.
seed
(
2022
)
np
.
random
.
seed
(
2022
)
random
.
seed
(
2022
)
place
=
paddle
.
fluid
.
CUDAPlace
(
paddle
.
distributed
.
ParallelEnv
().
dev_id
)
engine
.
_executor
=
paddle
.
static
.
Executor
(
place
)
def
get_engine
(
self
,
use_amp
=
False
,
amp_dtype
=
"bfloat16"
):
reset_prog
()
strategy
=
apply_pass
(
use_amp
,
amp_dtype
)
# clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
clip
=
None
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
model
,
loss
=
generate_model
(
"mp"
)
engine
=
auto
.
Engine
(
model
,
loss
,
opt
,
strategy
=
strategy
)
self
.
init
(
engine
)
return
engine
def
check_bf16
(
self
,
program
):
num_bf16
=
0
num_fp16
=
0
num_fp32
=
0
for
p
in
program
.
all_parameters
():
if
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
num_fp32
+=
1
if
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
num_fp16
+=
1
if
p
.
dtype
==
core
.
VarDesc
.
VarType
.
BF16
:
num_bf16
+=
1
self
.
assertEqual
(
num_bf16
,
25
)
self
.
assertEqual
(
num_fp16
,
0
)
self
.
assertEqual
(
num_fp32
,
11
)
def
test_param_grad_fuse_overlap
(
self
):
# std
mp_engine
=
self
.
get_engine
(
use_amp
=
False
)
mp_history
=
mp_engine
.
fit
(
self
.
dataset
,
3
,
epochs
=
1
,
steps_per_epoch
=
self
.
batch_num
,
log_freq
=
1
,
batch_size
=
self
.
batch_size
,
)
loss0
=
mp_history
.
history
[
'loss'
][
0
]
# bf16
mp_bf16_engine
=
self
.
get_engine
(
use_amp
=
True
)
if
not
paddle
.
is_compiled_with_cuda
()
or
get_cuda_version
()
<
11000
:
return
mp_bf16_history
=
mp_bf16_engine
.
fit
(
self
.
dataset
,
3
,
epochs
=
1
,
steps_per_epoch
=
self
.
batch_num
,
log_freq
=
1
,
batch_size
=
self
.
batch_size
,
)
loss1
=
mp_bf16_history
.
history
[
'loss'
][
0
]
np
.
testing
.
assert_allclose
(
loss0
,
loss1
,
atol
=
1e-3
,
rtol
=
1e-2
)
self
.
check_bf16
(
mp_bf16_engine
.
main_program
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
浏览文件 @
6f3c9643
...
...
@@ -38,7 +38,7 @@ def apply_pass(use_amp=False, level=None):
]
amp
.
init_loss_scaling
=
32768
amp
.
use_fp16_guard
=
False
amp
.
use_pure_fp16
=
level
in
[
"o2"
,
"o3"
]
amp
.
level
=
level
amp
.
use_optimizer_fp16
=
level
==
"o3"
print
(
"amp level: "
,
level
)
return
strategy
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py
0 → 100644
浏览文件 @
6f3c9643
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
subprocess
import
sys
import
tempfile
import
unittest
class
TestAMPO2
(
unittest
.
TestCase
):
def
test_bf16
(
self
):
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
launch_model_path
=
os
.
path
.
join
(
file_dir
,
"amp_o2_pass.py"
)
if
os
.
environ
.
get
(
"WITH_COVERAGE"
,
"OFF"
)
==
"ON"
:
coverage_args
=
[
"-m"
,
"coverage"
,
"run"
,
"--branch"
,
"-p"
]
else
:
coverage_args
=
[]
tmp_dir
=
tempfile
.
TemporaryDirectory
()
cmd
=
(
[
sys
.
executable
,
"-u"
]
+
coverage_args
+
[
"-m"
,
"paddle.distributed.launch"
,
"--devices"
,
"0,1"
,
"--log_dir"
,
tmp_dir
.
name
,
launch_model_path
,
]
)
process
=
subprocess
.
Popen
(
cmd
)
process
.
wait
()
self
.
assertEqual
(
process
.
returncode
,
0
)
tmp_dir
.
cleanup
()
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py
浏览文件 @
6f3c9643
...
...
@@ -13,13 +13,13 @@
# limitations under the License.
import
os
# import yaml
import
unittest
from
paddle.distributed.fleet
import
auto
class
TestStrategy
(
unittest
.
TestCase
):
def
test_default_config
(
self
):
strategy
=
auto
.
Strategy
()
...
...
@@ -29,6 +29,8 @@ class TestStrategy(unittest.TestCase):
amp
=
strategy
.
amp
self
.
assertEqual
(
amp
.
enable
,
False
)
self
.
assertAlmostEqual
(
amp
.
dtype
,
"float16"
)
self
.
assertAlmostEqual
(
amp
.
level
,
"o1"
)
self
.
assertAlmostEqual
(
amp
.
init_loss_scaling
,
32768.0
)
self
.
assertEqual
(
amp
.
incr_every_n_steps
,
1000
)
self
.
assertEqual
(
amp
.
decr_every_n_nan_or_inf
,
2
)
...
...
@@ -38,8 +40,7 @@ class TestStrategy(unittest.TestCase):
self
.
assertEqual
(
amp
.
custom_black_list
,
[])
self
.
assertEqual
(
amp
.
custom_white_list
,
[])
self
.
assertEqual
(
amp
.
custom_black_varnames
,
[])
self
.
assertEqual
(
amp
.
use_pure_fp16
,
False
)
self
.
assertEqual
(
amp
.
use_fp16_guard
,
True
)
self
.
assertEqual
(
amp
.
use_fp16_guard
,
False
)
self
.
assertEqual
(
amp
.
use_optimizer_fp16
,
False
)
sharding
=
strategy
.
sharding
...
...
@@ -92,7 +93,6 @@ class TestStrategy(unittest.TestCase):
amp
.
custom_white_list
=
[
"x"
]
amp
.
custom_black_list
=
[
"y"
]
amp
.
custom_black_varnames
=
[
"z"
]
amp
.
use_pure_fp16
=
True
amp
.
use_fp16_guard
=
False
amp
.
use_optimizer_fp16
=
True
self
.
assertEqual
(
amp
.
enable
,
True
)
...
...
@@ -105,7 +105,6 @@ class TestStrategy(unittest.TestCase):
self
.
assertEqual
(
amp
.
custom_white_list
,
[
"x"
])
self
.
assertEqual
(
amp
.
custom_black_list
,
[
"y"
])
self
.
assertEqual
(
amp
.
custom_black_varnames
,
[
"z"
])
self
.
assertEqual
(
amp
.
use_pure_fp16
,
True
)
self
.
assertEqual
(
amp
.
use_fp16_guard
,
False
)
self
.
assertEqual
(
amp
.
use_optimizer_fp16
,
True
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录