Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6f3c9643
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6f3c9643
编写于
4月 14, 2023
作者:
J
JZ-LIANG
提交者:
GitHub
4月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Eb118 BF16 Adoption (#52827)
* pr1 * pr2 * pr3 * fixed unitest * adopt for scale
上级
8cbc75ca
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
1878 addition
and
970 deletion
+1878
-970
python/paddle/distributed/auto_parallel/constants.py
python/paddle/distributed/auto_parallel/constants.py
+4
-2
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
...dle/distributed/auto_parallel/operators/dist_embedding.py
+7
-4
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+1049
-631
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+14
-5
python/paddle/distributed/passes/auto_parallel_amp.py
python/paddle/distributed/passes/auto_parallel_amp.py
+347
-166
python/paddle/distributed/passes/auto_parallel_fp16.py
python/paddle/distributed/passes/auto_parallel_fp16.py
+252
-156
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py
...paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py
+142
-0
python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
.../fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
+1
-1
python/paddle/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py
...e/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py
+55
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py
...ddle/fluid/tests/unittests/auto_parallel/test_strategy.py
+4
-5
未找到文件。
python/paddle/distributed/auto_parallel/constants.py
浏览文件 @
6f3c9643
...
@@ -62,6 +62,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False)
...
@@ -62,6 +62,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False)
#########################################
#########################################
AMP
=
"amp"
AMP
=
"amp"
set_field_default_config
(
AMP
,
"enable"
,
False
)
set_field_default_config
(
AMP
,
"enable"
,
False
)
set_field_default_config
(
AMP
,
"dtype"
,
"float16"
)
set_field_default_config
(
AMP
,
"level"
,
"o1"
)
set_field_default_config
(
AMP
,
"init_loss_scaling"
,
32768.0
)
set_field_default_config
(
AMP
,
"init_loss_scaling"
,
32768.0
)
set_field_default_config
(
AMP
,
"incr_every_n_steps"
,
1000
)
set_field_default_config
(
AMP
,
"incr_every_n_steps"
,
1000
)
set_field_default_config
(
AMP
,
"decr_every_n_nan_or_inf"
,
2
)
set_field_default_config
(
AMP
,
"decr_every_n_nan_or_inf"
,
2
)
...
@@ -71,8 +73,8 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
...
@@ -71,8 +73,8 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
set_field_default_config
(
AMP
,
"custom_white_list"
,
[])
set_field_default_config
(
AMP
,
"custom_white_list"
,
[])
set_field_default_config
(
AMP
,
"custom_black_list"
,
[])
set_field_default_config
(
AMP
,
"custom_black_list"
,
[])
set_field_default_config
(
AMP
,
"custom_black_varnames"
,
[])
set_field_default_config
(
AMP
,
"custom_black_varnames"
,
[])
set_field_default_config
(
AMP
,
"use_
pure_fp16
"
,
False
)
set_field_default_config
(
AMP
,
"use_
fp16_guard
"
,
False
)
set_field_default_config
(
AMP
,
"use_
fp16_guard"
,
Tru
e
)
set_field_default_config
(
AMP
,
"use_
bf16_guard"
,
Fals
e
)
set_field_default_config
(
AMP
,
"use_optimizer_fp16"
,
False
)
set_field_default_config
(
AMP
,
"use_optimizer_fp16"
,
False
)
#########################################
#########################################
...
...
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
浏览文件 @
6f3c9643
...
@@ -459,7 +459,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
...
@@ -459,7 +459,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype
(
check_variable_and_dtype
(
Out_var
,
Out_var
,
'tensor'
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'c_allreduce_sum'
,
'c_allreduce_sum'
,
)
)
...
@@ -649,7 +649,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
...
@@ -649,7 +649,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype
(
check_variable_and_dtype
(
Out_grad
,
Out_grad
,
'tensor'
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
'_c_identity'
,
)
)
...
@@ -691,12 +691,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
...
@@ -691,12 +691,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
},
},
)
)
check_variable_and_dtype
(
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
)
check_dtype
(
check_dtype
(
intermediate_var_0
.
dtype
,
intermediate_var_0
.
dtype
,
'dtype'
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
'linear'
,
)
)
...
...
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
6f3c9643
...
@@ -20,7 +20,11 @@ from .common import DistributedOperatorImpl
...
@@ -20,7 +20,11 @@ from .common import DistributedOperatorImpl
from
.common
import
register_distributed_operator_impl_container
from
.common
import
register_distributed_operator_impl_container
from
.common
import
register_distributed_operator_impl
from
.common
import
register_distributed_operator_impl
from
.common
import
gradient_synchronization
from
.common
import
gradient_synchronization
from
.common
import
set_comm_op_dist_attr_for_program
,
naive_copy_op_dist_attr_for_program
,
is_parameter_related
from
.common
import
(
set_comm_op_dist_attr_for_program
,
naive_copy_op_dist_attr_for_program
,
is_parameter_related
,
)
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_replicate
from
..utils
import
is_dim_replicate
from
..utils
import
is_valid_list_index
from
..utils
import
is_valid_list_index
...
@@ -33,24 +37,39 @@ from paddle.fluid import core, unique_name
...
@@ -33,24 +37,39 @@ from paddle.fluid import core, unique_name
from
paddle.fluid.framework
import
_non_static_mode
from
paddle.fluid.framework
import
_non_static_mode
from
paddle.fluid.framework
import
Program
,
Parameter
,
Variable
,
program_guard
from
paddle.fluid.framework
import
Program
,
Parameter
,
Variable
,
program_guard
from
paddle.fluid.data_feeder
import
check_variable_and_dtype
,
check_dtype
from
paddle.fluid.data_feeder
import
check_variable_and_dtype
,
check_dtype
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
from
paddle.distributed.fleet.meta_optimizers.common
import
(
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
,
)
from
..process_group
import
new_process_group
from
..process_group
import
new_process_group
from
..utils
import
_get_comm_group
,
_get_corresponding_rank
from
..utils
import
_get_comm_group
,
_get_corresponding_rank
from
.dist_default
import
DistributedDefaultImpl0
from
.dist_default
import
DistributedDefaultImpl0
from
..cost
import
build_comp_desc_from_dist_op
,
build_comm_desc_from_dist_op
,
build_dp_costs
from
..cost
import
(
build_comp_desc_from_dist_op
,
build_comm_desc_from_dist_op
,
build_dp_costs
,
)
from
..cost
import
build_comm_costs_from_descs
,
build_comp_costs_from_descs
from
..cost
import
build_comm_costs_from_descs
,
build_comp_costs_from_descs
from
..cost
import
MatmulV2OpCost
,
MatmulOpCost
,
MulOpCost
from
..cost
import
MatmulV2OpCost
,
MatmulOpCost
,
MulOpCost
from
..cost
import
MatmulV2GradOpCost
,
MatmulGradOpCost
,
MulGradOpCost
from
..cost
import
MatmulV2GradOpCost
,
MatmulGradOpCost
,
MulGradOpCost
from
paddle.distributed.auto_parallel.cost.comm_op_cost
import
AllreduceSumOpCost
,
IdentityOpCost
from
paddle.distributed.auto_parallel.cost.comm_op_cost
import
(
AllreduceSumOpCost
,
IdentityOpCost
,
)
def
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
):
def
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
):
if
trans_x
:
if
trans_x
:
x_dims_mapping
[
-
1
],
x_dims_mapping
[
-
2
]
=
x_dims_mapping
[
x_dims_mapping
[
-
1
],
x_dims_mapping
[
-
2
]
=
(
-
2
],
x_dims_mapping
[
-
1
]
x_dims_mapping
[
-
2
],
x_dims_mapping
[
-
1
],
)
if
trans_y
:
if
trans_y
:
y_dims_mapping
[
-
1
],
y_dims_mapping
[
-
2
]
=
y_dims_mapping
[
y_dims_mapping
[
-
1
],
y_dims_mapping
[
-
2
]
=
(
-
2
],
y_dims_mapping
[
-
1
]
y_dims_mapping
[
-
2
],
y_dims_mapping
[
-
1
],
)
def
copy_op_with_new_input_output
(
ctx
,
block
,
src_op
,
**
kwargs
):
def
copy_op_with_new_input_output
(
ctx
,
block
,
src_op
,
**
kwargs
):
...
@@ -123,13 +142,17 @@ def _update_dims_mapping_for_matmul(dist_op):
...
@@ -123,13 +142,17 @@ def _update_dims_mapping_for_matmul(dist_op):
for
i
in
range
(
new_out_dims_mapping_len
-
2
):
for
i
in
range
(
new_out_dims_mapping_len
-
2
):
broadcast_out_dims_mapping
.
append
(
out_dims_mapping
[
i
])
broadcast_out_dims_mapping
.
append
(
out_dims_mapping
[
i
])
compatible_dims_mapping
=
compute_compatible_dims_mapping
([
compatible_dims_mapping
=
compute_compatible_dims_mapping
(
broadcast_x_dims_mapping
,
broadcast_y_dims_mapping
,
[
broadcast_out_dims_mapping
broadcast_x_dims_mapping
,
])
broadcast_y_dims_mapping
,
broadcast_out_dims_mapping
,
]
)
if
compatible_dims_mapping
is
None
:
if
compatible_dims_mapping
is
None
:
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
trans_x_y_dims_mapping
(
y_dims_mapping
)
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
return
False
return
False
for
i
in
range
(
new_x_dims_mapping_len
-
2
):
for
i
in
range
(
new_x_dims_mapping_len
-
2
):
...
@@ -152,17 +175,20 @@ def _update_dims_mapping_for_matmul(dist_op):
...
@@ -152,17 +175,20 @@ def _update_dims_mapping_for_matmul(dist_op):
# The following which uses negative index can be work
# The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
dim_changed
=
compute_compatible_and_update_dim_mapping
(
dim_changed
=
compute_compatible_and_update_dim_mapping
(
[
x_dims_mapping
,
y_dims_mapping
],
[
-
1
,
-
2
])
[
x_dims_mapping
,
y_dims_mapping
],
[
-
1
,
-
2
]
)
if
dim_changed
:
if
dim_changed
:
changed
=
True
changed
=
True
dim_changed
=
compute_compatible_and_update_dim_mapping
(
dim_changed
=
compute_compatible_and_update_dim_mapping
(
[
x_dims_mapping
,
out_dims_mapping
],
[
-
2
,
-
2
])
[
x_dims_mapping
,
out_dims_mapping
],
[
-
2
,
-
2
]
)
if
dim_changed
:
if
dim_changed
:
changed
=
True
changed
=
True
dim_changed
=
compute_compatible_and_update_dim_mapping
(
dim_changed
=
compute_compatible_and_update_dim_mapping
(
[
y_dims_mapping
,
out_dims_mapping
],
[
-
1
,
-
1
])
[
y_dims_mapping
,
out_dims_mapping
],
[
-
1
,
-
1
]
)
if
dim_changed
:
if
dim_changed
:
changed
=
True
changed
=
True
...
@@ -202,7 +228,8 @@ def _is_auto_compatible_for_matmul(dist_op):
...
@@ -202,7 +228,8 @@ def _is_auto_compatible_for_matmul(dist_op):
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
out_dims_mapping
=
copy
.
deepcopy
(
out_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_output_dims_mapping
(
out_name
))
op_dist_attr
.
get_output_dims_mapping
(
out_name
)
)
x_dims_mapping_len
=
len
(
x_dims_mapping
)
x_dims_mapping_len
=
len
(
x_dims_mapping
)
y_dims_mapping_len
=
len
(
y_dims_mapping
)
y_dims_mapping_len
=
len
(
y_dims_mapping
)
out_dims_mapping_len
=
len
(
out_dims_mapping
)
out_dims_mapping_len
=
len
(
out_dims_mapping
)
...
@@ -234,22 +261,23 @@ def _is_auto_compatible_for_matmul(dist_op):
...
@@ -234,22 +261,23 @@ def _is_auto_compatible_for_matmul(dist_op):
for
i
in
range
(
out_dims_mapping_len
-
2
):
for
i
in
range
(
out_dims_mapping_len
-
2
):
broadcast_out_dims_mapping
.
append
(
out_dims_mapping
[
i
])
broadcast_out_dims_mapping
.
append
(
out_dims_mapping
[
i
])
is_same
=
((
broadcast_x_dims_mapping
==
broadcast_y_dims_mapping
)
is_same
=
(
broadcast_x_dims_mapping
==
broadcast_y_dims_mapping
)
and
(
and
(
broadcast_x_dims_mapping
==
broadcast_out_dims_mapping
))
broadcast_x_dims_mapping
==
broadcast_out_dims_mapping
)
if
not
is_same
:
if
not
is_same
:
return
False
return
False
# The following which uses negative index can be work
# The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
is_same
=
(
x_dims_mapping
[
-
1
]
==
y_dims_mapping
[
-
2
])
is_same
=
x_dims_mapping
[
-
1
]
==
y_dims_mapping
[
-
2
]
if
not
is_same
:
if
not
is_same
:
return
False
return
False
is_same
=
(
x_dims_mapping
[
-
2
]
==
out_dims_mapping
[
-
2
])
is_same
=
x_dims_mapping
[
-
2
]
==
out_dims_mapping
[
-
2
]
if
not
is_same
:
if
not
is_same
:
return
False
return
False
is_same
=
(
y_dims_mapping
[
-
1
]
==
out_dims_mapping
[
-
1
])
is_same
=
y_dims_mapping
[
-
1
]
==
out_dims_mapping
[
-
1
]
if
not
is_same
:
if
not
is_same
:
return
False
return
False
...
@@ -265,8 +293,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
...
@@ -265,8 +293,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
backward_op
=
dist_op_context
.
cur_src_op
backward_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
rank_id
=
dist_op_context
.
rank_id
dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
backward_op
)
dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
backward_op
)
assert
dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
assert
(
str
(
backward_op
))
dist_attr
is
not
None
),
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
backward_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
dist_attr
.
process_mesh
.
processes
:
if
rank_id
not
in
dist_attr
.
process_mesh
.
processes
:
...
@@ -277,22 +306,26 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
...
@@ -277,22 +306,26 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
assert
'Out@GRAD'
in
kwargs
,
"input [{}] is not given"
.
format
(
'Out@GRAD'
)
assert
'Out@GRAD'
in
kwargs
,
"input [{}] is not given"
.
format
(
'Out@GRAD'
)
assert
'Y@GRAD'
in
kwargs
,
"output [{}] is not given"
.
format
(
'Y@GRAD'
)
assert
'Y@GRAD'
in
kwargs
,
"output [{}] is not given"
.
format
(
'Y@GRAD'
)
assert
'X@GRAD'
in
kwargs
,
"output [{}] is not given"
.
format
(
'X@GRAD'
)
assert
'X@GRAD'
in
kwargs
,
"output [{}] is not given"
.
format
(
'X@GRAD'
)
assert
len
(
assert
(
len
(
kwargs
[
'Y'
])
==
1
),
"row_parallel_embedding input Ids take 1 variable but got {}"
.
format
(
kwargs
[
'Y'
]
kwargs
[
'Y'
]
)
==
1
,
"row_parallel_embedding input Ids take 1 variable but got {}"
.
format
(
)
kwargs
[
'Y'
])
assert
(
assert
len
(
len
(
kwargs
[
'X'
])
==
1
),
"row_parallel_embedding input Ids take 1 variable but got {}"
.
format
(
kwargs
[
'X'
]
kwargs
[
'X'
]
)
==
1
,
"row_parallel_embedding input Ids take 1 variable but got {}"
.
format
(
)
kwargs
[
'X'
])
assert
(
assert
len
(
len
(
kwargs
[
'Out@GRAD'
])
==
1
kwargs
[
'Out@GRAD'
]
),
"row_parallel_embedding input Ids take 1 variable but got {}"
.
format
(
)
==
1
,
"row_parallel_embedding input Ids take 1 variable but got {}"
.
format
(
kwargs
[
'Out'
]
kwargs
[
'Out'
])
)
assert
len
(
assert
(
len
(
kwargs
[
'Y@GRAD'
])
==
1
),
"row_parallel_embedding output Ids take 1 variable but got {}"
.
format
(
kwargs
[
'Y@GRAD'
]
kwargs
[
'Y@GRAD'
]
)
==
1
,
"row_parallel_embedding output Ids take 1 variable but got {}"
.
format
(
)
kwargs
[
'Y@GRAD'
])
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Y_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Y_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
...
@@ -302,7 +335,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
...
@@ -302,7 +335,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
assert
not
is_parameter_related
(
assert
not
is_parameter_related
(
X_var
.
name
,
main_block
X_var
.
name
,
main_block
),
"left operand(X) [{}] of dist matmul should not be parameter"
.
format
(
),
"left operand(X) [{}] of dist matmul should not be parameter"
.
format
(
X_var
.
name
)
X_var
.
name
)
X_var_dims_mapping
=
dist_attr
.
get_input_dims_mapping
(
X_var
.
name
)
X_var_dims_mapping
=
dist_attr
.
get_input_dims_mapping
(
X_var
.
name
)
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var
.
name
)
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var
.
name
)
...
@@ -339,28 +373,34 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
...
@@ -339,28 +373,34 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
parallel_axis
=
Y_var_dim_mapping
[
0
]
parallel_axis
=
Y_var_dim_mapping
[
0
]
check_variable_and_dtype
(
check_variable_and_dtype
(
Out_grad
,
'tensor'
,
Out_grad
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'tensor'
,
'_c_identity'
)
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
)
intermediate_var_0
=
main_block
.
create_var
(
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
name
=
unique_name
.
generate_with_ignorable_key
(
[
"c_identity"
,
'tmp'
]))
+
"@GRAD"
,
"."
.
join
([
"c_identity"
,
'tmp'
])
)
+
"@GRAD"
,
dtype
=
Out_grad
.
dtype
,
dtype
=
Out_grad
.
dtype
,
shape
=
Out_grad
.
shape
,
shape
=
Out_grad
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
persistable
=
False
,
stop_gradient
=
Out_grad
.
stop_gradient
)
stop_gradient
=
Out_grad
.
stop_gradient
,
)
# copy X_var's dist_attr to intermediate_var_0's dist_attr
# copy X_var's dist_attr to intermediate_var_0's dist_attr
out_grad_dist_attr
=
dist_attr
.
get_input_dist_attr
(
Out_grad
.
name
)
out_grad_dist_attr
=
dist_attr
.
get_input_dist_attr
(
Out_grad
.
name
)
assert
out_grad_dist_attr
is
not
None
assert
out_grad_dist_attr
is
not
None
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
ctx
.
set_tensor_dist_attr_for_program
(
out_grad_dist_attr
)
intermediate_var_0
,
out_grad_dist_attr
)
group_ranks
=
_get_comm_group
(
process_mesh_group
,
group_ranks
=
_get_comm_group
(
process_mesh_shape
,
parallel_axis
,
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
rank_id
)
)
group
=
new_process_group
(
group_ranks
)
group
=
new_process_group
(
group_ranks
)
c_identity_op
=
main_block
.
append_op
(
c_identity_op
=
main_block
.
append_op
(
type
=
'c_identity'
,
type
=
'c_identity'
,
...
@@ -371,20 +411,29 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
...
@@ -371,20 +411,29 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
,
OP_ROLE_KEY
:
OpRole
.
Backward
,
})
},
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
)
[
'float16'
,
'float32'
,
'float64'
],
check_variable_and_dtype
(
'linear'
)
intermediate_var_0
,
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
set_comm_op_dist_attr_for_program
(
c_identity_op
,
'linear'
,
dist_attr
.
process_mesh
,
)
out_grad_dist_attr
,
ctx
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
set_comm_op_dist_attr_for_program
(
c_identity_op
,
dist_attr
.
process_mesh
,
out_grad_dist_attr
,
ctx
)
new_kwargs
=
copy
.
deepcopy
(
kwargs
)
new_kwargs
=
copy
.
deepcopy
(
kwargs
)
new_kwargs
[
'Out@GRAD'
]
=
[
intermediate_var_0
.
name
]
new_kwargs
[
'Out@GRAD'
]
=
[
intermediate_var_0
.
name
]
matmul_op_desc
=
copy_op_with_new_input_output
(
matmul_op_desc
=
copy_op_with_new_input_output
(
ctx
,
main_block
,
backward_op
,
**
new_kwargs
)
ctx
,
main_block
,
backward_op
,
**
new_kwargs
)
else
:
else
:
# col parallel: matmul + allreduce
# col parallel: matmul + allreduce
assert
Y_var_dim_mapping
[
0
]
<
0
assert
Y_var_dim_mapping
[
0
]
<
0
...
@@ -397,28 +446,36 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
...
@@ -397,28 +446,36 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
assert
len
(
kwargs
[
'X@GRAD'
])
==
1
assert
len
(
kwargs
[
'X@GRAD'
])
==
1
X_grad
=
main_block
.
var
(
kwargs
[
'X@GRAD'
][
0
])
X_grad
=
main_block
.
var
(
kwargs
[
'X@GRAD'
][
0
])
intermediate_var_0
=
main_block
.
create_var
(
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
name
=
unique_name
.
generate_with_ignorable_key
(
[
"c_identity"
,
'tmp'
]))
+
"@GRAD"
,
"."
.
join
([
"c_identity"
,
'tmp'
])
)
+
"@GRAD"
,
dtype
=
X_grad
.
dtype
,
dtype
=
X_grad
.
dtype
,
shape
=
X_grad
.
shape
,
shape
=
X_grad
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
persistable
=
False
,
stop_gradient
=
X_grad
.
stop_gradient
)
stop_gradient
=
X_grad
.
stop_gradient
,
)
X_grad_dist_attr
=
dist_attr
.
get_output_dist_attr
(
X_grad
.
name
)
X_grad_dist_attr
=
dist_attr
.
get_output_dist_attr
(
X_grad
.
name
)
assert
X_grad_dist_attr
is
not
None
assert
X_grad_dist_attr
is
not
None
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
ctx
.
set_tensor_dist_attr_for_program
(
X_grad_dist_attr
)
intermediate_var_0
,
X_grad_dist_attr
)
new_kwargs
[
'X@GRAD'
]
=
[
intermediate_var_0
.
name
]
new_kwargs
[
'X@GRAD'
]
=
[
intermediate_var_0
.
name
]
matmul_op_desc
=
copy_op_with_new_input_output
(
matmul_op_desc
=
copy_op_with_new_input_output
(
ctx
,
main_block
,
backward_op
,
**
new_kwargs
)
ctx
,
main_block
,
backward_op
,
**
new_kwargs
)
# NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad
# NOTE (JZ-LIANG) trick to skip one allreduce if left operand has not grad
if
has_x_grad
:
if
has_x_grad
:
group_ranks
=
_get_comm_group
(
process_mesh_group
,
group_ranks
=
_get_comm_group
(
process_mesh_shape
,
parallel_axis
,
process_mesh_group
,
rank_id
)
process_mesh_shape
,
parallel_axis
,
rank_id
,
)
group
=
new_process_group
(
group_ranks
)
group
=
new_process_group
(
group_ranks
)
c_allreduce_sum_op
=
main_block
.
append_op
(
c_allreduce_sum_op
=
main_block
.
append_op
(
type
=
'c_allreduce_sum'
,
type
=
'c_allreduce_sum'
,
...
@@ -428,15 +485,20 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
...
@@ -428,15 +485,20 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
'ring_id'
:
group
.
id
,
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
OP_ROLE_KEY
:
OpRole
.
Backward
,
})
},
set_comm_op_dist_attr_for_program
(
c_allreduce_sum_op
,
)
dist_attr
.
process_mesh
,
set_comm_op_dist_attr_for_program
(
X_grad_dist_attr
,
ctx
)
c_allreduce_sum_op
,
dist_attr
.
process_mesh
,
X_grad_dist_attr
,
ctx
,
)
else
:
else
:
# replicate
# replicate
matmul_op_desc
=
copy_op_with_new_input_output
(
ctx
,
main_block
,
matmul_op_desc
=
copy_op_with_new_input_output
(
backward_op
,
**
kwargs
)
ctx
,
main_block
,
backward_op
,
**
kwargs
)
# data parallel gradient synchronization
# data parallel gradient synchronization
act_grad_names
=
[
X_var
.
name
]
act_grad_names
=
[
X_var
.
name
]
...
@@ -448,8 +510,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
...
@@ -448,8 +510,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
if
trans_x
:
if
trans_x
:
trans_x_y_dims_mapping
(
True
,
False
,
X_var_dims_mapping
,
None
)
trans_x_y_dims_mapping
(
True
,
False
,
X_var_dims_mapping
,
None
)
gradient_synchronization
(
ctx
,
backward_op
,
act_grad_names
,
out_grad_names
,
gradient_synchronization
(
rank_id
)
ctx
,
backward_op
,
act_grad_names
,
out_grad_names
,
rank_id
)
if
trans_x
:
if
trans_x
:
trans_x_y_dims_mapping
(
True
,
False
,
X_var_dims_mapping
,
None
)
trans_x_y_dims_mapping
(
True
,
False
,
X_var_dims_mapping
,
None
)
...
@@ -472,23 +535,25 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
...
@@ -472,23 +535,25 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
if
size
<=
1
or
axis
in
dim_mapping
:
if
size
<=
1
or
axis
in
dim_mapping
:
pass
pass
else
:
else
:
group_ranks
=
_get_comm_group
(
process_mesh
.
processes
,
group_ranks
=
_get_comm_group
(
process_mesh
.
topology
,
axis
,
rank_id
)
process_mesh
.
processes
,
process_mesh
.
topology
,
axis
,
rank_id
)
sync_group
=
new_process_group
(
group_ranks
)
sync_group
=
new_process_group
(
group_ranks
)
startup_block
.
append_op
(
type
=
'c_broadcast'
,
startup_block
.
append_op
(
inputs
=
{
'X'
:
param
},
type
=
'c_broadcast'
,
outputs
=
{
'Out'
:
param
},
inputs
=
{
'X'
:
param
},
attrs
=
{
outputs
=
{
'Out'
:
param
},
'ring_id'
:
sync_group
.
id
,
attrs
=
{
'root'
:
0
,
'ring_id'
:
sync_group
.
id
,
'use_calc_stream'
:
True
,
'root'
:
0
,
OP_ROLE_KEY
:
OpRole
.
Forward
'use_calc_stream'
:
True
,
})
OP_ROLE_KEY
:
OpRole
.
Forward
,
},
)
class
DistributedMatmul
(
DistributedOperatorImplContainer
):
class
DistributedMatmul
(
DistributedOperatorImplContainer
):
def
__init__
(
self
,
op_type
):
def
__init__
(
self
,
op_type
):
super
(
DistributedMatmul
,
self
).
__init__
(
op_type
)
super
(
DistributedMatmul
,
self
).
__init__
(
op_type
)
...
@@ -498,7 +563,6 @@ register_distributed_operator_impl_container(DistributedMatmul("matmul"))
...
@@ -498,7 +563,6 @@ register_distributed_operator_impl_container(DistributedMatmul("matmul"))
# ColumnParallel
# ColumnParallel
class
DistributedMatmulImpl0
(
DistributedOperatorImpl
):
class
DistributedMatmulImpl0
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
def
__init__
(
self
,
name
):
super
(
DistributedMatmulImpl0
,
self
).
__init__
(
name
)
super
(
DistributedMatmulImpl0
,
self
).
__init__
(
name
)
self
.
_forward_implemented
=
True
self
.
_forward_implemented
=
True
...
@@ -521,7 +585,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -521,7 +585,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
])
backward_op
.
input
(
"Y"
)[
0
]
)
# col parallel: matmul + allreduce
# col parallel: matmul + allreduce
assert
Y_var_dim_mapping
[
0
]
<
0
assert
Y_var_dim_mapping
[
0
]
<
0
parallel_axis
=
Y_var_dim_mapping
[
1
]
parallel_axis
=
Y_var_dim_mapping
[
1
]
...
@@ -531,13 +596,14 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -531,13 +596,14 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
assert
len
(
backward_op
.
output
(
"X@GRAD"
))
==
1
assert
len
(
backward_op
.
output
(
"X@GRAD"
))
==
1
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
)
process_mesh
=
dist_attr
.
process_mesh
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
processes
=
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulGradOpCost
,
ctx
,
cost_mapping
=
build_comp_costs_from_descs
(
processes
,
desc_mapping
,
MatmulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
cluster
)
)
res
.
append
(
cost_mapping
)
res
.
append
(
cost_mapping
)
# calc comm op cost
# calc comm op cost
...
@@ -550,40 +616,52 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -550,40 +616,52 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
ctx
,
ctx
,
var_names
,
var_names
,
attrs
=
attrs
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
comm_op_cost_list
=
build_comm_costs_from_descs
(
AllreduceSumOpCost
,
ctx
,
processes
,
AllreduceSumOpCost
,
c_allreduce_sum_desc_mapping
,
cluster
)
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
,
)
res
.
append
(
comm_op_cost_list
)
res
.
append
(
comm_op_cost_list
)
# need gradient allreduce
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
if
(
batch_size_axis
]
>
1
and
is_parameter_related
(
batch_size_axis
>
-
1
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
build_dp_costs
(
cluster
)
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulOpCost
,
ctx
,
processes
,
cost_mapping
=
build_comp_costs_from_descs
(
desc_mapping
,
cluster
)
MatmulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
# calc comm op cost
# calc comm op cost
serial_op
=
dist_op
.
serial_op
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
])[
-
1
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
1
]
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
var_names
=
serial_op
.
input
(
"X"
)
var_names
=
serial_op
.
input
(
"X"
)
c_identity_desc_mapping
=
build_comm_desc_from_dist_op
(
c_identity_desc_mapping
=
build_comm_desc_from_dist_op
(
...
@@ -592,10 +670,12 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -592,10 +670,12 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
ctx
,
ctx
,
var_names
,
var_names
,
attrs
=
attrs
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
comm_op_cost_list
=
build_comm_costs_from_descs
(
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
res_cost
=
[
comm_op_cost_list
,
cost_mapping
]
res_cost
=
[
comm_op_cost_list
,
cost_mapping
]
return
res_cost
return
res_cost
...
@@ -606,16 +686,19 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -606,16 +686,19 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
x_name
=
op_desc
.
input
(
'X'
)[
0
]
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
copy
.
deepcopy
(
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
)
y_dims_mapping
=
copy
.
deepcopy
(
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
)
trans_x
=
op_desc
.
attr
(
'transpose_X'
)
trans_x
=
op_desc
.
attr
(
'transpose_X'
)
trans_y
=
op_desc
.
attr
(
'transpose_Y'
)
trans_y
=
op_desc
.
attr
(
'transpose_Y'
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
2
])
or
is_dim_replicate
(
if
is_dim_shard
(
y_dims_mapping
[
-
2
])
or
is_dim_replicate
(
y_dims_mapping
[
-
1
]):
y_dims_mapping
[
-
1
]
):
return
False
return
False
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
if
is_dim_shard
(
mapping
):
if
is_dim_shard
(
mapping
):
...
@@ -635,8 +718,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -635,8 +718,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
return
True
return
True
def
is_auto_compatible
(
self
,
dist_op
):
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
(
not
self
.
is_output_compatible
(
dist_op
)):
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
return
False
...
@@ -661,28 +745,33 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -661,28 +745,33 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
src_op
=
dist_op_context
.
cur_src_op
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
assert
op_dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
assert
(
str
(
src_op
))
op_dist_attr
is
not
None
),
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
=
_get_corresponding_rank
(
rank_id
)
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
# check validation of inputs / outputs
# check validation of inputs / outputs
for
input_name
in
src_op
.
desc
.
input_names
():
for
input_name
in
src_op
.
desc
.
input_names
():
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
input_name
)
input_name
)
assert
len
(
kwargs
[
input_name
])
==
len
(
assert
len
(
kwargs
[
input_name
])
==
len
(
src_op
.
desc
.
input
(
input_name
)
src_op
.
desc
.
input
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
for
output_name
in
src_op
.
desc
.
output_names
():
for
output_name
in
src_op
.
desc
.
output_names
():
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
output_name
)
output_name
)
assert
len
(
kwargs
[
output_name
])
==
len
(
assert
len
(
kwargs
[
output_name
])
==
len
(
src_op
.
desc
.
output
(
output_name
)
src_op
.
desc
.
output
(
output_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
),
"number of tensor for input [{}] is not match"
.
format
(
output_name
)
output_name
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
var
(
kwargs
[
'Y'
][
0
])
Weight_var
=
main_block
.
var
(
kwargs
[
'Y'
][
0
])
...
@@ -692,18 +781,24 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -692,18 +781,24 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
# TODO infer logic comm presentation
# TODO infer logic comm presentation
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
Weight_var
.
name
)[
-
1
]
if
trans_y
:
if
trans_y
:
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
Weight_var
.
name
assert
matmul_col_dim_mapping
>=
0
,
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
)[
-
2
]
matmul_col_dim_mapping
)
assert
(
matmul_col_dim_mapping
>=
0
),
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_col_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
parallel_axis
=
matmul_col_dim_mapping
parallel_axis
=
matmul_col_dim_mapping
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
group_ranks
=
_get_comm_group
(
parallel_axis
,
rank_id
)
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group
=
new_process_group
(
group_ranks
)
group
=
new_process_group
(
group_ranks
)
# infer new var shape with op dist attr
# infer new var shape with op dist attr
...
@@ -711,31 +806,39 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -711,31 +806,39 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
assert
x_tensor_dist_attr
is
not
None
assert
x_tensor_dist_attr
is
not
None
identity_var_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
X_var
.
name
)
identity_var_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
X_var
.
name
)
assert
identity_var_dist_attr
is
not
None
assert
identity_var_dist_attr
is
not
None
ref_shape_x
=
infer_shape
(
main_block
,
X_var
,
x_tensor_dist_attr
,
ref_shape_x
=
infer_shape
(
identity_var_dist_attr
)
main_block
,
X_var
,
x_tensor_dist_attr
,
identity_var_dist_attr
)
# infer out var shape with op dist attr
# infer out var shape with op dist attr
out_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
Out_var
)
out_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
Out_var
)
assert
out_tensor_dist_attr
is
not
None
assert
out_tensor_dist_attr
is
not
None
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
out_var_dist_attr
is
not
None
assert
out_var_dist_attr
is
not
None
ref_shape_out
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
ref_shape_out
=
infer_shape
(
out_var_dist_attr
)
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
intermediate_var_0
=
main_block
.
create_var
(
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
name
=
unique_name
.
generate_with_ignorable_key
(
[
"c_identity"
,
'tmp'
])),
"."
.
join
([
"c_identity"
,
'tmp'
])
),
dtype
=
X_var
.
dtype
,
dtype
=
X_var
.
dtype
,
shape
=
X_var
.
shape
,
shape
=
X_var
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
persistable
=
False
,
stop_gradient
=
X_var
.
stop_gradient
)
stop_gradient
=
X_var
.
stop_gradient
,
)
# set intermediate_var_0's dist_attr with X_var's dist_attr
# set intermediate_var_0's dist_attr with X_var's dist_attr
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
ctx
.
set_tensor_dist_attr_for_program
(
identity_var_dist_attr
)
intermediate_var_0
,
identity_var_dist_attr
)
check_variable_and_dtype
(
check_variable_and_dtype
(
X_var
,
'tensor'
,
X_var
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'_c_identity'
)
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
)
c_identity_op
=
main_block
.
append_op
(
c_identity_op
=
main_block
.
append_op
(
type
=
'c_identity'
,
type
=
'c_identity'
,
...
@@ -745,26 +848,34 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -745,26 +848,34 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
),
})
},
)
if
intermediate_var_0
.
shape
!=
ref_shape_x
:
if
intermediate_var_0
.
shape
!=
ref_shape_x
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
check_variable_and_dtype
(
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
intermediate_var_0
,
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
attrs
=
{
attrs
=
{
'transpose_X'
:
trans_x
,
'transpose_X'
:
trans_x
,
'transpose_Y'
:
trans_y
,
'transpose_Y'
:
trans_y
,
'alpha'
:
1
,
'alpha'
:
1
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
,
}
}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
matmul_op
=
main_block
.
append_op
(
type
=
'matmul'
,
matmul_op
=
main_block
.
append_op
(
inputs
=
inputs
,
type
=
'matmul'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
Out_var
},
attrs
=
attrs
outputs
=
{
'Out'
:
Out_var
},
)
attrs
=
attrs
)
if
Out_var
.
shape
!=
ref_shape_out
:
if
Out_var
.
shape
!=
ref_shape_out
:
Out_var
.
desc
.
set_shape
(
ref_shape_out
)
Out_var
.
desc
.
set_shape
(
ref_shape_out
)
...
@@ -778,13 +889,16 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -778,13 +889,16 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
input_varname
=
c_identity_op
.
desc
.
input_arg_names
()[
0
]
input_varname
=
c_identity_op
.
desc
.
input_arg_names
()[
0
]
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
identity_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
)
input_dist_attr
)
identity_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
# output
# output
output_varname
=
c_identity_op
.
desc
.
output_arg_names
()[
0
]
output_varname
=
c_identity_op
.
desc
.
output_arg_names
()[
0
]
identity_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
identity_op_dist_attr
.
set_output_dist_attr
(
input_dist_attr
)
output_varname
,
input_dist_attr
)
# set op dist attr
# set op dist attr
ctx
.
set_op_dist_attr_for_program
(
c_identity_op
,
identity_op_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
c_identity_op
,
identity_op_dist_attr
)
...
@@ -797,31 +911,39 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -797,31 +911,39 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
for
input_varname
in
matmul_op
.
desc
.
input_arg_names
():
for
input_varname
in
matmul_op
.
desc
.
input_arg_names
():
if
input_varname
in
src_op
.
desc
.
input_arg_names
():
if
input_varname
in
src_op
.
desc
.
input_arg_names
():
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
matmul_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
)
input_dist_attr
)
matmul_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
else
:
else
:
input_var
=
main_block
.
var
(
input_varname
)
input_var
=
main_block
.
var
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
input_var
matmul_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
)
tensor_dist_attr
)
matmul_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
tensor_dist_attr
)
# output
# output
output_varname
=
matmul_op
.
desc
.
output_arg_names
()[
0
]
output_varname
=
matmul_op
.
desc
.
output_arg_names
()[
0
]
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
matmul_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
)
output_dist_attr
)
matmul_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
# set op dist attr
# set op dist attr
ctx
.
set_op_dist_attr_for_program
(
matmul_op
,
matmul_op_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
matmul_op
,
matmul_op_dist_attr
)
# init param sync
# init param sync
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
_init_param_sync
(
rank_id
)
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
def
backward
(
ctx
,
*
args
,
**
kwargs
):
...
@@ -830,7 +952,6 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
...
@@ -830,7 +952,6 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
# RowParallel
# RowParallel
class
DistributedMatmulImpl1
(
DistributedOperatorImpl
):
class
DistributedMatmulImpl1
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
def
__init__
(
self
,
name
):
super
(
DistributedMatmulImpl1
,
self
).
__init__
(
name
)
super
(
DistributedMatmulImpl1
,
self
).
__init__
(
name
)
self
.
_forward_implemented
=
True
self
.
_forward_implemented
=
True
...
@@ -853,7 +974,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -853,7 +974,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
])
backward_op
.
input
(
"Y"
)[
0
]
)
assert
Y_var_dim_mapping
[
1
]
<
0
assert
Y_var_dim_mapping
[
1
]
<
0
parallel_axis
=
Y_var_dim_mapping
[
0
]
parallel_axis
=
Y_var_dim_mapping
[
0
]
...
@@ -866,50 +988,60 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -866,50 +988,60 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
ctx
,
ctx
,
var_names
,
var_names
,
attrs
=
attrs
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
process_mesh
=
dist_attr
.
process_mesh
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
processes
=
process_mesh
.
processes
comm_op_cost_list
=
build_comm_costs_from_descs
(
comm_op_cost_list
=
build_comm_costs_from_descs
(
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
res
.
append
(
comm_op_cost_list
)
res
.
append
(
comm_op_cost_list
)
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
cost_mapping
=
build_comp_costs_from_descs
(
MatmulGradOpCost
,
ctx
,
)
processes
,
desc_mapping
,
cost_mapping
=
build_comp_costs_from_descs
(
cluster
)
MatmulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
res
.
append
(
cost_mapping
)
# need gradient allreduce
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
if
(
batch_size_axis
]
>
1
and
is_parameter_related
(
batch_size_axis
>
-
1
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
build_dp_costs
(
cluster
)
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulOpCost
,
ctx
,
processes
,
cost_mapping
=
build_comp_costs_from_descs
(
desc_mapping
,
cluster
)
MatmulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
# calc comm op cost
# calc comm op cost
serial_op
=
dist_op
.
serial_op
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
])[
-
2
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
2
]
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
var_names
=
serial_op
.
output
(
"Out"
)
var_names
=
serial_op
.
output
(
"Out"
)
...
@@ -919,11 +1051,16 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -919,11 +1051,16 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
ctx
,
ctx
,
var_names
,
var_names
,
attrs
=
attrs
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
comm_op_cost_list
=
build_comm_costs_from_descs
(
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
AllreduceSumOpCost
,
cluster
)
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
,
)
res_cost
=
[
cost_mapping
,
comm_op_cost_list
]
res_cost
=
[
cost_mapping
,
comm_op_cost_list
]
...
@@ -935,16 +1072,19 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -935,16 +1072,19 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
x_name
=
op_desc
.
input
(
'X'
)[
0
]
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
copy
.
deepcopy
(
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
)
y_dims_mapping
=
copy
.
deepcopy
(
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
)
trans_x
=
op_desc
.
attr
(
'transpose_X'
)
trans_x
=
op_desc
.
attr
(
'transpose_X'
)
trans_y
=
op_desc
.
attr
(
'transpose_Y'
)
trans_y
=
op_desc
.
attr
(
'transpose_Y'
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
if
is_dim_replicate
(
x_dims_mapping
[
-
1
]):
if
is_dim_replicate
(
x_dims_mapping
[
-
1
]):
return
False
return
False
if
is_dim_replicate
(
y_dims_mapping
[
-
2
])
or
is_dim_shard
(
if
is_dim_replicate
(
y_dims_mapping
[
-
2
])
or
is_dim_shard
(
y_dims_mapping
[
-
1
]):
y_dims_mapping
[
-
1
]
):
return
False
return
False
# Other dimensions must be replicate except the batch dimension
# Other dimensions must be replicate except the batch dimension
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
...
@@ -966,8 +1106,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -966,8 +1106,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
return
True
return
True
def
is_auto_compatible
(
self
,
dist_op
):
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
(
not
self
.
is_output_compatible
(
dist_op
)):
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
return
False
...
@@ -992,28 +1133,33 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -992,28 +1133,33 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
src_op
=
dist_op_context
.
cur_src_op
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
assert
op_dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
assert
(
str
(
src_op
))
op_dist_attr
is
not
None
),
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
=
_get_corresponding_rank
(
rank_id
)
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
# check validation of inputs / outputs
# check validation of inputs / outputs
for
input_name
in
src_op
.
desc
.
input_names
():
for
input_name
in
src_op
.
desc
.
input_names
():
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
input_name
)
input_name
)
assert
len
(
kwargs
[
input_name
])
==
len
(
assert
len
(
kwargs
[
input_name
])
==
len
(
src_op
.
desc
.
input
(
input_name
)
src_op
.
desc
.
input
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
for
output_name
in
src_op
.
desc
.
output_names
():
for
output_name
in
src_op
.
desc
.
output_names
():
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
output_name
)
output_name
)
assert
len
(
kwargs
[
output_name
])
==
len
(
assert
len
(
kwargs
[
output_name
])
==
len
(
src_op
.
desc
.
output
(
output_name
)
src_op
.
desc
.
output
(
output_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
),
"number of tensor for input [{}] is not match"
.
format
(
output_name
)
output_name
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
var
(
kwargs
[
'Y'
][
0
])
Weight_var
=
main_block
.
var
(
kwargs
[
'Y'
][
0
])
...
@@ -1023,29 +1169,40 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -1023,29 +1169,40 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# TODO infer logic comm presentation
# TODO infer logic comm presentation
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
Weight_var
.
name
)[
-
2
]
if
trans_y
:
if
trans_y
:
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
Weight_var
.
name
assert
matmul_row_dim_mapping
>=
0
,
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
)[
-
1
]
matmul_row_dim_mapping
)
assert
(
matmul_row_dim_mapping
>=
0
),
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_row_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
parallel_axis
=
matmul_row_dim_mapping
parallel_axis
=
matmul_row_dim_mapping
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
group_ranks
=
_get_comm_group
(
parallel_axis
,
rank_id
)
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group
=
new_process_group
(
group_ranks
)
group
=
new_process_group
(
group_ranks
)
check_variable_and_dtype
(
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
check_variable_and_dtype
(
'linear'
)
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
)
'linear'
)
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
attrs
=
{
attrs
=
{
'transpose_X'
:
trans_x
,
'transpose_X'
:
trans_x
,
'transpose_Y'
:
trans_y
,
'transpose_Y'
:
trans_y
,
'alpha'
:
1
,
'alpha'
:
1
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
,
}
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
...
@@ -1054,27 +1211,33 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -1054,27 +1211,33 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
assert
out_tensor_dist_attr
is
not
None
assert
out_tensor_dist_attr
is
not
None
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
out_var_dist_attr
is
not
None
assert
out_var_dist_attr
is
not
None
ref_shape
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
ref_shape
=
infer_shape
(
out_var_dist_attr
)
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
intermediate_var_0
=
main_block
.
create_var
(
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
name
=
unique_name
.
generate_with_ignorable_key
(
[
"c_allreduce_sum"
,
'tmp'
])),
"."
.
join
([
"c_allreduce_sum"
,
'tmp'
])
),
shape
=
Out_var
.
shape
,
shape
=
Out_var
.
shape
,
dtype
=
Out_var
.
dtype
,
dtype
=
Out_var
.
dtype
,
type
=
Out_var
.
type
,
type
=
Out_var
.
type
,
lod_level
=
Out_var
.
lod_level
,
lod_level
=
Out_var
.
lod_level
,
persistable
=
False
,
persistable
=
False
,
is_data
=
False
,
is_data
=
False
,
need_check_feed
=
Out_var
.
desc
.
need_check_feed
())
need_check_feed
=
Out_var
.
desc
.
need_check_feed
(),
)
# set intermediate_var_0's dist_attr with Out_var's dist_attr
# set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
ctx
.
set_tensor_dist_attr_for_program
(
out_var_dist_attr
)
intermediate_var_0
,
out_var_dist_attr
)
matmul_op
=
main_block
.
append_op
(
type
=
'matmul'
,
matmul_op
=
main_block
.
append_op
(
inputs
=
inputs
,
type
=
'matmul'
,
outputs
=
{
'Out'
:
intermediate_var_0
},
inputs
=
inputs
,
attrs
=
attrs
)
outputs
=
{
'Out'
:
intermediate_var_0
},
attrs
=
attrs
,
)
if
intermediate_var_0
.
shape
!=
ref_shape
:
if
intermediate_var_0
.
shape
!=
ref_shape
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape
)
intermediate_var_0
.
desc
.
set_shape
(
ref_shape
)
...
@@ -1086,8 +1249,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -1086,8 +1249,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
),
})
},
)
if
Out_var
.
shape
!=
ref_shape
:
if
Out_var
.
shape
!=
ref_shape
:
Out_var
.
desc
.
set_shape
(
ref_shape
)
Out_var
.
desc
.
set_shape
(
ref_shape
)
...
@@ -1100,15 +1264,19 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -1100,15 +1264,19 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
for
input_varname
in
matmul_op
.
desc
.
input_arg_names
():
for
input_varname
in
matmul_op
.
desc
.
input_arg_names
():
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
matmul_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
)
input_dist_attr
)
matmul_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
output_varname
=
matmul_op
.
desc
.
output_arg_names
()[
0
]
output_varname
=
matmul_op
.
desc
.
output_arg_names
()[
0
]
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
matmul_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
)
output_dist_attr
)
matmul_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
matmul_op
,
matmul_op_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
matmul_op
,
matmul_op_dist_attr
)
# allreduce
# allreduce
...
@@ -1120,21 +1288,26 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -1120,21 +1288,26 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
input_var
=
main_block
.
var
(
input_varname
)
input_var
=
main_block
.
var
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
assert
tensor_dist_attr
is
not
None
assert
tensor_dist_attr
is
not
None
allreduce_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
allreduce_op_dist_attr
.
set_input_dist_attr
(
tensor_dist_attr
)
input_varname
,
tensor_dist_attr
)
for
output_varname
in
c_allreduce_sum_op
.
desc
.
output_arg_names
():
for
output_varname
in
c_allreduce_sum_op
.
desc
.
output_arg_names
():
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
allreduce_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
)
output_dist_attr
)
allreduce_op_dist_attr
.
set_output_dist_attr
(
ctx
.
set_op_dist_attr_for_program
(
c_allreduce_sum_op
,
output_varname
,
output_dist_attr
allreduce_op_dist_attr
)
)
ctx
.
set_op_dist_attr_for_program
(
c_allreduce_sum_op
,
allreduce_op_dist_attr
)
# init param sync
# init param sync
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
_init_param_sync
(
rank_id
)
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
def
backward
(
ctx
,
*
args
,
**
kwargs
):
...
@@ -1143,7 +1316,6 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
...
@@ -1143,7 +1316,6 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# ReplicateParallel
# ReplicateParallel
class
DistributedMatmulImpl2
(
DistributedOperatorImpl
):
class
DistributedMatmulImpl2
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
def
__init__
(
self
,
name
):
super
(
DistributedMatmulImpl2
,
self
).
__init__
(
name
)
super
(
DistributedMatmulImpl2
,
self
).
__init__
(
name
)
...
@@ -1164,38 +1336,45 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
...
@@ -1164,38 +1336,45 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
vars
=
main_block
.
vars
vars
=
main_block
.
vars
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
)
process_mesh
=
dist_attr
.
process_mesh
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
processes
=
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulGradOpCost
,
ctx
,
cost_mapping
=
build_comp_costs_from_descs
(
processes
,
desc_mapping
,
MatmulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
cluster
)
)
res
.
append
(
cost_mapping
)
res
.
append
(
cost_mapping
)
# need gradient allreduce
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
if
(
batch_size_axis
]
>
1
and
is_parameter_related
(
batch_size_axis
>
-
1
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
build_dp_costs
(
cluster
)
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulOpCost
,
ctx
,
processes
,
cost_mapping
=
build_comp_costs_from_descs
(
desc_mapping
,
cluster
)
MatmulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res_cost
=
[
cost_mapping
]
res_cost
=
[
cost_mapping
]
return
res_cost
return
res_cost
...
@@ -1211,13 +1390,15 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
...
@@ -1211,13 +1390,15 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
return
False
if
is_valid_list_index
(
x_dims_mapping
,
-
2
)
and
is_dim_shard
(
if
is_valid_list_index
(
x_dims_mapping
,
-
2
)
and
is_dim_shard
(
x_dims_mapping
[
-
2
]):
x_dims_mapping
[
-
2
]
):
return
False
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
1
]):
if
is_dim_shard
(
y_dims_mapping
[
-
1
]):
return
False
return
False
if
is_valid_list_index
(
y_dims_mapping
,
-
2
)
and
is_dim_shard
(
if
is_valid_list_index
(
y_dims_mapping
,
-
2
)
and
is_dim_shard
(
y_dims_mapping
[
-
2
]):
y_dims_mapping
[
-
2
]
):
return
False
return
False
return
True
return
True
...
@@ -1231,14 +1412,16 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
...
@@ -1231,14 +1412,16 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
if
is_dim_shard
(
out_dims_mapping
[
-
1
]):
if
is_dim_shard
(
out_dims_mapping
[
-
1
]):
return
False
return
False
if
is_valid_list_index
(
out_dims_mapping
,
-
2
)
and
is_dim_shard
(
if
is_valid_list_index
(
out_dims_mapping
,
-
2
)
and
is_dim_shard
(
out_dims_mapping
[
-
2
]):
out_dims_mapping
[
-
2
]
):
return
False
return
False
return
True
return
True
def
is_auto_compatible
(
self
,
dist_op
):
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
(
not
self
.
is_output_compatible
(
dist_op
)):
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
...
@@ -1262,16 +1445,18 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
...
@@ -1262,16 +1445,18 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
_right_operand_parameter_matmul_backward
(
ctx
,
*
args
,
**
kwargs
)
_right_operand_parameter_matmul_backward
(
ctx
,
*
args
,
**
kwargs
)
register_distributed_operator_impl
(
"matmul"
,
register_distributed_operator_impl
(
DistributedMatmulImpl0
(
"column_parallel"
))
"matmul"
,
DistributedMatmulImpl0
(
"column_parallel"
)
register_distributed_operator_impl
(
"matmul"
,
)
DistributedMatmulImpl1
(
"row_parallel"
))
register_distributed_operator_impl
(
register_distributed_operator_impl
(
"matmul"
,
"matmul"
,
DistributedMatmulImpl1
(
"row_parallel"
)
DistributedMatmulImpl2
(
"replicate_parallel"
))
)
register_distributed_operator_impl
(
"matmul"
,
DistributedMatmulImpl2
(
"replicate_parallel"
)
)
class
DistributedMatmulV2
(
DistributedOperatorImplContainer
):
class
DistributedMatmulV2
(
DistributedOperatorImplContainer
):
def
__init__
(
self
,
op_type
):
def
__init__
(
self
,
op_type
):
super
(
DistributedMatmulV2
,
self
).
__init__
(
op_type
)
super
(
DistributedMatmulV2
,
self
).
__init__
(
op_type
)
...
@@ -1281,7 +1466,6 @@ register_distributed_operator_impl_container(DistributedMatmulV2("matmul_v2"))
...
@@ -1281,7 +1466,6 @@ register_distributed_operator_impl_container(DistributedMatmulV2("matmul_v2"))
# ColumnParallel
# ColumnParallel
class
DistributedMatmulV2Impl0
(
DistributedOperatorImpl
):
class
DistributedMatmulV2Impl0
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
def
__init__
(
self
,
name
):
super
(
DistributedMatmulV2Impl0
,
self
).
__init__
(
name
)
super
(
DistributedMatmulV2Impl0
,
self
).
__init__
(
name
)
self
.
_forward_implemented
=
True
self
.
_forward_implemented
=
True
...
@@ -1304,7 +1488,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1304,7 +1488,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
])
backward_op
.
input
(
"Y"
)[
0
]
)
process_mesh
=
dist_attr
.
process_mesh
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
processes
=
process_mesh
.
processes
# col parallel: matmul + allreduce
# col parallel: matmul + allreduce
...
@@ -1318,12 +1503,13 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1318,12 +1503,13 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
assert
len
(
backward_op
.
output
(
"X@GRAD"
))
==
1
assert
len
(
backward_op
.
output
(
"X@GRAD"
))
==
1
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
)
cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2GradOpCost
,
ctx
,
cost_mapping
=
build_comp_costs_from_descs
(
processes
,
desc_mapping
,
MatmulV2GradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
cluster
)
)
res
.
append
(
cost_mapping
)
res
.
append
(
cost_mapping
)
# calc comm op cost
# calc comm op cost
...
@@ -1336,45 +1522,55 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1336,45 +1522,55 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
ctx
,
ctx
,
var_names
,
var_names
,
attrs
=
attrs
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
comm_op_cost_list
=
build_comm_costs_from_descs
(
AllreduceSumOpCost
,
ctx
,
processes
,
AllreduceSumOpCost
,
c_allreduce_sum_desc_mapping
,
cluster
)
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
,
)
res
.
append
(
comm_op_cost_list
)
res
.
append
(
comm_op_cost_list
)
# need gradient allreduce
# need gradient allreduce
process_mesh
=
dist_attr
.
process_mesh
process_mesh
=
dist_attr
.
process_mesh
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
if
(
batch_size_axis
]
>
1
and
is_parameter_related
(
batch_size_axis
>
-
1
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
build_dp_costs
(
cluster
)
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
# calc comp op cost
# TODO: trans shape if trans_x or trans_y is True
# TODO: trans shape if trans_x or trans_y is True
comp_desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
comp_desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
comp_cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2OpCost
,
ctx
,
comp_cost_mapping
=
build_comp_costs_from_descs
(
processes
,
MatmulV2OpCost
,
ctx
,
processes
,
comp_desc_mapping
,
cluster
comp_desc_mapping
,
)
cluster
)
# calc comm op cost
# calc comm op cost
serial_op
=
dist_op
.
serial_op
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
])[
-
1
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
1
]
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
var_names
=
serial_op
.
input
(
"X"
)
var_names
=
serial_op
.
input
(
"X"
)
...
@@ -1384,9 +1580,11 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1384,9 +1580,11 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
ctx
,
ctx
,
var_names
,
var_names
,
attrs
=
attrs
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
comm_op_cost_list
=
build_comm_costs_from_descs
(
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
res_cost
=
[
comm_op_cost_list
,
comp_cost_mapping
]
res_cost
=
[
comm_op_cost_list
,
comp_cost_mapping
]
return
res_cost
return
res_cost
...
@@ -1397,16 +1595,19 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1397,16 +1595,19 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
x_name
=
op_desc
.
input
(
'X'
)[
0
]
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
copy
.
deepcopy
(
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
)
y_dims_mapping
=
copy
.
deepcopy
(
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
)
trans_x
=
op_desc
.
attr
(
'trans_x'
)
trans_x
=
op_desc
.
attr
(
'trans_x'
)
trans_y
=
op_desc
.
attr
(
'trans_y'
)
trans_y
=
op_desc
.
attr
(
'trans_y'
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
2
])
or
is_dim_replicate
(
if
is_dim_shard
(
y_dims_mapping
[
-
2
])
or
is_dim_replicate
(
y_dims_mapping
[
-
1
]):
y_dims_mapping
[
-
1
]
):
return
False
return
False
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
if
is_dim_shard
(
mapping
):
if
is_dim_shard
(
mapping
):
...
@@ -1426,8 +1627,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1426,8 +1627,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
return
True
return
True
def
is_auto_compatible
(
self
,
dist_op
):
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
(
not
self
.
is_output_compatible
(
dist_op
)):
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
return
False
...
@@ -1452,28 +1654,33 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1452,28 +1654,33 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
src_op
=
dist_op_context
.
cur_src_op
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
assert
op_dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
assert
(
str
(
src_op
))
op_dist_attr
is
not
None
),
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
=
_get_corresponding_rank
(
rank_id
)
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
# check validation of inputs / outputs
# check validation of inputs / outputs
for
input_name
in
src_op
.
desc
.
input_names
():
for
input_name
in
src_op
.
desc
.
input_names
():
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
input_name
)
input_name
)
assert
len
(
kwargs
[
input_name
])
==
len
(
assert
len
(
kwargs
[
input_name
])
==
len
(
src_op
.
desc
.
input
(
input_name
)
src_op
.
desc
.
input
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
for
output_name
in
src_op
.
desc
.
output_names
():
for
output_name
in
src_op
.
desc
.
output_names
():
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
output_name
)
output_name
)
assert
len
(
kwargs
[
output_name
])
==
len
(
assert
len
(
kwargs
[
output_name
])
==
len
(
src_op
.
desc
.
output
(
output_name
)
src_op
.
desc
.
output
(
output_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
),
"number of tensor for input [{}] is not match"
.
format
(
output_name
)
output_name
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
...
@@ -1483,18 +1690,24 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1483,18 +1690,24 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
# TODO infer logic comm presentation
# TODO infer logic comm presentation
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
Weight_var
.
name
)[
-
1
]
if
trans_y
:
if
trans_y
:
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
Weight_var
.
name
assert
matmul_col_dim_mapping
>=
0
,
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
)[
-
2
]
matmul_col_dim_mapping
)
assert
(
matmul_col_dim_mapping
>=
0
),
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_col_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
parallel_axis
=
matmul_col_dim_mapping
parallel_axis
=
matmul_col_dim_mapping
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
group_ranks
=
_get_comm_group
(
parallel_axis
,
rank_id
)
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group
=
new_process_group
(
group_ranks
)
group
=
new_process_group
(
group_ranks
)
# infer new var shape with op dist attr
# infer new var shape with op dist attr
...
@@ -1502,31 +1715,39 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1502,31 +1715,39 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
assert
x_tensor_dist_attr
is
not
None
assert
x_tensor_dist_attr
is
not
None
identity_var_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
X_var
.
name
)
identity_var_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
X_var
.
name
)
assert
identity_var_dist_attr
is
not
None
assert
identity_var_dist_attr
is
not
None
ref_shape_x
=
infer_shape
(
main_block
,
X_var
,
x_tensor_dist_attr
,
ref_shape_x
=
infer_shape
(
identity_var_dist_attr
)
main_block
,
X_var
,
x_tensor_dist_attr
,
identity_var_dist_attr
)
# infer out var shape with op dist attr
# infer out var shape with op dist attr
out_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
Out_var
)
out_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
Out_var
)
assert
out_tensor_dist_attr
is
not
None
assert
out_tensor_dist_attr
is
not
None
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
out_var_dist_attr
is
not
None
assert
out_var_dist_attr
is
not
None
ref_shape_out
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
ref_shape_out
=
infer_shape
(
out_var_dist_attr
)
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
intermediate_var_0
=
main_block
.
create_var
(
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
name
=
unique_name
.
generate_with_ignorable_key
(
[
"c_identity"
,
'tmp'
])),
"."
.
join
([
"c_identity"
,
'tmp'
])
),
dtype
=
X_var
.
dtype
,
dtype
=
X_var
.
dtype
,
shape
=
X_var
.
shape
,
shape
=
X_var
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
persistable
=
False
,
stop_gradient
=
X_var
.
stop_gradient
)
stop_gradient
=
X_var
.
stop_gradient
,
)
# set intermediate_var_0's dist_attr with X_var's dist_attr
# set intermediate_var_0's dist_attr with X_var's dist_attr
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
ctx
.
set_tensor_dist_attr_for_program
(
identity_var_dist_attr
)
intermediate_var_0
,
identity_var_dist_attr
)
check_variable_and_dtype
(
check_variable_and_dtype
(
X_var
,
'tensor'
,
X_var
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'_c_identity'
)
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
)
c_identity_op
=
main_block
.
append_op
(
c_identity_op
=
main_block
.
append_op
(
type
=
'c_identity'
,
type
=
'c_identity'
,
inputs
=
{
'X'
:
[
X_var
]},
inputs
=
{
'X'
:
[
X_var
]},
...
@@ -1536,24 +1757,35 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1536,24 +1757,35 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
),
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
),
})
},
)
if
intermediate_var_0
.
shape
!=
ref_shape_x
:
if
intermediate_var_0
.
shape
!=
ref_shape_x
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
check_variable_and_dtype
(
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
intermediate_var_0
,
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
attrs
=
{
attrs
=
{
'trans_x'
:
trans_x
,
'trans_x'
:
trans_x
,
'trans_y'
:
trans_y
,
'trans_y'
:
trans_y
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
,
}
}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
matmul_v2_op
=
main_block
.
append_op
(
type
=
'matmul_v2'
,
matmul_v2_op
=
main_block
.
append_op
(
inputs
=
inputs
,
type
=
'matmul_v2'
,
outputs
=
{
'Out'
:
Out_var
},
inputs
=
inputs
,
attrs
=
attrs
)
outputs
=
{
'Out'
:
Out_var
},
attrs
=
attrs
,
)
if
Out_var
.
shape
!=
ref_shape_out
:
if
Out_var
.
shape
!=
ref_shape_out
:
Out_var
.
desc
.
set_shape
(
ref_shape_out
)
Out_var
.
desc
.
set_shape
(
ref_shape_out
)
...
@@ -1567,13 +1799,16 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1567,13 +1799,16 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
input_varname
=
c_identity_op
.
desc
.
input_arg_names
()[
0
]
input_varname
=
c_identity_op
.
desc
.
input_arg_names
()[
0
]
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
identity_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
)
input_dist_attr
)
identity_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
# output
# output
output_varname
=
c_identity_op
.
desc
.
output_arg_names
()[
0
]
output_varname
=
c_identity_op
.
desc
.
output_arg_names
()[
0
]
identity_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
identity_op_dist_attr
.
set_output_dist_attr
(
input_dist_attr
)
output_varname
,
input_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
c_identity_op
,
identity_op_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
c_identity_op
,
identity_op_dist_attr
)
# matmulv2
# matmulv2
...
@@ -1584,29 +1819,37 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1584,29 +1819,37 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
for
input_varname
in
matmul_v2_op
.
desc
.
input_arg_names
():
for
input_varname
in
matmul_v2_op
.
desc
.
input_arg_names
():
if
input_varname
in
src_op
.
desc
.
input_arg_names
():
if
input_varname
in
src_op
.
desc
.
input_arg_names
():
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
input_varname
,
input_dist_attr
)
else
:
else
:
input_var
=
main_block
.
var
(
input_varname
)
input_var
=
main_block
.
var
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
input_var
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
tensor_dist_attr
)
input_varname
,
tensor_dist_attr
)
for
output_varname
in
matmul_v2_op
.
desc
.
output_arg_names
():
for
output_varname
in
matmul_v2_op
.
desc
.
output_arg_names
():
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
)
output_dist_attr
)
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
matmul_v2_op
,
matmulv2_op_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
matmul_v2_op
,
matmulv2_op_dist_attr
)
# init param sync
# init param sync
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
_init_param_sync
(
rank_id
)
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
def
backward
(
ctx
,
*
args
,
**
kwargs
):
...
@@ -1615,7 +1858,6 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
...
@@ -1615,7 +1858,6 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
# RowParallel
# RowParallel
class
DistributedMatmulV2Impl1
(
DistributedOperatorImpl
):
class
DistributedMatmulV2Impl1
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
def
__init__
(
self
,
name
):
super
(
DistributedMatmulV2Impl1
,
self
).
__init__
(
name
)
super
(
DistributedMatmulV2Impl1
,
self
).
__init__
(
name
)
self
.
_forward_implemented
=
True
self
.
_forward_implemented
=
True
...
@@ -1638,7 +1880,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1638,7 +1880,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
])
backward_op
.
input
(
"Y"
)[
0
]
)
assert
Y_var_dim_mapping
[
1
]
<
0
assert
Y_var_dim_mapping
[
1
]
<
0
parallel_axis
=
Y_var_dim_mapping
[
0
]
parallel_axis
=
Y_var_dim_mapping
[
0
]
...
@@ -1653,50 +1896,59 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1653,50 +1896,59 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
ctx
,
ctx
,
var_names
,
var_names
,
attrs
=
attrs
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
comm_op_cost_list
=
build_comm_costs_from_descs
(
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
res
.
append
(
comm_op_cost_list
)
res
.
append
(
comm_op_cost_list
)
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2GradOpCost
,
ctx
,
)
processes
,
desc_mapping
,
cost_mapping
=
build_comp_costs_from_descs
(
cluster
)
MatmulV2GradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
res
.
append
(
cost_mapping
)
# need gradient allreduce
# need gradient allreduce
process_mesh
=
dist_attr
.
process_mesh
process_mesh
=
dist_attr
.
process_mesh
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
if
(
batch_size_axis
]
>
1
and
is_parameter_related
(
batch_size_axis
>
-
1
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
build_dp_costs
(
cluster
)
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2OpCost
,
ctx
,
cost_mapping
=
build_comp_costs_from_descs
(
processes
,
desc_mapping
,
MatmulV2OpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
cluster
)
)
# calc comm op cost
# calc comm op cost
serial_op
=
dist_op
.
serial_op
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
])[
-
2
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
2
]
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
var_names
=
serial_op
.
output
(
"Out"
)
var_names
=
serial_op
.
output
(
"Out"
)
...
@@ -1706,11 +1958,16 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1706,11 +1958,16 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
ctx
,
ctx
,
var_names
,
var_names
,
attrs
=
attrs
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
comm_op_cost_list
=
build_comm_costs_from_descs
(
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
AllreduceSumOpCost
,
cluster
)
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
,
)
res_cost
=
[
cost_mapping
,
comm_op_cost_list
]
res_cost
=
[
cost_mapping
,
comm_op_cost_list
]
return
res_cost
return
res_cost
...
@@ -1721,16 +1978,19 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1721,16 +1978,19 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
x_name
=
op_desc
.
input
(
'X'
)[
0
]
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
x_dims_mapping
=
copy
.
deepcopy
(
x_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
x_name
))
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
)
y_dims_mapping
=
copy
.
deepcopy
(
y_dims_mapping
=
copy
.
deepcopy
(
op_dist_attr
.
get_input_dims_mapping
(
y_name
))
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
)
trans_x
=
op_desc
.
attr
(
'trans_x'
)
trans_x
=
op_desc
.
attr
(
'trans_x'
)
trans_y
=
op_desc
.
attr
(
'trans_y'
)
trans_y
=
op_desc
.
attr
(
'trans_y'
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
trans_x_y_dims_mapping
(
trans_x
,
trans_y
,
x_dims_mapping
,
y_dims_mapping
)
if
is_dim_replicate
(
x_dims_mapping
[
-
1
]):
if
is_dim_replicate
(
x_dims_mapping
[
-
1
]):
return
False
return
False
if
is_dim_replicate
(
y_dims_mapping
[
-
2
])
or
is_dim_shard
(
if
is_dim_replicate
(
y_dims_mapping
[
-
2
])
or
is_dim_shard
(
y_dims_mapping
[
-
1
]):
y_dims_mapping
[
-
1
]
):
return
False
return
False
# Other dimensions must be replicate except the batch dimension
# Other dimensions must be replicate except the batch dimension
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
...
@@ -1752,8 +2012,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1752,8 +2012,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
return
True
return
True
def
is_auto_compatible
(
self
,
dist_op
):
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
(
not
self
.
is_output_compatible
(
dist_op
)):
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
return
False
return
False
...
@@ -1778,28 +2039,33 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1778,28 +2039,33 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
src_op
=
dist_op_context
.
cur_src_op
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
assert
op_dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
assert
(
str
(
src_op
))
op_dist_attr
is
not
None
),
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
=
_get_corresponding_rank
(
rank_id
)
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
# check validation of inputs / outputs
# check validation of inputs / outputs
for
input_name
in
src_op
.
desc
.
input_names
():
for
input_name
in
src_op
.
desc
.
input_names
():
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
input_name
)
input_name
)
assert
len
(
kwargs
[
input_name
])
==
len
(
assert
len
(
kwargs
[
input_name
])
==
len
(
src_op
.
desc
.
input
(
input_name
)
src_op
.
desc
.
input
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
for
output_name
in
src_op
.
desc
.
output_names
():
for
output_name
in
src_op
.
desc
.
output_names
():
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
output_name
)
output_name
)
assert
len
(
kwargs
[
output_name
])
==
len
(
assert
len
(
kwargs
[
output_name
])
==
len
(
src_op
.
desc
.
output
(
output_name
)
src_op
.
desc
.
output
(
output_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
),
"number of tensor for input [{}] is not match"
.
format
(
output_name
)
output_name
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
...
@@ -1809,28 +2075,39 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1809,28 +2075,39 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
# TODO infer logic comm presentation
# TODO infer logic comm presentation
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
Weight_var
.
name
)[
-
2
]
if
trans_y
:
if
trans_y
:
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
Weight_var
.
name
assert
matmul_row_dim_mapping
>=
0
,
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
)[
-
1
]
matmul_row_dim_mapping
)
assert
(
matmul_row_dim_mapping
>=
0
),
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_row_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
parallel_axis
=
matmul_row_dim_mapping
parallel_axis
=
matmul_row_dim_mapping
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
group_ranks
=
_get_comm_group
(
parallel_axis
,
rank_id
)
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group
=
new_process_group
(
group_ranks
)
group
=
new_process_group
(
group_ranks
)
check_variable_and_dtype
(
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
check_variable_and_dtype
(
'linear'
)
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
)
'linear'
)
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
attrs
=
{
attrs
=
{
'trans_x'
:
trans_x
,
'trans_x'
:
trans_x
,
'trans_y'
:
trans_y
,
'trans_y'
:
trans_y
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
,
}
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
...
@@ -1839,27 +2116,33 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1839,27 +2116,33 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
assert
out_tensor_dist_attr
is
not
None
assert
out_tensor_dist_attr
is
not
None
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
out_var_dist_attr
is
not
None
assert
out_var_dist_attr
is
not
None
ref_shape
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
ref_shape
=
infer_shape
(
out_var_dist_attr
)
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
intermediate_var_0
=
main_block
.
create_var
(
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
name
=
unique_name
.
generate_with_ignorable_key
(
[
"c_allreduce_sum"
,
'tmp'
])),
"."
.
join
([
"c_allreduce_sum"
,
'tmp'
])
),
shape
=
Out_var
.
shape
,
shape
=
Out_var
.
shape
,
dtype
=
Out_var
.
dtype
,
dtype
=
Out_var
.
dtype
,
type
=
Out_var
.
type
,
type
=
Out_var
.
type
,
lod_level
=
Out_var
.
lod_level
,
lod_level
=
Out_var
.
lod_level
,
persistable
=
False
,
persistable
=
False
,
is_data
=
False
,
is_data
=
False
,
need_check_feed
=
Out_var
.
desc
.
need_check_feed
())
need_check_feed
=
Out_var
.
desc
.
need_check_feed
(),
)
# set intermediate_var_0's dist_attr with Out_var's dist_attr
# set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
ctx
.
set_tensor_dist_attr_for_program
(
out_var_dist_attr
)
intermediate_var_0
,
out_var_dist_attr
)
matmul_v2_op
=
main_block
.
append_op
(
type
=
'matmul_v2'
,
matmul_v2_op
=
main_block
.
append_op
(
inputs
=
inputs
,
type
=
'matmul_v2'
,
outputs
=
{
'Out'
:
intermediate_var_0
},
inputs
=
inputs
,
attrs
=
attrs
)
outputs
=
{
'Out'
:
intermediate_var_0
},
attrs
=
attrs
,
)
if
intermediate_var_0
.
shape
!=
ref_shape
:
if
intermediate_var_0
.
shape
!=
ref_shape
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape
)
intermediate_var_0
.
desc
.
set_shape
(
ref_shape
)
...
@@ -1871,8 +2154,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1871,8 +2154,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
),
})
},
)
if
Out_var
.
shape
!=
ref_shape
:
if
Out_var
.
shape
!=
ref_shape
:
Out_var
.
desc
.
set_shape
(
ref_shape
)
Out_var
.
desc
.
set_shape
(
ref_shape
)
...
@@ -1885,15 +2169,19 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1885,15 +2169,19 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
for
input_varname
in
matmul_v2_op
.
desc
.
input_arg_names
():
for
input_varname
in
matmul_v2_op
.
desc
.
input_arg_names
():
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
)
input_dist_attr
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
output_varname
=
matmul_v2_op
.
desc
.
output_arg_names
()[
0
]
output_varname
=
matmul_v2_op
.
desc
.
output_arg_names
()[
0
]
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
)
output_dist_attr
)
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
matmul_v2_op
,
matmulv2_op_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
matmul_v2_op
,
matmulv2_op_dist_attr
)
# allreduce
# allreduce
...
@@ -1905,21 +2193,26 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1905,21 +2193,26 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
input_var
=
main_block
.
var
(
input_varname
)
input_var
=
main_block
.
var
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
assert
tensor_dist_attr
is
not
None
assert
tensor_dist_attr
is
not
None
allreduce_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
allreduce_op_dist_attr
.
set_input_dist_attr
(
tensor_dist_attr
)
input_varname
,
tensor_dist_attr
)
for
output_varname
in
c_allreduce_sum_op
.
desc
.
output_arg_names
():
for
output_varname
in
c_allreduce_sum_op
.
desc
.
output_arg_names
():
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
allreduce_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
)
output_dist_attr
)
allreduce_op_dist_attr
.
set_output_dist_attr
(
ctx
.
set_op_dist_attr_for_program
(
c_allreduce_sum_op
,
output_varname
,
output_dist_attr
allreduce_op_dist_attr
)
)
ctx
.
set_op_dist_attr_for_program
(
c_allreduce_sum_op
,
allreduce_op_dist_attr
)
# init param sync
# init param sync
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
_init_param_sync
(
rank_id
)
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
def
backward
(
ctx
,
*
args
,
**
kwargs
):
...
@@ -1928,7 +2221,6 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
...
@@ -1928,7 +2221,6 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
# ReplicateParallel
# ReplicateParallel
class
DistributedMatmulV2Impl2
(
DistributedOperatorImpl
):
class
DistributedMatmulV2Impl2
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
def
__init__
(
self
,
name
):
super
(
DistributedMatmulV2Impl2
,
self
).
__init__
(
name
)
super
(
DistributedMatmulV2Impl2
,
self
).
__init__
(
name
)
...
@@ -1950,38 +2242,44 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
...
@@ -1950,38 +2242,44 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
process_mesh
=
dist_attr
.
process_mesh
process_mesh
=
dist_attr
.
process_mesh
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
process_mesh
.
processes
processes
=
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2GradOpCost
,
ctx
,
cost_mapping
=
build_comp_costs_from_descs
(
processes
,
desc_mapping
,
MatmulV2GradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
cluster
)
)
res
.
append
(
cost_mapping
)
res
.
append
(
cost_mapping
)
# need gradient allreduce
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
if
(
batch_size_axis
]
>
1
and
is_parameter_related
(
batch_size_axis
>
-
1
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
build_dp_costs
(
cluster
)
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulV2OpCost
,
ctx
,
cost_mapping
=
build_comp_costs_from_descs
(
processes
,
desc_mapping
,
MatmulV2OpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
cluster
)
)
res_cost
=
[
cost_mapping
]
res_cost
=
[
cost_mapping
]
...
@@ -1998,13 +2296,15 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
...
@@ -1998,13 +2296,15 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
return
False
if
is_valid_list_index
(
x_dims_mapping
,
-
2
)
and
is_dim_shard
(
if
is_valid_list_index
(
x_dims_mapping
,
-
2
)
and
is_dim_shard
(
x_dims_mapping
[
-
2
]):
x_dims_mapping
[
-
2
]
):
return
False
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
1
]):
if
is_dim_shard
(
y_dims_mapping
[
-
1
]):
return
False
return
False
if
is_valid_list_index
(
y_dims_mapping
,
-
2
)
and
is_dim_shard
(
if
is_valid_list_index
(
y_dims_mapping
,
-
2
)
and
is_dim_shard
(
y_dims_mapping
[
-
2
]):
y_dims_mapping
[
-
2
]
):
return
False
return
False
return
True
return
True
...
@@ -2019,14 +2319,16 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
...
@@ -2019,14 +2319,16 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
if
is_dim_shard
(
out_dims_mapping
[
-
1
]):
if
is_dim_shard
(
out_dims_mapping
[
-
1
]):
return
False
return
False
if
is_valid_list_index
(
out_dims_mapping
,
-
2
)
and
is_dim_shard
(
if
is_valid_list_index
(
out_dims_mapping
,
-
2
)
and
is_dim_shard
(
out_dims_mapping
[
-
2
]):
out_dims_mapping
[
-
2
]
):
return
False
return
False
return
True
return
True
def
is_auto_compatible
(
self
,
dist_op
):
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
(
not
self
.
is_output_compatible
(
dist_op
)):
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
...
@@ -2050,16 +2352,18 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
...
@@ -2050,16 +2352,18 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
_right_operand_parameter_matmul_backward
(
ctx
,
*
args
,
**
kwargs
)
_right_operand_parameter_matmul_backward
(
ctx
,
*
args
,
**
kwargs
)
register_distributed_operator_impl
(
"matmul_v2"
,
DistributedMatmulV2Impl0
(
"column_parallel"
))
register_distributed_operator_impl
(
"matmul_v2"
,
DistributedMatmulV2Impl1
(
"row_parallel"
))
register_distributed_operator_impl
(
register_distributed_operator_impl
(
"matmul_v2"
,
DistributedMatmulV2Impl2
(
"replicate_parallel"
))
"matmul_v2"
,
DistributedMatmulV2Impl0
(
"column_parallel"
)
)
register_distributed_operator_impl
(
"matmul_v2"
,
DistributedMatmulV2Impl1
(
"row_parallel"
)
)
register_distributed_operator_impl
(
"matmul_v2"
,
DistributedMatmulV2Impl2
(
"replicate_parallel"
)
)
class
DistributedMul
(
DistributedOperatorImplContainer
):
class
DistributedMul
(
DistributedOperatorImplContainer
):
def
__init__
(
self
,
op_type
):
def
__init__
(
self
,
op_type
):
super
(
DistributedMul
,
self
).
__init__
(
op_type
)
super
(
DistributedMul
,
self
).
__init__
(
op_type
)
...
@@ -2069,7 +2373,6 @@ register_distributed_operator_impl_container(DistributedMul("mul"))
...
@@ -2069,7 +2373,6 @@ register_distributed_operator_impl_container(DistributedMul("mul"))
# ColumnParallel
# ColumnParallel
class
DistributedMulImpl0
(
DistributedOperatorImpl
):
class
DistributedMulImpl0
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
def
__init__
(
self
,
name
):
super
(
DistributedMulImpl0
,
self
).
__init__
(
name
)
super
(
DistributedMulImpl0
,
self
).
__init__
(
name
)
self
.
_forward_implemented
=
True
self
.
_forward_implemented
=
True
...
@@ -2092,7 +2395,8 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2092,7 +2395,8 @@ class DistributedMulImpl0(DistributedOperatorImpl):
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
])
backward_op
.
input
(
"Y"
)[
0
]
)
# col parallel: matmul + allreduce
# col parallel: matmul + allreduce
assert
Y_var_dim_mapping
[
0
]
<
0
assert
Y_var_dim_mapping
[
0
]
<
0
parallel_axis
=
Y_var_dim_mapping
[
1
]
parallel_axis
=
Y_var_dim_mapping
[
1
]
...
@@ -2102,13 +2406,14 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2102,13 +2406,14 @@ class DistributedMulImpl0(DistributedOperatorImpl):
assert
len
(
backward_op
.
output
(
"X@GRAD"
))
==
1
assert
len
(
backward_op
.
output
(
"X@GRAD"
))
==
1
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
)
process_mesh
=
dist_attr
.
process_mesh
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
processes
=
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MulGradOpCost
,
ctx
,
cost_mapping
=
build_comp_costs_from_descs
(
processes
,
desc_mapping
,
MulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
cluster
)
)
res
.
append
(
cost_mapping
)
res
.
append
(
cost_mapping
)
# calc comm op cost
# calc comm op cost
...
@@ -2121,40 +2426,52 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2121,40 +2426,52 @@ class DistributedMulImpl0(DistributedOperatorImpl):
ctx
,
ctx
,
var_names
,
var_names
,
attrs
=
attrs
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
comm_op_cost_list
=
build_comm_costs_from_descs
(
AllreduceSumOpCost
,
ctx
,
processes
,
AllreduceSumOpCost
,
c_allreduce_sum_desc_mapping
,
cluster
)
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
,
)
res
.
append
(
comm_op_cost_list
)
res
.
append
(
comm_op_cost_list
)
# need gradient allreduce
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
if
(
batch_size_axis
]
>
1
and
is_parameter_related
(
batch_size_axis
>
-
1
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
build_dp_costs
(
cluster
)
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MulOpCost
,
ctx
,
processes
,
cost_mapping
=
build_comp_costs_from_descs
(
desc_mapping
,
cluster
)
MulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
# calc comm op cost
# calc comm op cost
serial_op
=
dist_op
.
serial_op
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
])[
-
1
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
1
]
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
var_names
=
serial_op
.
input
(
"X"
)
var_names
=
serial_op
.
input
(
"X"
)
c_identity_desc_mapping
=
build_comm_desc_from_dist_op
(
c_identity_desc_mapping
=
build_comm_desc_from_dist_op
(
...
@@ -2163,10 +2480,12 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2163,10 +2480,12 @@ class DistributedMulImpl0(DistributedOperatorImpl):
ctx
,
ctx
,
var_names
,
var_names
,
attrs
=
attrs
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
comm_op_cost_list
=
build_comm_costs_from_descs
(
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
res_cost
=
[
comm_op_cost_list
,
cost_mapping
]
res_cost
=
[
comm_op_cost_list
,
cost_mapping
]
return
res_cost
return
res_cost
...
@@ -2181,7 +2500,8 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2181,7 +2500,8 @@ class DistributedMulImpl0(DistributedOperatorImpl):
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
2
])
or
is_dim_replicate
(
if
is_dim_shard
(
y_dims_mapping
[
-
2
])
or
is_dim_replicate
(
y_dims_mapping
[
-
1
]):
y_dims_mapping
[
-
1
]
):
return
False
return
False
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
if
is_dim_shard
(
mapping
):
if
is_dim_shard
(
mapping
):
...
@@ -2201,8 +2521,9 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2201,8 +2521,9 @@ class DistributedMulImpl0(DistributedOperatorImpl):
return
True
return
True
def
is_auto_compatible
(
self
,
dist_op
):
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
(
not
self
.
is_output_compatible
(
dist_op
)):
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
...
@@ -2229,28 +2550,33 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2229,28 +2550,33 @@ class DistributedMulImpl0(DistributedOperatorImpl):
src_op
=
dist_op_context
.
cur_src_op
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
assert
op_dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
assert
(
str
(
src_op
))
op_dist_attr
is
not
None
),
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
=
_get_corresponding_rank
(
rank_id
)
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
# check validation of inputs / outputs
# check validation of inputs / outputs
for
input_name
in
src_op
.
desc
.
input_names
():
for
input_name
in
src_op
.
desc
.
input_names
():
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
input_name
)
input_name
)
assert
len
(
kwargs
[
input_name
])
==
len
(
assert
len
(
kwargs
[
input_name
])
==
len
(
src_op
.
desc
.
input
(
input_name
)
src_op
.
desc
.
input
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
for
output_name
in
src_op
.
desc
.
output_names
():
for
output_name
in
src_op
.
desc
.
output_names
():
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
output_name
)
output_name
)
assert
len
(
kwargs
[
output_name
])
==
len
(
assert
len
(
kwargs
[
output_name
])
==
len
(
src_op
.
desc
.
output
(
output_name
)
src_op
.
desc
.
output
(
output_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
),
"number of tensor for input [{}] is not match"
.
format
(
output_name
)
output_name
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
...
@@ -2258,15 +2584,20 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2258,15 +2584,20 @@ class DistributedMulImpl0(DistributedOperatorImpl):
# TODO infer logic comm presentation
# TODO infer logic comm presentation
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
1
]
Weight_var
.
name
assert
matmul_col_dim_mapping
>=
0
,
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
)[
-
1
]
matmul_col_dim_mapping
)
assert
(
matmul_col_dim_mapping
>=
0
),
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_col_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
parallel_axis
=
matmul_col_dim_mapping
parallel_axis
=
matmul_col_dim_mapping
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
group_ranks
=
_get_comm_group
(
parallel_axis
,
rank_id
)
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group
=
new_process_group
(
group_ranks
)
group
=
new_process_group
(
group_ranks
)
# infer new var shape with op dist attr
# infer new var shape with op dist attr
...
@@ -2274,31 +2605,39 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2274,31 +2605,39 @@ class DistributedMulImpl0(DistributedOperatorImpl):
assert
x_tensor_dist_attr
is
not
None
assert
x_tensor_dist_attr
is
not
None
identity_var_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
X_var
.
name
)
identity_var_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
X_var
.
name
)
assert
identity_var_dist_attr
is
not
None
assert
identity_var_dist_attr
is
not
None
ref_shape_x
=
infer_shape
(
main_block
,
X_var
,
x_tensor_dist_attr
,
ref_shape_x
=
infer_shape
(
identity_var_dist_attr
)
main_block
,
X_var
,
x_tensor_dist_attr
,
identity_var_dist_attr
)
# infer out var shape with op dist attr
# infer out var shape with op dist attr
out_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
Out_var
)
out_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
Out_var
)
assert
out_tensor_dist_attr
is
not
None
assert
out_tensor_dist_attr
is
not
None
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
out_var_dist_attr
is
not
None
assert
out_var_dist_attr
is
not
None
ref_shape_out
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
ref_shape_out
=
infer_shape
(
out_var_dist_attr
)
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
intermediate_var_0
=
main_block
.
create_var
(
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
name
=
unique_name
.
generate_with_ignorable_key
(
[
"c_identity"
,
'tmp'
])),
"."
.
join
([
"c_identity"
,
'tmp'
])
),
dtype
=
X_var
.
dtype
,
dtype
=
X_var
.
dtype
,
shape
=
X_var
.
shape
,
shape
=
X_var
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
persistable
=
False
,
stop_gradient
=
X_var
.
stop_gradient
)
stop_gradient
=
X_var
.
stop_gradient
,
)
# set intermediate_var_0's dist_attr with X_var's dist_attr
# set intermediate_var_0's dist_attr with X_var's dist_attr
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
ctx
.
set_tensor_dist_attr_for_program
(
identity_var_dist_attr
)
intermediate_var_0
,
identity_var_dist_attr
)
check_variable_and_dtype
(
check_variable_and_dtype
(
X_var
,
'tensor'
,
X_var
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'_c_identity'
)
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
)
c_identity_op
=
main_block
.
append_op
(
c_identity_op
=
main_block
.
append_op
(
type
=
'c_identity'
,
type
=
'c_identity'
,
inputs
=
{
'X'
:
[
X_var
]},
inputs
=
{
'X'
:
[
X_var
]},
...
@@ -2307,20 +2646,29 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2307,20 +2646,29 @@ class DistributedMulImpl0(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
),
})
},
)
if
intermediate_var_0
.
shape
!=
ref_shape_x
:
if
intermediate_var_0
.
shape
!=
ref_shape_x
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
check_variable_and_dtype
(
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
intermediate_var_0
,
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
# attrs = {'trans_x': False, 'trans_y': False}
# attrs = {'trans_x': False, 'trans_y': False}
attrs
=
{
attrs
=
{
"x_num_col_dims"
:
src_op
.
desc
.
attr
(
"x_num_col_dims"
),
"x_num_col_dims"
:
src_op
.
desc
.
attr
(
"x_num_col_dims"
),
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
),
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
),
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
,
}
}
inputs
=
{
'X'
:
intermediate_var_0
,
'Y'
:
Weight_var
}
inputs
=
{
'X'
:
intermediate_var_0
,
'Y'
:
Weight_var
}
...
@@ -2334,16 +2682,15 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2334,16 +2682,15 @@ class DistributedMulImpl0(DistributedOperatorImpl):
inputs_original_shape
[
var_name
]
=
var
.
shape
inputs_original_shape
[
var_name
]
=
var
.
shape
input_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
var
)
input_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
var
)
input_var_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
var
.
name
)
input_var_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
var
.
name
)
input_ref_shape
=
infer_shape
(
main_block
,
var
,
input_ref_shape
=
infer_shape
(
input_tensor_dist_attr
,
main_block
,
var
,
input_tensor_dist_attr
,
input_var_dist_attr
input_var_dist_attr
)
)
inputs_ref_shape
[
var_name
]
=
input_ref_shape
inputs_ref_shape
[
var_name
]
=
input_ref_shape
var
.
desc
.
set_shape
(
input_ref_shape
)
var
.
desc
.
set_shape
(
input_ref_shape
)
mul_op
=
main_block
.
append_op
(
type
=
'mul'
,
mul_op
=
main_block
.
append_op
(
inputs
=
inputs
,
type
=
'mul'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
Out_var
},
attrs
=
attrs
outputs
=
{
'Out'
:
Out_var
},
)
attrs
=
attrs
)
if
Out_var
.
shape
!=
ref_shape_out
:
if
Out_var
.
shape
!=
ref_shape_out
:
Out_var
.
desc
.
set_shape
(
ref_shape_out
)
Out_var
.
desc
.
set_shape
(
ref_shape_out
)
...
@@ -2362,13 +2709,16 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2362,13 +2709,16 @@ class DistributedMulImpl0(DistributedOperatorImpl):
input_varname
=
c_identity_op
.
desc
.
input_arg_names
()[
0
]
input_varname
=
c_identity_op
.
desc
.
input_arg_names
()[
0
]
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
identity_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
)
input_dist_attr
)
identity_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
# output
# output
output_varname
=
c_identity_op
.
desc
.
output_arg_names
()[
0
]
output_varname
=
c_identity_op
.
desc
.
output_arg_names
()[
0
]
identity_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
identity_op_dist_attr
.
set_output_dist_attr
(
input_dist_attr
)
output_varname
,
input_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
c_identity_op
,
identity_op_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
c_identity_op
,
identity_op_dist_attr
)
# matmulv2
# matmulv2
...
@@ -2379,29 +2729,37 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2379,29 +2729,37 @@ class DistributedMulImpl0(DistributedOperatorImpl):
for
input_varname
in
mul_op
.
desc
.
input_arg_names
():
for
input_varname
in
mul_op
.
desc
.
input_arg_names
():
if
input_varname
in
src_op
.
desc
.
input_arg_names
():
if
input_varname
in
src_op
.
desc
.
input_arg_names
():
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
input_varname
,
input_dist_attr
)
else
:
else
:
input_var
=
main_block
.
var
(
input_varname
)
input_var
=
main_block
.
var
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
input_var
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
tensor_dist_attr
)
input_varname
,
tensor_dist_attr
)
for
output_varname
in
mul_op
.
desc
.
output_arg_names
():
for
output_varname
in
mul_op
.
desc
.
output_arg_names
():
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
)
output_dist_attr
)
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
mul_op
,
matmulv2_op_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
mul_op
,
matmulv2_op_dist_attr
)
# init param sync
# init param sync
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
_init_param_sync
(
rank_id
)
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
def
backward
(
ctx
,
*
args
,
**
kwargs
):
...
@@ -2410,7 +2768,6 @@ class DistributedMulImpl0(DistributedOperatorImpl):
...
@@ -2410,7 +2768,6 @@ class DistributedMulImpl0(DistributedOperatorImpl):
# RowParallel
# RowParallel
class
DistributedMulImpl1
(
DistributedOperatorImpl
):
class
DistributedMulImpl1
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
def
__init__
(
self
,
name
):
super
(
DistributedMulImpl1
,
self
).
__init__
(
name
)
super
(
DistributedMulImpl1
,
self
).
__init__
(
name
)
self
.
_forward_implemented
=
True
self
.
_forward_implemented
=
True
...
@@ -2434,7 +2791,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -2434,7 +2791,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
main_block
=
backward_op
.
block
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
])
backward_op
.
input
(
"Y"
)[
0
]
)
assert
Y_var_dim_mapping
[
1
]
<
0
assert
Y_var_dim_mapping
[
1
]
<
0
parallel_axis
=
Y_var_dim_mapping
[
0
]
parallel_axis
=
Y_var_dim_mapping
[
0
]
...
@@ -2447,49 +2805,59 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -2447,49 +2805,59 @@ class DistributedMulImpl1(DistributedOperatorImpl):
ctx
,
ctx
,
var_names
,
var_names
,
attrs
=
attrs
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
processes
=
process_mesh
.
processes
processes
=
process_mesh
.
processes
comm_op_cost_list
=
build_comm_costs_from_descs
(
comm_op_cost_list
=
build_comm_costs_from_descs
(
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
res
.
append
(
comm_op_cost_list
)
res
.
append
(
comm_op_cost_list
)
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
cost_mapping
=
build_comp_costs_from_descs
(
MulGradOpCost
,
ctx
,
)
processes
,
desc_mapping
,
cost_mapping
=
build_comp_costs_from_descs
(
cluster
)
MulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
res
.
append
(
cost_mapping
)
# need gradient allreduce
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
if
(
batch_size_axis
]
>
1
and
is_parameter_related
(
batch_size_axis
>
-
1
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
build_dp_costs
(
cluster
)
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MulOpCost
,
ctx
,
processes
,
cost_mapping
=
build_comp_costs_from_descs
(
desc_mapping
,
cluster
)
MulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
# calc comm op cost
# calc comm op cost
serial_op
=
dist_op
.
serial_op
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
])[
-
2
]
serial_op
.
input
(
"Y"
)[
0
]
)[
-
2
]
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
var_names
=
serial_op
.
output
(
"Out"
)
var_names
=
serial_op
.
output
(
"Out"
)
...
@@ -2499,12 +2867,17 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -2499,12 +2867,17 @@ class DistributedMulImpl1(DistributedOperatorImpl):
ctx
,
ctx
,
var_names
,
var_names
,
attrs
=
attrs
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
parallel_axis
=
parallel_axis
,
)
# print("dist_matmul.py dist_op: ", dist_op)
# print("dist_matmul.py dist_op: ", dist_op)
comm_op_cost_list
=
build_comm_costs_from_descs
(
comm_op_cost_list
=
build_comm_costs_from_descs
(
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
AllreduceSumOpCost
,
cluster
)
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
,
)
res_cost
=
[
cost_mapping
,
comm_op_cost_list
]
res_cost
=
[
cost_mapping
,
comm_op_cost_list
]
...
@@ -2520,7 +2893,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -2520,7 +2893,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
if
is_dim_replicate
(
x_dims_mapping
[
-
1
]):
if
is_dim_replicate
(
x_dims_mapping
[
-
1
]):
return
False
return
False
if
is_dim_replicate
(
y_dims_mapping
[
-
2
])
or
is_dim_shard
(
if
is_dim_replicate
(
y_dims_mapping
[
-
2
])
or
is_dim_shard
(
y_dims_mapping
[
-
1
]):
y_dims_mapping
[
-
1
]
):
return
False
return
False
# Other dimensions must be replicate except the batch dimension
# Other dimensions must be replicate except the batch dimension
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
for
mapping
in
x_dims_mapping
[
1
:
-
1
]:
...
@@ -2542,8 +2916,9 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -2542,8 +2916,9 @@ class DistributedMulImpl1(DistributedOperatorImpl):
return
True
return
True
def
is_auto_compatible
(
self
,
dist_op
):
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
(
not
self
.
is_output_compatible
(
dist_op
)):
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
...
@@ -2570,28 +2945,33 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -2570,28 +2945,33 @@ class DistributedMulImpl1(DistributedOperatorImpl):
src_op
=
dist_op_context
.
cur_src_op
src_op
=
dist_op_context
.
cur_src_op
rank_id
=
dist_op_context
.
rank_id
rank_id
=
dist_op_context
.
rank_id
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
op_dist_attr
=
ctx
.
get_op_dist_attr_for_program
(
src_op
)
assert
op_dist_attr
is
not
None
,
"backward op [{}] don't have dist attribute !"
.
format
(
assert
(
str
(
src_op
))
op_dist_attr
is
not
None
),
"backward op [{}] don't have dist attribute !"
.
format
(
str
(
src_op
))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
if
rank_id
not
in
op_dist_attr
.
process_mesh
.
processes
:
rank_id
=
_get_corresponding_rank
(
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
=
_get_corresponding_rank
(
rank_id
)
ctx
,
op_dist_attr
.
process_mesh
,
rank_id
)
# check validation of inputs / outputs
# check validation of inputs / outputs
for
input_name
in
src_op
.
desc
.
input_names
():
for
input_name
in
src_op
.
desc
.
input_names
():
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
assert
input_name
in
kwargs
,
"input [{}] is not given"
.
format
(
input_name
)
input_name
)
assert
len
(
kwargs
[
input_name
])
==
len
(
assert
len
(
kwargs
[
input_name
])
==
len
(
src_op
.
desc
.
input
(
input_name
)
src_op
.
desc
.
input
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
input_name
)
for
output_name
in
src_op
.
desc
.
output_names
():
for
output_name
in
src_op
.
desc
.
output_names
():
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
assert
output_name
in
kwargs
,
"input [{}] is not given"
.
format
(
output_name
)
output_name
)
assert
len
(
kwargs
[
output_name
])
==
len
(
assert
len
(
kwargs
[
output_name
])
==
len
(
src_op
.
desc
.
output
(
output_name
)
src_op
.
desc
.
output
(
output_name
)
),
"number of tensor for input [{}] is not match"
.
format
(
),
"number of tensor for input [{}] is not match"
.
format
(
output_name
)
output_name
)
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
X_var
=
main_block
.
var
(
kwargs
[
'X'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
Weight_var
=
main_block
.
_var_recursive
(
kwargs
[
'Y'
][
0
])
...
@@ -2599,26 +2979,36 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -2599,26 +2979,36 @@ class DistributedMulImpl1(DistributedOperatorImpl):
# TODO infer logic comm presentation
# TODO infer logic comm presentation
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
-
2
]
Weight_var
.
name
assert
matmul_row_dim_mapping
>=
0
,
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
)[
-
2
]
matmul_row_dim_mapping
)
assert
(
matmul_row_dim_mapping
>=
0
),
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_row_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
process_mesh_group
=
op_dist_attr
.
process_mesh
.
processes
parallel_axis
=
matmul_row_dim_mapping
parallel_axis
=
matmul_row_dim_mapping
group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
group_ranks
=
_get_comm_group
(
parallel_axis
,
rank_id
)
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
group
=
new_process_group
(
group_ranks
)
group
=
new_process_group
(
group_ranks
)
check_variable_and_dtype
(
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
check_variable_and_dtype
(
'linear'
)
X_var
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
)
'linear'
)
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
# attrs = {'trans_x': False, 'trans_y': False}
# attrs = {'trans_x': False, 'trans_y': False}
attrs
=
{
attrs
=
{
"x_num_col_dims"
:
src_op
.
desc
.
attr
(
"x_num_col_dims"
),
"x_num_col_dims"
:
src_op
.
desc
.
attr
(
"x_num_col_dims"
),
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
),
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
),
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
,
}
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
...
@@ -2627,22 +3017,26 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -2627,22 +3017,26 @@ class DistributedMulImpl1(DistributedOperatorImpl):
assert
out_tensor_dist_attr
is
not
None
assert
out_tensor_dist_attr
is
not
None
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
out_var_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
out_var_dist_attr
is
not
None
assert
out_var_dist_attr
is
not
None
ref_shape
=
infer_shape
(
main_block
,
Out_var
,
out_tensor_dist_attr
,
ref_shape
=
infer_shape
(
out_var_dist_attr
)
main_block
,
Out_var
,
out_tensor_dist_attr
,
out_var_dist_attr
)
intermediate_var_0
=
main_block
.
create_var
(
intermediate_var_0
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
name
=
unique_name
.
generate_with_ignorable_key
(
[
"c_allreduce_sum"
,
'tmp'
])),
"."
.
join
([
"c_allreduce_sum"
,
'tmp'
])
),
shape
=
Out_var
.
shape
,
shape
=
Out_var
.
shape
,
dtype
=
Out_var
.
dtype
,
dtype
=
Out_var
.
dtype
,
type
=
Out_var
.
type
,
type
=
Out_var
.
type
,
lod_level
=
Out_var
.
lod_level
,
lod_level
=
Out_var
.
lod_level
,
persistable
=
False
,
persistable
=
False
,
is_data
=
False
,
is_data
=
False
,
need_check_feed
=
Out_var
.
desc
.
need_check_feed
())
need_check_feed
=
Out_var
.
desc
.
need_check_feed
(),
)
# set intermediate_var_0's dist_attr with Out_var's dist_attr
# set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx
.
set_tensor_dist_attr_for_program
(
intermediate_var_0
,
ctx
.
set_tensor_dist_attr_for_program
(
out_var_dist_attr
)
intermediate_var_0
,
out_var_dist_attr
)
inputs_ref_shape
=
{}
inputs_ref_shape
=
{}
inputs_original_shape
=
{}
inputs_original_shape
=
{}
...
@@ -2651,16 +3045,18 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -2651,16 +3045,18 @@ class DistributedMulImpl1(DistributedOperatorImpl):
inputs_original_shape
[
var_name
]
=
var
.
shape
inputs_original_shape
[
var_name
]
=
var
.
shape
input_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
var
)
input_tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
var
)
input_var_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
var
.
name
)
input_var_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
var
.
name
)
input_ref_shape
=
infer_shape
(
main_block
,
var
,
input_ref_shape
=
infer_shape
(
input_tensor_dist_attr
,
main_block
,
var
,
input_tensor_dist_attr
,
input_var_dist_attr
input_var_dist_attr
)
)
inputs_ref_shape
[
var_name
]
=
input_ref_shape
inputs_ref_shape
[
var_name
]
=
input_ref_shape
var
.
desc
.
set_shape
(
input_ref_shape
)
var
.
desc
.
set_shape
(
input_ref_shape
)
mul_op
=
main_block
.
append_op
(
type
=
'mul'
,
mul_op
=
main_block
.
append_op
(
inputs
=
inputs
,
type
=
'mul'
,
outputs
=
{
'Out'
:
intermediate_var_0
},
inputs
=
inputs
,
attrs
=
attrs
)
outputs
=
{
'Out'
:
intermediate_var_0
},
attrs
=
attrs
,
)
if
intermediate_var_0
.
shape
!=
ref_shape
:
if
intermediate_var_0
.
shape
!=
ref_shape
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape
)
intermediate_var_0
.
desc
.
set_shape
(
ref_shape
)
...
@@ -2678,8 +3074,9 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -2678,8 +3074,9 @@ class DistributedMulImpl1(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
),
})
},
)
if
Out_var
.
shape
!=
ref_shape
:
if
Out_var
.
shape
!=
ref_shape
:
Out_var
.
desc
.
set_shape
(
ref_shape
)
Out_var
.
desc
.
set_shape
(
ref_shape
)
...
@@ -2693,15 +3090,19 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -2693,15 +3090,19 @@ class DistributedMulImpl1(DistributedOperatorImpl):
for
input_varname
in
mul_op
.
desc
.
input_arg_names
():
for
input_varname
in
mul_op
.
desc
.
input_arg_names
():
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
input_dist_attr
=
op_dist_attr
.
get_input_dist_attr
(
input_varname
)
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
input_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
)
input_dist_attr
)
matmulv2_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
input_dist_attr
)
output_varname
=
mul_op
.
desc
.
output_arg_names
()[
0
]
output_varname
=
mul_op
.
desc
.
output_arg_names
()[
0
]
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
Out_var
.
name
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
)
output_dist_attr
)
matmulv2_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
output_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
mul_op
,
matmulv2_op_dist_attr
)
ctx
.
set_op_dist_attr_for_program
(
mul_op
,
matmulv2_op_dist_attr
)
# allreduce
# allreduce
...
@@ -2713,21 +3114,26 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -2713,21 +3114,26 @@ class DistributedMulImpl1(DistributedOperatorImpl):
input_var
=
main_block
.
var
(
input_varname
)
input_var
=
main_block
.
var
(
input_varname
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
tensor_dist_attr
=
ctx
.
get_tensor_dist_attr_for_program
(
input_var
)
assert
tensor_dist_attr
is
not
None
assert
tensor_dist_attr
is
not
None
allreduce_op_dist_attr
.
set_input_dist_attr
(
input_varname
,
allreduce_op_dist_attr
.
set_input_dist_attr
(
tensor_dist_attr
)
input_varname
,
tensor_dist_attr
)
for
output_varname
in
c_allreduce_sum_op
.
desc
.
output_arg_names
():
for
output_varname
in
c_allreduce_sum_op
.
desc
.
output_arg_names
():
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
output_dist_attr
=
op_dist_attr
.
get_output_dist_attr
(
output_varname
)
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
assert
output_dist_attr
is
not
None
,
"dist_attr is {}"
.
format
(
op_dist_attr
)
op_dist_attr
allreduce_op_dist_attr
.
set_output_dist_attr
(
output_varname
,
)
output_dist_attr
)
allreduce_op_dist_attr
.
set_output_dist_attr
(
ctx
.
set_op_dist_attr_for_program
(
c_allreduce_sum_op
,
output_varname
,
output_dist_attr
allreduce_op_dist_attr
)
)
ctx
.
set_op_dist_attr_for_program
(
c_allreduce_sum_op
,
allreduce_op_dist_attr
)
# init param sync
# init param sync
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
if
Weight_var
.
is_parameter
and
not
op_dist_attr
.
is_recompute
:
_init_param_sync
(
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
_init_param_sync
(
rank_id
)
Weight_var
,
dist_op_context
,
startup_block
,
ctx
,
rank_id
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
*
args
,
**
kwargs
):
def
backward
(
ctx
,
*
args
,
**
kwargs
):
...
@@ -2736,7 +3142,6 @@ class DistributedMulImpl1(DistributedOperatorImpl):
...
@@ -2736,7 +3142,6 @@ class DistributedMulImpl1(DistributedOperatorImpl):
# ReplicateParallel
# ReplicateParallel
class
DistributedMulImpl2
(
DistributedOperatorImpl
):
class
DistributedMulImpl2
(
DistributedOperatorImpl
):
def
__init__
(
self
,
name
):
def
__init__
(
self
,
name
):
super
(
DistributedMulImpl2
,
self
).
__init__
(
name
)
super
(
DistributedMulImpl2
,
self
).
__init__
(
name
)
...
@@ -2757,38 +3162,45 @@ class DistributedMulImpl2(DistributedOperatorImpl):
...
@@ -2757,38 +3162,45 @@ class DistributedMulImpl2(DistributedOperatorImpl):
vars
=
main_block
.
vars
vars
=
main_block
.
vars
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
)
process_mesh
=
dist_attr
.
process_mesh
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
processes
=
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MulGradOpCost
,
ctx
,
cost_mapping
=
build_comp_costs_from_descs
(
processes
,
desc_mapping
,
MulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
cluster
)
)
res
.
append
(
cost_mapping
)
res
.
append
(
cost_mapping
)
# need gradient allreduce
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
backward_op
.
input
(
"X"
)[
0
]
)
mesh_shape
=
process_mesh
.
topology
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
if
(
batch_size_axis
]
>
1
and
is_parameter_related
(
batch_size_axis
>
-
1
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
)
):
parallel_axis
=
batch_size_axis
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
build_dp_costs
(
cluster
)
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_context
=
ctx
)
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MulOpCost
,
ctx
,
processes
,
cost_mapping
=
build_comp_costs_from_descs
(
desc_mapping
,
cluster
)
MulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res_cost
=
[
cost_mapping
]
res_cost
=
[
cost_mapping
]
return
res_cost
return
res_cost
...
@@ -2804,12 +3216,14 @@ class DistributedMulImpl2(DistributedOperatorImpl):
...
@@ -2804,12 +3216,14 @@ class DistributedMulImpl2(DistributedOperatorImpl):
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
if
is_dim_shard
(
x_dims_mapping
[
-
1
]):
return
False
return
False
if
is_valid_list_index
(
x_dims_mapping
,
-
2
)
and
is_dim_shard
(
if
is_valid_list_index
(
x_dims_mapping
,
-
2
)
and
is_dim_shard
(
x_dims_mapping
[
-
2
]):
x_dims_mapping
[
-
2
]
):
return
False
return
False
if
is_dim_shard
(
y_dims_mapping
[
-
1
]):
if
is_dim_shard
(
y_dims_mapping
[
-
1
]):
return
False
return
False
if
is_valid_list_index
(
y_dims_mapping
,
-
2
)
and
is_dim_shard
(
if
is_valid_list_index
(
y_dims_mapping
,
-
2
)
and
is_dim_shard
(
y_dims_mapping
[
-
2
]):
y_dims_mapping
[
-
2
]
):
return
False
return
False
return
True
return
True
...
@@ -2824,14 +3238,16 @@ class DistributedMulImpl2(DistributedOperatorImpl):
...
@@ -2824,14 +3238,16 @@ class DistributedMulImpl2(DistributedOperatorImpl):
if
is_dim_shard
(
out_dims_mapping
[
-
1
]):
if
is_dim_shard
(
out_dims_mapping
[
-
1
]):
return
False
return
False
if
is_valid_list_index
(
out_dims_mapping
,
-
2
)
and
is_dim_shard
(
if
is_valid_list_index
(
out_dims_mapping
,
-
2
)
and
is_dim_shard
(
out_dims_mapping
[
-
2
]):
out_dims_mapping
[
-
2
]
):
return
False
return
False
return
True
return
True
def
is_auto_compatible
(
self
,
dist_op
):
def
is_auto_compatible
(
self
,
dist_op
):
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
\
if
(
not
self
.
is_input_compatible
(
dist_op
))
or
(
(
not
self
.
is_output_compatible
(
dist_op
)):
not
self
.
is_output_compatible
(
dist_op
)
):
return
False
return
False
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
if
not
_is_auto_compatible_for_matmul
(
dist_op
):
...
@@ -2855,8 +3271,10 @@ class DistributedMulImpl2(DistributedOperatorImpl):
...
@@ -2855,8 +3271,10 @@ class DistributedMulImpl2(DistributedOperatorImpl):
_right_operand_parameter_matmul_backward
(
ctx
,
*
args
,
**
kwargs
)
_right_operand_parameter_matmul_backward
(
ctx
,
*
args
,
**
kwargs
)
register_distributed_operator_impl
(
"mul"
,
register_distributed_operator_impl
(
DistributedMulImpl0
(
"column_parallel"
))
"mul"
,
DistributedMulImpl0
(
"column_parallel"
)
)
register_distributed_operator_impl
(
"mul"
,
DistributedMulImpl1
(
"row_parallel"
))
register_distributed_operator_impl
(
"mul"
,
DistributedMulImpl1
(
"row_parallel"
))
register_distributed_operator_impl
(
"mul"
,
register_distributed_operator_impl
(
DistributedMulImpl2
(
"replicate_parallel"
))
"mul"
,
DistributedMulImpl2
(
"replicate_parallel"
)
)
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
6f3c9643
...
@@ -254,17 +254,26 @@ class Parallelizer:
...
@@ -254,17 +254,26 @@ class Parallelizer:
self
.
_dist_context
.
serial_feed_vars
[
"inputs"
]
self
.
_dist_context
.
serial_feed_vars
[
"inputs"
]
+
self
.
_dist_context
.
serial_feed_vars
[
"labels"
]
+
self
.
_dist_context
.
serial_feed_vars
[
"labels"
]
)
)
if
config
[
"use_pure_fp16"
]:
self
.
_logger
.
info
(
"Applying AMP-{}-{} ..."
.
format
(
config
[
"dtype"
],
config
[
'level'
]
),
)
if
config
[
'level'
]
==
"o1"
:
auto_parallel_amp_pass
=
new_pass
(
"auto_parallel_amp"
,
config
)
auto_parallel_amp_pass
.
apply
(
[
main_program
],
[
startup_program
],
self
.
_pass_context
)
loss
=
auto_parallel_amp_pass
.
get_loss
()
elif
config
[
'level'
]
in
[
'o2'
,
'o3'
]:
config
[
"base_opt"
]
=
optimizer
config
[
"base_opt"
]
=
optimizer
auto_parallel_fp16_pass
=
new_pass
(
"auto_parallel_fp16"
,
config
)
auto_parallel_fp16_pass
=
new_pass
(
"auto_parallel_fp16"
,
config
)
auto_parallel_fp16_pass
.
apply
(
auto_parallel_fp16_pass
.
apply
(
[
main_program
],
[
startup_program
],
self
.
_pass_context
[
main_program
],
[
startup_program
],
self
.
_pass_context
)
)
loss
=
auto_parallel_fp16_pass
.
get_loss
()
else
:
else
:
auto_parallel_amp_pass
=
new_pass
(
"auto_parallel_amp"
,
config
)
raise
ValueError
(
"AMP level should be one of o1, o2, o3"
)
auto_parallel_amp_pass
.
apply
(
[
main_program
],
[
startup_program
],
self
.
_pass_context
)
# apply recompute pass
# apply recompute pass
# recompute is then train-only optimization
# recompute is then train-only optimization
...
...
python/paddle/distributed/passes/auto_parallel_amp.py
浏览文件 @
6f3c9643
...
@@ -18,25 +18,48 @@ from paddle.fluid import unique_name
...
@@ -18,25 +18,48 @@ from paddle.fluid import unique_name
from
.pass_base
import
PassBase
,
register_pass
from
.pass_base
import
PassBase
,
register_pass
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
from
paddle.fluid.data_feeder
import
check_variable_and_dtype
,
check_type
from
paddle.fluid.data_feeder
import
check_variable_and_dtype
,
check_type
from
paddle.distributed.auto_parallel.utils
import
get_loss_op
,
set_var_dist_attr
from
paddle.distributed.auto_parallel.utils
import
(
from
paddle.distributed.auto_parallel.utils
import
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
get_loss_op
,
from
paddle.distributed.auto_parallel.process_group
import
get_world_process_group
set_var_dist_attr
,
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
AutoMixedPrecisionLists
)
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
_keep_fp32_input
,
_keep_fp32_output
,
find_op_index
from
paddle.distributed.auto_parallel.utils
import
(
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
_valid_types
,
find_true_post_op
,
find_true_prev_op
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
,
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
_is_in_black_varnames
,
_dtype_to_str
,
_rename_arg
)
from
paddle.distributed.auto_parallel.dist_attribute
import
OperatorDistributedAttribute
from
paddle.distributed.auto_parallel.process_group
import
(
get_world_process_group
,
)
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
(
AutoMixedPrecisionLists
,
)
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
(
_keep_fp32_input
,
_keep_fp32_output
,
find_op_index
,
)
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
(
_valid_types
,
find_true_post_op
,
find_true_prev_op
,
)
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
(
_is_in_black_varnames
,
_dtype_to_str
,
_rename_arg
,
)
from
paddle.distributed.auto_parallel.dist_attribute
import
(
OperatorDistributedAttribute
,
)
from
..auto_parallel.utils
import
is_forward_op
,
is_backward_op
,
is_loss_op
from
..auto_parallel.utils
import
is_forward_op
,
is_backward_op
,
is_loss_op
world_process_group
=
get_world_process_group
()
world_process_group
=
get_world_process_group
()
class
AMPState
(
object
):
class
AMPState
(
object
):
def
__init__
(
self
,
block
):
def
__init__
(
self
,
block
):
self
.
_block
=
block
self
.
_block
=
block
self
.
_op_fp16_dict
=
{
self
.
_op_fp16_dict
=
(
}
# op_id --> True/False. 'True' means that the current op is in fp16 mode.
{}
)
# op_id --> True/False. 'True' means that the current op is in fp16 mode.
self
.
_var_name_dict
=
{}
# fwd_op_id --> {old_name: cast_name}
self
.
_var_name_dict
=
{}
# fwd_op_id --> {old_name: cast_name}
self
.
is_train
=
False
self
.
is_train
=
False
...
@@ -55,7 +78,8 @@ class AMPState(object):
...
@@ -55,7 +78,8 @@ class AMPState(object):
elif
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Backward
):
elif
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Backward
):
if
op
.
desc
.
original_id
()
in
dist_op_context
.
grad_op_id_to_op_id
:
if
op
.
desc
.
original_id
()
in
dist_op_context
.
grad_op_id_to_op_id
:
fwd_op_id
=
dist_op_context
.
grad_op_id_to_op_id
[
fwd_op_id
=
dist_op_context
.
grad_op_id_to_op_id
[
op
.
desc
.
original_id
()]
op
.
desc
.
original_id
()
]
if
self
.
_is_fp16_op
(
fwd_op_id
)
==
True
:
if
self
.
_is_fp16_op
(
fwd_op_id
)
==
True
:
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
True
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
True
elif
self
.
_is_fp16_op
(
fwd_op_id
)
==
False
:
elif
self
.
_is_fp16_op
(
fwd_op_id
)
==
False
:
...
@@ -78,7 +102,8 @@ class AMPState(object):
...
@@ -78,7 +102,8 @@ class AMPState(object):
if
op
.
type
==
'create_py_reader'
or
op
.
type
==
'read'
:
if
op
.
type
==
'create_py_reader'
or
op
.
type
==
'read'
:
continue
continue
if
amp_lists
.
black_varnames
is
not
None
and
_is_in_black_varnames
(
if
amp_lists
.
black_varnames
is
not
None
and
_is_in_black_varnames
(
op
,
amp_lists
):
op
,
amp_lists
):
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
False
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
False
continue
continue
if
op
.
type
in
amp_lists
.
black_list
:
if
op
.
type
in
amp_lists
.
black_list
:
...
@@ -98,17 +123,24 @@ class AMPState(object):
...
@@ -98,17 +123,24 @@ class AMPState(object):
continue
continue
elif
in_var
.
op
is
op
:
elif
in_var
.
op
is
op
:
prev_op
=
find_true_prev_op
(
prev_op
=
find_true_prev_op
(
ops
,
op
,
in_var_name
)
ops
,
op
,
in_var_name
)
if
prev_op
is
None
:
if
prev_op
is
None
:
continue
continue
else
:
else
:
prev_op
=
in_var
.
op
prev_op
=
in_var
.
op
# if it's one of inputs
# if it's one of inputs
if
self
.
_is_fp16_op
(
prev_op
.
desc
.
original_id
())
==
False
or
\
if
(
prev_op
.
type
in
amp_lists
.
black_list
:
self
.
_is_fp16_op
(
prev_op
.
desc
.
original_id
())
==
False
or
prev_op
.
type
in
amp_lists
.
black_list
):
is_black_op
=
True
is_black_op
=
True
elif
self
.
_is_fp16_op
(
prev_op
.
desc
.
original_id
())
==
True
or
\
elif
(
prev_op
.
type
in
amp_lists
.
white_list
:
self
.
_is_fp16_op
(
prev_op
.
desc
.
original_id
())
==
True
or
prev_op
.
type
in
amp_lists
.
white_list
):
is_white_op
=
True
is_white_op
=
True
if
is_black_op
:
if
is_black_op
:
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
False
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
False
...
@@ -131,19 +163,28 @@ class AMPState(object):
...
@@ -131,19 +163,28 @@ class AMPState(object):
break
break
if
self
.
_is_fp16_op
(
op
.
desc
.
original_id
())
==
False
:
if
self
.
_is_fp16_op
(
op
.
desc
.
original_id
())
==
False
:
num_cast_ops
=
self
.
_insert_cast_op_forward
(
num_cast_ops
=
self
.
_insert_cast_op_forward
(
op
,
idx
,
core
.
VarDesc
.
VarType
.
FP16
,
op
,
core
.
VarDesc
.
VarType
.
FP32
,
dist_context
)
idx
,
core
.
VarDesc
.
VarType
.
FP16
,
core
.
VarDesc
.
VarType
.
FP32
,
dist_context
,
)
elif
self
.
_is_fp16_op
(
op
.
desc
.
original_id
())
==
True
:
elif
self
.
_is_fp16_op
(
op
.
desc
.
original_id
())
==
True
:
num_cast_ops
=
self
.
_insert_cast_op_forward
(
num_cast_ops
=
self
.
_insert_cast_op_forward
(
op
,
idx
,
core
.
VarDesc
.
VarType
.
FP32
,
op
,
core
.
VarDesc
.
VarType
.
FP16
,
dist_context
)
idx
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP16
,
dist_context
,
)
else
:
else
:
pass
pass
idx
+=
num_cast_ops
+
1
idx
+=
num_cast_ops
+
1
self
.
_block
.
_sync_with_cpp
()
self
.
_block
.
_sync_with_cpp
()
def
_insert_cast_op_forward
(
self
,
op
,
idx
,
src_dtype
,
dst_dtype
,
def
_insert_cast_op_forward
(
dist_context
):
self
,
op
,
idx
,
src_dtype
,
dst_dtype
,
dist_context
):
"""
"""
only for forward cast
only for forward cast
modified from paddle.fluid.contrib.mixed_precision
modified from paddle.fluid.contrib.mixed_precision
...
@@ -152,38 +193,45 @@ class AMPState(object):
...
@@ -152,38 +193,45 @@ class AMPState(object):
var_name_dict
=
{}
var_name_dict
=
{}
for
in_name
in
op
.
input_names
:
for
in_name
in
op
.
input_names
:
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
_keep_fp32_input
(
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
_keep_fp32_input
(
op
,
in_name
):
op
,
in_name
):
continue
continue
for
in_var_name
in
op
.
input
(
in_name
):
for
in_var_name
in
op
.
input
(
in_name
):
in_var
=
self
.
_block
.
_find_var_recursive
(
in_var_name
)
in_var
=
self
.
_block
.
_find_var_recursive
(
in_var_name
)
if
in_var
.
type
not
in
_valid_types
or
in_var
.
dtype
==
dst_dtype
:
if
in_var
.
type
not
in
_valid_types
or
in_var
.
dtype
==
dst_dtype
:
continue
continue
if
in_var
.
dtype
==
src_dtype
:
if
in_var
.
dtype
==
src_dtype
:
cast_name
=
in_var
.
name
+
'.cast_'
+
_dtype_to_str
(
cast_name
=
(
dst_dtype
)
in_var
.
name
+
'.cast_'
+
_dtype_to_str
(
dst_dtype
)
)
out_var
=
self
.
_block
.
vars
.
get
(
cast_name
)
out_var
=
self
.
_block
.
vars
.
get
(
cast_name
)
var_name_dict
[
in_var
.
name
]
=
cast_name
var_name_dict
[
in_var
.
name
]
=
cast_name
consume_op_attr
=
dist_context
.
get_op_dist_attr_for_program
(
consume_op_attr
=
dist_context
.
get_op_dist_attr_for_program
(
op
)
op
)
assert
consume_op_attr
is
not
None
assert
consume_op_attr
is
not
None
if
out_var
is
None
or
out_var
.
dtype
!=
dst_dtype
:
if
out_var
is
None
or
out_var
.
dtype
!=
dst_dtype
:
# NOTE we make the cast op and var's dist attr as the op that consume the
# NOTE we make the cast op and var's dist attr as the op that consume the
# cast var instead of the op which generates the var
# cast var instead of the op which generates the var
in_var_dist_attr
=
consume_op_attr
.
get_input_dist_attr
(
in_var_dist_attr
=
consume_op_attr
.
get_input_dist_attr
(
in_var
.
name
)
in_var
.
name
)
assert
in_var_dist_attr
is
not
None
assert
in_var_dist_attr
is
not
None
ref_mesh
=
in_var_dist_attr
.
process_mesh
ref_mesh
=
in_var_dist_attr
.
process_mesh
ref_mapping
=
in_var_dist_attr
.
dims_mapping
ref_mapping
=
in_var_dist_attr
.
dims_mapping
consume_op_attr
.
set_input_dist_attr
(
consume_op_attr
.
set_input_dist_attr
(
cast_name
,
in_var_dist_attr
)
cast_name
,
in_var_dist_attr
)
out_var
=
self
.
_block
.
create_var
(
out_var
=
self
.
_block
.
create_var
(
name
=
cast_name
,
name
=
cast_name
,
dtype
=
dst_dtype
,
dtype
=
dst_dtype
,
persistable
=
False
,
persistable
=
False
,
stop_gradient
=
in_var
.
stop_gradient
)
stop_gradient
=
in_var
.
stop_gradient
,
set_var_dist_attr
(
dist_context
,
out_var
,
ref_mapping
,
)
ref_mesh
)
set_var_dist_attr
(
dist_context
,
out_var
,
ref_mapping
,
ref_mesh
)
cast_op
=
self
.
_block
.
_insert_op_without_sync
(
cast_op
=
self
.
_block
.
_insert_op_without_sync
(
idx
,
idx
,
...
@@ -193,22 +241,29 @@ class AMPState(object):
...
@@ -193,22 +241,29 @@ class AMPState(object):
attrs
=
{
attrs
=
{
"in_dtype"
:
in_var
.
dtype
,
"in_dtype"
:
in_var
.
dtype
,
"out_dtype"
:
out_var
.
dtype
,
"out_dtype"
:
out_var
.
dtype
,
})
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
cast_op
,
ref_mesh
,
ref_mapping
,
dist_context
)
cast_op
,
ref_mesh
,
ref_mapping
,
dist_context
)
num_cast_ops
+=
1
num_cast_ops
+=
1
else
:
else
:
in_var_dist_attr
=
consume_op_attr
.
get_input_dist_attr
(
in_var_dist_attr
=
consume_op_attr
.
get_input_dist_attr
(
in_var
.
name
)
in_var
.
name
)
consume_op_attr
.
set_input_dist_attr
(
consume_op_attr
.
set_input_dist_attr
(
cast_name
,
in_var_dist_attr
)
cast_name
,
in_var_dist_attr
)
_rename_arg
(
op
,
in_var
.
name
,
cast_name
)
_rename_arg
(
op
,
in_var
.
name
,
cast_name
)
else
:
else
:
if
op
.
has_attr
(
'in_dtype'
):
if
op
.
has_attr
(
'in_dtype'
):
op
.
_set_attr
(
'in_dtype'
,
dst_dtype
)
op
.
_set_attr
(
'in_dtype'
,
dst_dtype
)
self
.
_var_name_dict
[
op
.
desc
.
original_id
()]
=
var_name_dict
self
.
_var_name_dict
[
op
.
desc
.
original_id
()]
=
var_name_dict
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
dst_dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
(
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
dst_dtype
==
core
.
VarDesc
.
VarType
.
FP16
):
for
out_name
in
op
.
output_names
:
for
out_name
in
op
.
output_names
:
if
_keep_fp32_output
(
op
,
out_name
):
if
_keep_fp32_output
(
op
,
out_name
):
continue
continue
...
@@ -238,8 +293,9 @@ class AMPState(object):
...
@@ -238,8 +293,9 @@ class AMPState(object):
# NOTE: the map in `grad_var_to_var` may be changed when the var is casted,
# NOTE: the map in `grad_var_to_var` may be changed when the var is casted,
# which will affect the dist_op to insert allreduce_sum op.
# which will affect the dist_op to insert allreduce_sum op.
op_dist_attr
=
dist_context
.
get_op_dist_attr_for_program
(
grad_op
)
op_dist_attr
=
dist_context
.
get_op_dist_attr_for_program
(
grad_op
)
if
is_backward_op
(
grad_op
)
and
(
is_forward_op
(
ops
[
idx
-
1
])
if
is_backward_op
(
grad_op
)
and
(
or
is_loss_op
(
ops
[
idx
-
1
])):
is_forward_op
(
ops
[
idx
-
1
])
or
is_loss_op
(
ops
[
idx
-
1
])
):
if
not
op_dist_attr
.
is_recompute
:
if
not
op_dist_attr
.
is_recompute
:
appended_grad_times
+=
1
appended_grad_times
+=
1
...
@@ -248,14 +304,22 @@ class AMPState(object):
...
@@ -248,14 +304,22 @@ class AMPState(object):
if
grad_op_orig_id
in
dist_op_context
.
grad_op_id_to_op_id
:
if
grad_op_orig_id
in
dist_op_context
.
grad_op_id_to_op_id
:
if
self
.
_is_fp16_op
(
grad_op_orig_id
)
==
False
:
# fp32
if
self
.
_is_fp16_op
(
grad_op_orig_id
)
==
False
:
# fp32
num_cast_ops
=
self
.
_insert_cast_op_backward
(
num_cast_ops
=
self
.
_insert_cast_op_backward
(
grad_op
,
idx
,
core
.
VarDesc
.
VarType
.
FP16
,
grad_op
,
core
.
VarDesc
.
VarType
.
FP32
,
dist_context
,
idx
,
appended_grad_times
)
core
.
VarDesc
.
VarType
.
FP16
,
core
.
VarDesc
.
VarType
.
FP32
,
dist_context
,
appended_grad_times
,
)
elif
self
.
_is_fp16_op
(
grad_op_orig_id
)
==
True
:
# fp16
elif
self
.
_is_fp16_op
(
grad_op_orig_id
)
==
True
:
# fp16
num_cast_ops
=
self
.
_insert_cast_op_backward
(
num_cast_ops
=
self
.
_insert_cast_op_backward
(
grad_op
,
idx
,
core
.
VarDesc
.
VarType
.
FP32
,
grad_op
,
core
.
VarDesc
.
VarType
.
FP16
,
dist_context
,
idx
,
appended_grad_times
)
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP16
,
dist_context
,
appended_grad_times
,
)
elif
grad_op
.
type
==
"sum"
:
elif
grad_op
.
type
==
"sum"
:
in_var_name
=
grad_op
.
desc
.
input_arg_names
()[
0
]
in_var_name
=
grad_op
.
desc
.
input_arg_names
()[
0
]
src_dtype
=
self
.
_block
.
var
(
in_var_name
).
dtype
src_dtype
=
self
.
_block
.
var
(
in_var_name
).
dtype
...
@@ -270,15 +334,24 @@ class AMPState(object):
...
@@ -270,15 +334,24 @@ class AMPState(object):
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"'{}' op is not supported in the complete amp pass."
.
format
(
"'{}' op is not supported in the complete amp pass."
.
format
(
grad_op
.
type
))
grad_op
.
type
)
)
idx
+=
num_cast_ops
+
1
idx
+=
num_cast_ops
+
1
self
.
_block
.
_sync_with_cpp
()
self
.
_block
.
_sync_with_cpp
()
_update_backward_cast_ops
(
params_grads
,
dist_context
)
_update_backward_cast_ops
(
params_grads
,
dist_context
)
def
_insert_cast_op_backward
(
self
,
grad_op
,
idx
,
src_dtype
,
dst_dtype
,
def
_insert_cast_op_backward
(
dist_context
,
appended_grad_times
):
self
,
""" only for backward cast """
grad_op
,
idx
,
src_dtype
,
dst_dtype
,
dist_context
,
appended_grad_times
,
):
"""only for backward cast"""
def
_keep_fp32_input
(
op
,
in_name
):
def
_keep_fp32_input
(
op
,
in_name
):
op_type
=
op
.
type
op_type
=
op
.
type
...
@@ -299,7 +372,8 @@ class AMPState(object):
...
@@ -299,7 +372,8 @@ class AMPState(object):
for
in_name
in
grad_op
.
input_names
:
for
in_name
in
grad_op
.
input_names
:
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
_keep_fp32_input
(
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
_keep_fp32_input
(
grad_op
,
in_name
):
grad_op
,
in_name
):
for
in_var_name
in
grad_op
.
input
(
in_name
):
for
in_var_name
in
grad_op
.
input
(
in_name
):
in_var
=
self
.
_block
.
_find_var_recursive
(
in_var_name
)
in_var
=
self
.
_block
.
_find_var_recursive
(
in_var_name
)
assert
in_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
assert
in_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
...
@@ -309,24 +383,34 @@ class AMPState(object):
...
@@ -309,24 +383,34 @@ class AMPState(object):
in_var
=
self
.
_block
.
_find_var_recursive
(
in_var_name
)
in_var
=
self
.
_block
.
_find_var_recursive
(
in_var_name
)
if
in_var
.
dtype
==
src_dtype
:
if
in_var
.
dtype
==
src_dtype
:
consume_op_attr
=
dist_context
.
get_op_dist_attr_for_program
(
consume_op_attr
=
dist_context
.
get_op_dist_attr_for_program
(
grad_op
)
grad_op
)
if
in_var_name
in
self
.
_var_name_dict
[
fwd_op_id
]:
if
in_var_name
in
self
.
_var_name_dict
[
fwd_op_id
]:
# NOTE: if in_var of consume grad_op has been casted before,
# NOTE: if in_var of consume grad_op has been casted before,
# it should be renamed and reset dist_attr.
# it should be renamed and reset dist_attr.
cast_name
=
self
.
_var_name_dict
[
fwd_op_id
][
in_var_name
]
cast_name
=
self
.
_var_name_dict
[
fwd_op_id
][
in_var_name
]
grad_op
.
desc
.
_rename_input
(
in_var_name
,
cast_name
)
grad_op
.
desc
.
_rename_input
(
in_var_name
,
cast_name
)
in_var_dist_attr
=
consume_op_attr
.
get_input_dist_attr
(
in_var_dist_attr
=
consume_op_attr
.
get_input_dist_attr
(
in_var_name
)
in_var_name
)
consume_op_attr
.
set_input_dist_attr
(
consume_op_attr
.
set_input_dist_attr
(
cast_name
,
in_var_dist_attr
)
cast_name
,
in_var_dist_attr
)
else
:
else
:
assert
in_var
.
dtype
==
dst_dtype
,
"op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}"
.
format
(
assert
(
grad_op
.
type
,
in_name
,
dst_dtype
,
in_var
.
dtype
,
in_var
.
dtype
==
dst_dtype
str
(
grad_op
))
),
"op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}"
.
format
(
grad_op
.
type
,
in_name
,
dst_dtype
,
in_var
.
dtype
,
str
(
grad_op
),
)
for
out_name
in
grad_op
.
output_names
:
for
out_name
in
grad_op
.
output_names
:
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
_keep_fp32_output
(
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
_keep_fp32_output
(
grad_op
,
out_name
):
grad_op
,
out_name
):
for
out_var_name
in
grad_op
.
output
(
out_name
):
for
out_var_name
in
grad_op
.
output
(
out_name
):
out_var
=
self
.
_block
.
_find_var_recursive
(
out_var_name
)
out_var
=
self
.
_block
.
_find_var_recursive
(
out_var_name
)
assert
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
assert
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
...
@@ -334,7 +418,7 @@ class AMPState(object):
...
@@ -334,7 +418,7 @@ class AMPState(object):
for
out_var_name
in
grad_op
.
output
(
out_name
):
for
out_var_name
in
grad_op
.
output
(
out_name
):
out_var
=
self
.
_block
.
_find_var_recursive
(
out_var_name
)
out_var
=
self
.
_block
.
_find_var_recursive
(
out_var_name
)
out_var_name_prefix
=
out_var_name
[:
out_var_name
.
find
(
"@"
)]
out_var_name_prefix
=
out_var_name
[:
out_var_name
.
find
(
"@"
)]
fwd_var
=
self
.
_block
.
_find_var_recursive
(
out_var_name_prefix
)
fwd_var
=
self
.
_block
.
_find_var_recursive
(
out_var_name_prefix
)
# NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype
# NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype
if
out_var
.
dtype
!=
fwd_var
.
dtype
:
if
out_var
.
dtype
!=
fwd_var
.
dtype
:
...
@@ -345,34 +429,45 @@ class AMPState(object):
...
@@ -345,34 +429,45 @@ class AMPState(object):
# NOTE: if out_var of consume grad_op has been casted before,
# NOTE: if out_var of consume grad_op has been casted before,
# it should be renamed and reset dist_attr, then we insert cast op to
# it should be renamed and reset dist_attr, then we insert cast op to
# convert the cast_var to original dtype
# convert the cast_var to original dtype
consume_op_attr
=
dist_context
.
get_op_dist_attr_for_program
(
consume_op_attr
=
(
grad_op
)
dist_context
.
get_op_dist_attr_for_program
(
grad_op
)
)
fwd_cast_name
=
self
.
_var_name_dict
[
fwd_op_id
][
fwd_cast_name
=
self
.
_var_name_dict
[
fwd_op_id
][
out_var_name_prefix
]
out_var_name_prefix
]
suffix
=
""
suffix
=
""
if
"@RENAME"
in
out_var_name
:
if
"@RENAME"
in
out_var_name
:
suffix
=
out_var_name
[
out_var_name
.
find
(
"@RENAME"
):]
suffix
=
out_var_name
[
out_var_name
.
find
(
"@RENAME"
)
:
]
cast_name
=
fwd_cast_name
+
"@GRAD"
+
suffix
cast_name
=
fwd_cast_name
+
"@GRAD"
+
suffix
cast_var
=
self
.
_block
.
vars
.
get
(
cast_name
)
cast_var
=
self
.
_block
.
vars
.
get
(
cast_name
)
if
cast_var
is
None
or
cast_var
.
dtype
!=
dst_dtype
:
if
cast_var
is
None
or
cast_var
.
dtype
!=
dst_dtype
:
grad_op
.
desc
.
_rename_output
(
out_var_name
,
cast_name
)
grad_op
.
desc
.
_rename_output
(
out_var_name
,
cast_name
)
out_var_dist_attr
=
consume_op_attr
.
get_output_dist_attr
(
out_var_dist_attr
=
(
out_var_name
)
consume_op_attr
.
get_output_dist_attr
(
out_var_name
)
)
ref_mesh
=
out_var_dist_attr
.
process_mesh
ref_mesh
=
out_var_dist_attr
.
process_mesh
ref_mapping
=
out_var_dist_attr
.
dims_mapping
ref_mapping
=
out_var_dist_attr
.
dims_mapping
consume_op_attr
.
set_output_dist_attr
(
consume_op_attr
.
set_output_dist_attr
(
cast_name
,
out_var_dist_attr
)
cast_name
,
out_var_dist_attr
)
assert
ref_mapping
is
not
None
assert
ref_mapping
is
not
None
cast_var
=
self
.
_block
.
create_var
(
cast_var
=
self
.
_block
.
create_var
(
name
=
cast_name
,
name
=
cast_name
,
shape
=
out_var
.
shape
,
shape
=
out_var
.
shape
,
dtype
=
dst_dtype
,
dtype
=
dst_dtype
,
persistable
=
False
,
persistable
=
False
,
stop_gradient
=
out_var
.
stop_gradient
)
stop_gradient
=
out_var
.
stop_gradient
,
set_var_dist_attr
(
dist_context
,
cast_var
,
)
ref_mapping
,
ref_mesh
)
set_var_dist_attr
(
dist_context
,
cast_var
,
ref_mapping
,
ref_mesh
)
dist_op_context
.
grad_var_to_var
[
dist_op_context
.
grad_var_to_var
[
appended_grad_times
][
cast_name
]
=
fwd_cast_name
appended_grad_times
][
cast_name
]
=
fwd_cast_name
cast_op
=
self
.
_block
.
_insert_op
(
cast_op
=
self
.
_block
.
_insert_op
(
idx
+
1
,
idx
+
1
,
...
@@ -382,13 +477,15 @@ class AMPState(object):
...
@@ -382,13 +477,15 @@ class AMPState(object):
attrs
=
{
attrs
=
{
"in_dtype"
:
cast_var
.
dtype
,
"in_dtype"
:
cast_var
.
dtype
,
"out_dtype"
:
out_var
.
dtype
,
"out_dtype"
:
out_var
.
dtype
,
"op_role"
:
OpRole
.
Backward
"op_role"
:
OpRole
.
Backward
,
})
},
)
cast_op
.
_remove_attr
(
"op_role_var"
)
cast_op
.
_remove_attr
(
"op_role_var"
)
cast_op
.
_remove_attr
(
"op_namescope"
)
cast_op
.
_remove_attr
(
"op_namescope"
)
cast_op
.
_remove_attr
(
"with_quant_attr"
)
cast_op
.
_remove_attr
(
"with_quant_attr"
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
cast_op
,
ref_mesh
,
ref_mapping
,
dist_context
)
cast_op
,
ref_mesh
,
ref_mapping
,
dist_context
)
num_cast_ops
+=
1
num_cast_ops
+=
1
else
:
else
:
assert
out_var
.
dtype
==
dst_dtype
assert
out_var
.
dtype
==
dst_dtype
...
@@ -409,15 +506,18 @@ def _update_backward_cast_ops(params_grads, dist_context):
...
@@ -409,15 +506,18 @@ def _update_backward_cast_ops(params_grads, dist_context):
for
p
,
g
in
params_grads
:
for
p
,
g
in
params_grads
:
op
=
g
.
op
op
=
g
.
op
if
g
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
op
.
type
==
'cast'
:
if
g
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
op
.
type
==
'cast'
:
if
int
(
op
.
attr
(
'op_role'
))
==
int
(
if
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Backward
)
and
op
.
has_attr
(
OpRole
.
Backward
)
and
op
.
has_attr
(
'op_role_var'
):
'op_role_var'
):
op
.
_remove_attr
(
"op_role_var"
)
op
.
_remove_attr
(
"op_role_var"
)
post_ops
=
find_true_post_op
(
main_block
.
ops
,
op
,
g
.
name
)
post_ops
=
find_true_post_op
(
main_block
.
ops
,
op
,
g
.
name
)
if
post_ops
:
if
post_ops
:
raise
ValueError
(
"The cast op {0}'s output should not be"
raise
ValueError
(
"used by a non-optimize op, however, it"
"The cast op {0}'s output should not be"
"is used by {1}"
.
format
(
op
,
post_ops
[
0
]))
"used by a non-optimize op, however, it"
"is used by {1}"
.
format
(
op
,
post_ops
[
0
])
)
if
op
==
main_block
.
ops
[
-
1
]:
if
op
==
main_block
.
ops
[
-
1
]:
continue
continue
...
@@ -425,23 +525,29 @@ def _update_backward_cast_ops(params_grads, dist_context):
...
@@ -425,23 +525,29 @@ def _update_backward_cast_ops(params_grads, dist_context):
# add new op in the python and cpp at the same time
# add new op in the python and cpp at the same time
new_op_desc
=
main_block
.
desc
.
append_op
()
new_op_desc
=
main_block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
op
.
desc
)
new_op_desc
.
copy_from
(
op
.
desc
)
new_op
=
paddle
.
fluid
.
framework
.
Operator
(
block
=
main_block
,
new_op
=
paddle
.
fluid
.
framework
.
Operator
(
desc
=
new_op_desc
,
block
=
main_block
,
type
=
None
,
desc
=
new_op_desc
,
inputs
=
None
,
type
=
None
,
outputs
=
None
,
inputs
=
None
,
attrs
=
None
)
outputs
=
None
,
attrs
=
None
,
)
main_block
.
ops
.
append
(
new_op
)
main_block
.
ops
.
append
(
new_op
)
# dist attr
# dist attr
param_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
p
)
param_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
p
)
output_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
output_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
main_block
.
var
(
op
.
output_arg_names
[
0
]))
main_block
.
var
(
op
.
output_arg_names
[
0
])
)
assert
param_dist_attr
is
not
None
assert
param_dist_attr
is
not
None
assert
output_dist_attr
is
not
None
assert
output_dist_attr
is
not
None
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
new_op
,
param_dist_attr
.
process_mesh
,
new_op
,
param_dist_attr
.
dims_mapping
,
dist_context
)
param_dist_attr
.
process_mesh
,
param_dist_attr
.
dims_mapping
,
dist_context
,
)
output_dist_attr
.
process_mesh
=
param_dist_attr
.
process_mesh
output_dist_attr
.
process_mesh
=
param_dist_attr
.
process_mesh
output_dist_attr
.
dims_mapping
=
param_dist_attr
.
dims_mapping
output_dist_attr
.
dims_mapping
=
param_dist_attr
.
dims_mapping
...
@@ -462,26 +568,34 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
...
@@ -462,26 +568,34 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
grads
=
[
g
for
_
,
g
in
params_grads
]
grads
=
[
g
for
_
,
g
in
params_grads
]
check_type
(
grads
,
'x'
,
(
tuple
,
list
),
'check_finite_and_unscale'
)
check_type
(
grads
,
'x'
,
(
tuple
,
list
),
'check_finite_and_unscale'
)
for
e
in
grads
:
for
e
in
grads
:
check_variable_and_dtype
(
e
,
"x"
,
[
'float16'
,
'float32'
,
'float64'
],
check_variable_and_dtype
(
'check_finite_and_unscale'
)
e
,
"x"
,
[
'float16'
,
'float32'
,
'float64'
],
'check_finite_and_unscale'
,
)
found_inf
=
main_block
.
create_var
(
found_inf
=
main_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
name
=
unique_name
.
generate_with_ignorable_key
(
[
'find_infinite_scale'
,
'tmp'
])),
"."
.
join
([
'find_infinite_scale'
,
'tmp'
])
),
shape
=
[
1
],
shape
=
[
1
],
dtype
=
'bool'
,
dtype
=
'bool'
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
persistable
=
False
,
stop_gradient
=
False
)
stop_gradient
=
False
,
)
set_var_dist_attr
(
dist_context
,
found_inf
,
[
-
1
],
world_process_group
.
ranks
)
set_var_dist_attr
(
dist_context
,
found_inf
,
[
-
1
],
world_process_group
.
ranks
)
inputs
=
{
'X'
:
grads
,
'Scale'
:
loss_scaling
}
inputs
=
{
'X'
:
grads
,
'Scale'
:
loss_scaling
}
outputs
=
{
'Out'
:
grads
,
'FoundInfinite'
:
found_inf
}
outputs
=
{
'Out'
:
grads
,
'FoundInfinite'
:
found_inf
}
attrs
=
{
'op_role'
:
OpRole
.
Optimize
}
attrs
=
{
'op_role'
:
OpRole
.
Optimize
}
new_op
=
main_block
.
append_op
(
type
=
'check_finite_and_unscale'
,
new_op
=
main_block
.
append_op
(
inputs
=
inputs
,
type
=
'check_finite_and_unscale'
,
outputs
=
outputs
,
inputs
=
inputs
,
attrs
=
attrs
)
outputs
=
outputs
,
attrs
=
attrs
,
)
new_op_dist_attr
=
OperatorDistributedAttribute
()
new_op_dist_attr
=
OperatorDistributedAttribute
()
new_op_dist_attr
.
process_mesh
=
world_process_group
.
ranks
new_op_dist_attr
.
process_mesh
=
world_process_group
.
ranks
...
@@ -491,17 +605,18 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
...
@@ -491,17 +605,18 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
for
g
in
grads
:
for
g
in
grads
:
g_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
g
)
g_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
g
)
assert
g_dist_attr
is
not
None
assert
g_dist_attr
is
not
None
new_op_dist_attr
.
set_input_dims_mapping
(
g
.
name
,
new_op_dist_attr
.
set_input_dims_mapping
(
g_dist_attr
.
dims_mapping
)
g
.
name
,
g_dist_attr
.
dims_mapping
new_op_dist_attr
.
set_output_dims_mapping
(
g
.
name
,
)
g_dist_attr
.
dims_mapping
)
new_op_dist_attr
.
set_output_dims_mapping
(
g
.
name
,
g_dist_attr
.
dims_mapping
)
dist_context
.
set_op_dist_attr_for_program
(
new_op
,
new_op_dist_attr
)
dist_context
.
set_op_dist_attr_for_program
(
new_op
,
new_op_dist_attr
)
return
grads
,
found_inf
return
grads
,
found_inf
@
register_pass
(
"auto_parallel_amp"
)
@
register_pass
(
"auto_parallel_amp"
)
class
AMPPass
(
PassBase
):
class
AMPPass
(
PassBase
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
AMPPass
,
self
).
__init__
()
super
(
AMPPass
,
self
).
__init__
()
self
.
set_attr
(
"loss"
,
None
)
self
.
set_attr
(
"loss"
,
None
)
...
@@ -517,6 +632,7 @@ class AMPPass(PassBase):
...
@@ -517,6 +632,7 @@ class AMPPass(PassBase):
self
.
set_attr
(
"use_dynamic_loss_scaling"
,
False
)
self
.
set_attr
(
"use_dynamic_loss_scaling"
,
False
)
self
.
set_attr
(
"input_data"
,
[])
self
.
set_attr
(
"input_data"
,
[])
self
.
set_attr
(
"params_grads"
,
[])
self
.
set_attr
(
"params_grads"
,
[])
self
.
set_attr
(
"dtype"
,
""
)
# fp16/bf16
self
.
_loss
=
None
self
.
_loss
=
None
self
.
_loss_scaling
=
None
self
.
_loss_scaling
=
None
self
.
_num_good_steps
=
None
self
.
_num_good_steps
=
None
...
@@ -524,6 +640,8 @@ class AMPPass(PassBase):
...
@@ -524,6 +640,8 @@ class AMPPass(PassBase):
self
.
_loss
=
None
self
.
_loss
=
None
def
_check_self
(
self
):
def
_check_self
(
self
):
if
self
.
get_attr
(
"dtype"
)
not
in
[
"float16"
,
"bfloat16"
]:
return
False
if
self
.
get_attr
(
"init_loss_scaling"
)
<
0
:
if
self
.
get_attr
(
"init_loss_scaling"
)
<
0
:
return
False
return
False
if
self
.
get_attr
(
"incr_every_n_steps"
)
<
0
:
if
self
.
get_attr
(
"incr_every_n_steps"
)
<
0
:
...
@@ -548,11 +666,13 @@ class AMPPass(PassBase):
...
@@ -548,11 +666,13 @@ class AMPPass(PassBase):
def
_apply_single_impl
(
self
,
main_program
,
startup_program
,
context
):
def
_apply_single_impl
(
self
,
main_program
,
startup_program
,
context
):
self
.
dist_context
=
self
.
get_attr
(
"dist_context"
)
self
.
dist_context
=
self
.
get_attr
(
"dist_context"
)
params_grads
=
self
.
get_attr
(
"params_grads"
)
params_grads
=
self
.
get_attr
(
"params_grads"
)
self
.
amp_dtype
=
self
.
get_attr
(
"dtype"
)
amp_lists
=
AutoMixedPrecisionLists
(
amp_lists
=
AutoMixedPrecisionLists
(
set
(
self
.
get_attr
(
"custom_white_list"
)),
set
(
self
.
get_attr
(
"custom_white_list"
)),
set
(
self
.
get_attr
(
"custom_black_list"
)),
set
(
self
.
get_attr
(
"custom_black_list"
)),
set
(
self
.
get_attr
(
"custom_black_varnames"
)))
set
(
self
.
get_attr
(
"custom_black_varnames"
)),
)
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
amp_state
=
AMPState
(
main_program
.
global_block
())
amp_state
=
AMPState
(
main_program
.
global_block
())
...
@@ -566,10 +686,13 @@ class AMPPass(PassBase):
...
@@ -566,10 +686,13 @@ class AMPPass(PassBase):
self
.
_init_amp_var
()
self
.
_init_amp_var
()
self
.
_scale_loss
()
self
.
_scale_loss
()
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
if
(
)
or
self
.
get_attr
(
"init_loss_scaling"
)
!=
1.0
:
self
.
get_attr
(
"use_dynamic_loss_scaling"
)
or
self
.
get_attr
(
"init_loss_scaling"
)
!=
1.0
):
grads
,
found_inf
=
_check_and_update_gradient
(
grads
,
found_inf
=
_check_and_update_gradient
(
params_grads
,
self
.
_loss_scaling
,
self
.
dist_context
)
params_grads
,
self
.
_loss_scaling
,
self
.
dist_context
)
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
):
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
):
self
.
_update_loss_scaling
(
grads
,
found_inf
)
self
.
_update_loss_scaling
(
grads
,
found_inf
)
...
@@ -580,9 +703,14 @@ class AMPPass(PassBase):
...
@@ -580,9 +703,14 @@ class AMPPass(PassBase):
shape
=
[
1
],
shape
=
[
1
],
value
=
self
.
get_attr
(
"init_loss_scaling"
),
value
=
self
.
get_attr
(
"init_loss_scaling"
),
dtype
=
'float32'
,
dtype
=
'float32'
,
persistable
=
True
)
persistable
=
True
,
set_var_dist_attr
(
self
.
dist_context
,
self
.
_loss_scaling
,
[
-
1
],
)
world_process_group
.
ranks
)
set_var_dist_attr
(
self
.
dist_context
,
self
.
_loss_scaling
,
[
-
1
],
world_process_group
.
ranks
,
)
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
):
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
):
self
.
_num_good_steps
=
paddle
.
static
.
create_global_var
(
self
.
_num_good_steps
=
paddle
.
static
.
create_global_var
(
...
@@ -590,18 +718,28 @@ class AMPPass(PassBase):
...
@@ -590,18 +718,28 @@ class AMPPass(PassBase):
shape
=
[
1
],
shape
=
[
1
],
value
=
0
,
value
=
0
,
dtype
=
'int32'
,
dtype
=
'int32'
,
persistable
=
True
)
persistable
=
True
,
set_var_dist_attr
(
self
.
dist_context
,
self
.
_num_good_steps
,
[
-
1
],
)
world_process_group
.
ranks
)
set_var_dist_attr
(
self
.
dist_context
,
self
.
_num_good_steps
,
[
-
1
],
world_process_group
.
ranks
,
)
self
.
_num_bad_steps
=
paddle
.
static
.
create_global_var
(
self
.
_num_bad_steps
=
paddle
.
static
.
create_global_var
(
name
=
unique_name
.
generate
(
"num_bad_steps"
),
name
=
unique_name
.
generate
(
"num_bad_steps"
),
shape
=
[
1
],
shape
=
[
1
],
value
=
0
,
value
=
0
,
dtype
=
'int32'
,
dtype
=
'int32'
,
persistable
=
True
)
persistable
=
True
,
set_var_dist_attr
(
self
.
dist_context
,
self
.
_num_bad_steps
,
[
-
1
],
)
world_process_group
.
ranks
)
set_var_dist_attr
(
self
.
dist_context
,
self
.
_num_bad_steps
,
[
-
1
],
world_process_group
.
ranks
,
)
def
_scale_loss
(
self
):
def
_scale_loss
(
self
):
...
@@ -613,7 +751,8 @@ class AMPPass(PassBase):
...
@@ -613,7 +751,8 @@ class AMPPass(PassBase):
assert
loss
is
not
None
assert
loss
is
not
None
loss_op
=
loss
.
op
loss_op
=
loss
.
op
loss_op_dist_attr
=
self
.
dist_context
.
get_op_dist_attr_for_program
(
loss_op_dist_attr
=
self
.
dist_context
.
get_op_dist_attr_for_program
(
loss_op
)
loss_op
)
if
loss
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP32
:
if
loss
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP32
:
# cast loss here will change the effective loss tensor for the computation graph
# cast loss here will change the effective loss tensor for the computation graph
...
@@ -626,10 +765,12 @@ class AMPPass(PassBase):
...
@@ -626,10 +765,12 @@ class AMPPass(PassBase):
tmp_name
=
unique_name
.
generate
(
loss
.
name
+
".cast_fp32"
)
tmp_name
=
unique_name
.
generate
(
loss
.
name
+
".cast_fp32"
)
cast_loss
=
main_block
.
create_var
(
name
=
tmp_name
,
dtype
=
dtype
)
cast_loss
=
main_block
.
create_var
(
name
=
tmp_name
,
dtype
=
dtype
)
loss_dist_attr
=
self
.
dist_context
.
get_tensor_dist_attr_for_program
(
loss_dist_attr
=
self
.
dist_context
.
get_tensor_dist_attr_for_program
(
loss
)
loss
)
ref_mesh
=
loss_op_dist_attr
.
process_mesh
ref_mesh
=
loss_op_dist_attr
.
process_mesh
self
.
dist_context
.
set_tensor_dist_attr_for_program
(
self
.
dist_context
.
set_tensor_dist_attr_for_program
(
cast_loss
,
loss_dist_attr
)
cast_loss
,
loss_dist_attr
)
loss_op_idx
=
find_op_index
(
main_block
.
desc
,
loss_op
.
desc
)
loss_op_idx
=
find_op_index
(
main_block
.
desc
,
loss_op
.
desc
)
cast_op
=
main_block
.
_insert_op
(
cast_op
=
main_block
.
_insert_op
(
...
@@ -641,16 +782,21 @@ class AMPPass(PassBase):
...
@@ -641,16 +782,21 @@ class AMPPass(PassBase):
"in_dtype"
:
loss
.
dtype
,
"in_dtype"
:
loss
.
dtype
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
,
'op_role'
:
loss_op
.
all_attrs
()[
OP_ROLE_KEY
],
'op_role'
:
loss_op
.
all_attrs
()[
OP_ROLE_KEY
],
})
},
)
loss_op
.
_set_attr
(
OP_ROLE_KEY
,
loss_op
.
_set_attr
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
OP_ROLE_KEY
,
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
cast_op
,
ref_mesh
,
[
-
1
],
self
.
dist_context
)
cast_op
,
ref_mesh
,
[
-
1
],
self
.
dist_context
)
loss
=
loss
.
astype
(
'float32'
)
loss
=
loss
.
astype
(
'float32'
)
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
if
self
.
amp_dtype
==
"float16"
and
(
)
or
self
.
get_attr
(
"init_loss_scaling"
)
!=
1.0
:
self
.
get_attr
(
"use_dynamic_loss_scaling"
)
or
self
.
get_attr
(
"init_loss_scaling"
)
!=
1.0
):
loss_op_idx
=
find_op_index
(
main_block
.
desc
,
loss_op
.
desc
)
loss_op_idx
=
find_op_index
(
main_block
.
desc
,
loss_op
.
desc
)
...
@@ -660,63 +806,76 @@ class AMPPass(PassBase):
...
@@ -660,63 +806,76 @@ class AMPPass(PassBase):
name
=
unique_name
.
generate
(
"scaled_loss"
),
name
=
unique_name
.
generate
(
"scaled_loss"
),
shape
=
loss
.
shape
,
shape
=
loss
.
shape
,
dtype
=
loss
.
dtype
,
dtype
=
loss
.
dtype
,
persistable
=
loss
.
persistable
)
persistable
=
loss
.
persistable
,
set_var_dist_attr
(
self
.
dist_context
,
self
.
_scaled_loss
,
[
-
1
],
)
ref_mesh
)
set_var_dist_attr
(
self
.
dist_context
,
self
.
_scaled_loss
,
[
-
1
],
ref_mesh
)
elementwise_mul_op
=
main_block
.
_insert_op
(
elementwise_mul_op
=
main_block
.
_insert_op
(
loss_op_idx
+
1
,
loss_op_idx
+
1
,
type
=
'elementwise_mul'
,
type
=
'elementwise_mul'
,
inputs
=
{
inputs
=
{
'X'
:
[
loss
],
'Y'
:
[
self
.
_loss_scaling
]},
'X'
:
[
loss
],
'Y'
:
[
self
.
_loss_scaling
]
},
outputs
=
{
'Out'
:
[
self
.
_scaled_loss
]},
outputs
=
{
'Out'
:
[
self
.
_scaled_loss
]},
attrs
=
{
attrs
=
{
'op_role'
:
loss_op
.
all_attrs
()[
OP_ROLE_KEY
],
'op_role'
:
loss_op
.
all_attrs
()[
OP_ROLE_KEY
],
})
},
loss_op
.
_set_attr
(
OP_ROLE_KEY
,
)
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
loss_op
.
_set_attr
(
OP_ROLE_KEY
,
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
elementwise_mul_op
,
ref_mesh
,
[
-
1
],
self
.
dist_context
)
elementwise_mul_op
,
ref_mesh
,
[
-
1
],
self
.
dist_context
)
# backward
# backward
first_backward_op
=
main_block
.
ops
[
loss_op_idx
+
2
]
first_backward_op
=
main_block
.
ops
[
loss_op_idx
+
2
]
assert
first_backward_op
.
type
==
"fill_constant"
and
int
(
assert
(
first_backward_op
.
all_attrs
()[
OP_ROLE_KEY
])
==
257
first_backward_op
.
type
==
"fill_constant"
and
int
(
first_backward_op
.
all_attrs
()[
OP_ROLE_KEY
])
==
257
)
self
.
_scaled_loss_grad
=
main_block
.
create_var
(
self
.
_scaled_loss_grad
=
main_block
.
create_var
(
name
=
unique_name
.
generate
(
"scaled_loss"
)
+
"@GRAD"
,
name
=
unique_name
.
generate
(
"scaled_loss"
)
+
"@GRAD"
,
shape
=
loss
.
shape
,
shape
=
loss
.
shape
,
dtype
=
loss
.
dtype
,
dtype
=
loss
.
dtype
,
persistable
=
loss
.
persistable
)
persistable
=
loss
.
persistable
,
set_var_dist_attr
(
self
.
dist_context
,
self
.
_scaled_loss_grad
,
[
-
1
],
)
ref_mesh
)
set_var_dist_attr
(
self
.
dist_context
,
self
.
_scaled_loss_grad
,
[
-
1
],
ref_mesh
)
pre_grad_name
=
first_backward_op
.
output_arg_names
[
0
]
pre_grad_name
=
first_backward_op
.
output_arg_names
[
0
]
first_backward_op
.
_rename_output
(
pre_grad_name
,
first_backward_op
.
_rename_output
(
self
.
_scaled_loss_grad
.
name
)
pre_grad_name
,
self
.
_scaled_loss_grad
.
name
)
# FIXME(JZ-LIANG) a trick to insert backward op
# FIXME(JZ-LIANG) a trick to insert backward op
main_block
.
_sync_with_cpp
()
main_block
.
_sync_with_cpp
()
elementwise_mul_grad_op_desc
=
main_block
.
desc
.
_insert_op
(
elementwise_mul_grad_op_desc
=
main_block
.
desc
.
_insert_op
(
loss_op_idx
+
3
)
loss_op_idx
+
3
)
elementwise_mul_grad_op_desc
.
set_type
(
"elementwise_mul_grad"
)
elementwise_mul_grad_op_desc
.
set_type
(
"elementwise_mul_grad"
)
elementwise_mul_grad_op_desc
.
set_input
(
elementwise_mul_grad_op_desc
.
set_input
(
'Out@GRAD'
,
[
self
.
_scaled_loss_grad
.
name
])
'Out@GRAD'
,
[
self
.
_scaled_loss_grad
.
name
]
)
elementwise_mul_grad_op_desc
.
set_input
(
'X'
,
[
loss
.
name
])
elementwise_mul_grad_op_desc
.
set_input
(
'X'
,
[
loss
.
name
])
elementwise_mul_grad_op_desc
.
set_input
(
'Y'
,
elementwise_mul_grad_op_desc
.
set_input
(
[
self
.
_loss_scaling
.
name
])
'Y'
,
[
self
.
_loss_scaling
.
name
]
)
elementwise_mul_grad_op_desc
.
set_output
(
'X@GRAD'
,
[
pre_grad_name
])
elementwise_mul_grad_op_desc
.
set_output
(
'X@GRAD'
,
[
pre_grad_name
])
elementwise_mul_grad_op_desc
.
set_output
(
'Y@GRAD'
,
[])
elementwise_mul_grad_op_desc
.
set_output
(
'Y@GRAD'
,
[])
elementwise_mul_grad_op_desc
.
_set_attr
(
elementwise_mul_grad_op_desc
.
_set_attr
(
OP_ROLE_KEY
,
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
)
OP_ROLE_KEY
,
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
)
elementwise_mul_grad_op_desc
.
_set_attr
(
'axis'
,
-
1
)
elementwise_mul_grad_op_desc
.
_set_attr
(
'axis'
,
-
1
)
elementwise_mul_grad_op
=
paddle
.
fluid
.
framework
.
Operator
(
elementwise_mul_grad_op
=
paddle
.
fluid
.
framework
.
Operator
(
main_block
,
elementwise_mul_grad_op_desc
)
main_block
,
elementwise_mul_grad_op_desc
)
main_block
.
ops
.
insert
(
loss_op_idx
+
3
,
elementwise_mul_grad_op
)
main_block
.
ops
.
insert
(
loss_op_idx
+
3
,
elementwise_mul_grad_op
)
main_block
.
_sync_with_cpp
()
main_block
.
_sync_with_cpp
()
elementwise_mul_grad_op
=
main_block
.
ops
[
loss_op_idx
+
3
]
elementwise_mul_grad_op
=
main_block
.
ops
[
loss_op_idx
+
3
]
assert
elementwise_mul_grad_op
.
type
==
"elementwise_mul_grad"
assert
elementwise_mul_grad_op
.
type
==
"elementwise_mul_grad"
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
elementwise_mul_grad_op
,
ref_mesh
,
[
-
1
],
self
.
dist_context
)
elementwise_mul_grad_op
,
ref_mesh
,
[
-
1
],
self
.
dist_context
)
else
:
else
:
self
.
_scaled_loss
=
loss
self
.
_scaled_loss
=
loss
...
@@ -728,31 +887,39 @@ class AMPPass(PassBase):
...
@@ -728,31 +887,39 @@ class AMPPass(PassBase):
main_block
=
paddle
.
static
.
default_main_program
().
global_block
()
main_block
=
paddle
.
static
.
default_main_program
().
global_block
()
main_block
.
_sync_with_cpp
()
main_block
.
_sync_with_cpp
()
check_variable_and_dtype
(
self
.
_loss_scaling
,
"prev_loss_scaling"
,
check_variable_and_dtype
(
[
'float32'
,
'float64'
],
"update_loss_scaling"
)
self
.
_loss_scaling
,
"prev_loss_scaling"
,
[
'float32'
,
'float64'
],
"update_loss_scaling"
,
)
check_type
(
grads
,
'x'
,
(
tuple
,
list
),
'update_loss_scaling'
)
check_type
(
grads
,
'x'
,
(
tuple
,
list
),
'update_loss_scaling'
)
for
e
in
grads
:
for
e
in
grads
:
check_variable_and_dtype
(
e
,
"x"
,
[
'float16'
,
'float32'
,
'float64'
],
check_variable_and_dtype
(
'update_loss_scaling'
)
e
,
"x"
,
[
'float16'
,
'float32'
,
'float64'
],
'update_loss_scaling'
)
if
e
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
e
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
assert
self
.
_loss_scaling
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
,
\
assert
(
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
self
.
_loss_scaling
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
),
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
else
:
else
:
assert
self
.
_loss_scaling
.
dtype
==
e
.
dtype
,
"The dtype of prev_loss_scaling should be equal to the dtype of x."
assert
(
self
.
_loss_scaling
.
dtype
==
e
.
dtype
),
"The dtype of prev_loss_scaling should be equal to the dtype of x."
inputs
=
{
inputs
=
{
'X'
:
grads
,
'X'
:
grads
,
'FoundInfinite'
:
found_inf
,
'FoundInfinite'
:
found_inf
,
'PrevLossScaling'
:
self
.
_loss_scaling
,
'PrevLossScaling'
:
self
.
_loss_scaling
,
'InGoodSteps'
:
self
.
_num_good_steps
,
'InGoodSteps'
:
self
.
_num_good_steps
,
'InBadSteps'
:
self
.
_num_bad_steps
'InBadSteps'
:
self
.
_num_bad_steps
,
}
}
outputs
=
{
outputs
=
{
'Out'
:
grads
,
'Out'
:
grads
,
'LossScaling'
:
self
.
_loss_scaling
,
'LossScaling'
:
self
.
_loss_scaling
,
'OutGoodSteps'
:
self
.
_num_good_steps
,
'OutGoodSteps'
:
self
.
_num_good_steps
,
'OutBadSteps'
:
self
.
_num_bad_steps
'OutBadSteps'
:
self
.
_num_bad_steps
,
}
}
attrs
=
{
attrs
=
{
...
@@ -761,13 +928,15 @@ class AMPPass(PassBase):
...
@@ -761,13 +928,15 @@ class AMPPass(PassBase):
'incr_ratio'
:
self
.
get_attr
(
"incr_ratio"
),
'incr_ratio'
:
self
.
get_attr
(
"incr_ratio"
),
'decr_ratio'
:
self
.
get_attr
(
"decr_ratio"
),
'decr_ratio'
:
self
.
get_attr
(
"decr_ratio"
),
'stop_update'
:
self
.
get_attr
(
"stop_update"
),
'stop_update'
:
self
.
get_attr
(
"stop_update"
),
'op_role'
:
OpRole
.
Optimize
'op_role'
:
OpRole
.
Optimize
,
}
}
new_op
=
main_block
.
append_op
(
type
=
'update_loss_scaling'
,
new_op
=
main_block
.
append_op
(
inputs
=
inputs
,
type
=
'update_loss_scaling'
,
outputs
=
outputs
,
inputs
=
inputs
,
attrs
=
attrs
)
outputs
=
outputs
,
attrs
=
attrs
,
)
new_op_dist_attr
=
OperatorDistributedAttribute
()
new_op_dist_attr
=
OperatorDistributedAttribute
()
new_op_dist_attr
.
process_mesh
=
world_process_group
.
ranks
new_op_dist_attr
.
process_mesh
=
world_process_group
.
ranks
...
@@ -777,10 +946,22 @@ class AMPPass(PassBase):
...
@@ -777,10 +946,22 @@ class AMPPass(PassBase):
for
g
in
grads
:
for
g
in
grads
:
g_dist_attr
=
self
.
dist_context
.
get_tensor_dist_attr_for_program
(
g
)
g_dist_attr
=
self
.
dist_context
.
get_tensor_dist_attr_for_program
(
g
)
assert
g_dist_attr
is
not
None
assert
g_dist_attr
is
not
None
new_op_dist_attr
.
set_input_dims_mapping
(
g
.
name
,
new_op_dist_attr
.
set_input_dims_mapping
(
g_dist_attr
.
dims_mapping
)
g
.
name
,
g_dist_attr
.
dims_mapping
new_op_dist_attr
.
set_output_dims_mapping
(
g
.
name
,
)
g_dist_attr
.
dims_mapping
)
new_op_dist_attr
.
set_output_dims_mapping
(
g
.
name
,
g_dist_attr
.
dims_mapping
)
self
.
dist_context
.
set_op_dist_attr_for_program
(
new_op
,
new_op_dist_attr
)
self
.
dist_context
.
set_op_dist_attr_for_program
(
new_op
,
new_op_dist_attr
)
main_block
.
_sync_with_cpp
()
main_block
.
_sync_with_cpp
()
def
get_loss
(
self
):
# the amp might change the effective loss variable for network and
# therefore would affect the subsequent passes that rely on the loss.
# return the effective loss after amp pass.
if
self
.
_loss
:
return
self
.
_loss
else
:
return
self
.
get_attr
(
"loss"
)
python/paddle/distributed/passes/auto_parallel_fp16.py
浏览文件 @
6f3c9643
...
@@ -27,14 +27,13 @@ from paddle.distributed.auto_parallel.utils import (
...
@@ -27,14 +27,13 @@ from paddle.distributed.auto_parallel.utils import (
from
paddle.distributed.auto_parallel.process_group
import
(
from
paddle.distributed.auto_parallel.process_group
import
(
get_world_process_group
,
get_world_process_group
,
)
)
from
paddle.fluid.contrib.mixed_precision.fp16_
util
s
import
(
from
paddle.fluid.contrib.mixed_precision.fp16_
list
s
import
(
AutoMixedPrecisionLists
,
AutoMixedPrecisionLists
,
)
)
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
(
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
(
_keep_layer_norm_scale_bias_to_fp32
,
_keep_layer_norm_scale_bias_to_fp32
,
_need_keep_fp32
,
_need_keep_fp32
,
_valid_types
,
_valid_types
,
_dtype_to_str
,
)
)
from
paddle.distributed.auto_parallel.dist_attribute
import
(
from
paddle.distributed.auto_parallel.dist_attribute
import
(
OperatorDistributedAttribute
,
OperatorDistributedAttribute
,
...
@@ -55,6 +54,23 @@ __amp_skip_ops__ = [
...
@@ -55,6 +54,23 @@ __amp_skip_ops__ = [
'while'
,
'while'
,
'cast'
,
'cast'
,
]
]
__target_dtype__
=
None
def
_dtype_to_str
(
dtype
):
"""
Convert specific variable type to its corresponding string.
Args:
dtype (VarType): Variable type.
"""
if
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
# TODO(Xreki): change the returned str to "bf16" for BF16 data type.
# Currently too many codes use "cast_fp16" as key.
return
'fp16'
elif
dtype
==
core
.
VarDesc
.
VarType
.
BF16
:
return
'bf16'
else
:
return
'fp32'
def
set_op_dtype_to_fp16
(
op
):
def
set_op_dtype_to_fp16
(
op
):
...
@@ -62,14 +78,20 @@ def set_op_dtype_to_fp16(op):
...
@@ -62,14 +78,20 @@ def set_op_dtype_to_fp16(op):
op
.
has_attr
(
'in_dtype'
)
op
.
has_attr
(
'in_dtype'
)
and
op
.
attr
(
'in_dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
and
op
.
attr
(
'in_dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
):
):
op
.
_set_attr
(
'in_dtype'
,
core
.
VarDesc
.
VarType
.
FP16
)
op
.
_set_attr
(
'in_dtype'
,
__target_dtype__
)
if
(
if
(
op
.
has_attr
(
'out_dtype'
)
op
.
has_attr
(
'out_dtype'
)
and
op
.
attr
(
'out_dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
and
op
.
attr
(
'out_dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
):
):
op
.
_set_attr
(
'out_dtype'
,
core
.
VarDesc
.
VarType
.
FP16
)
op
.
_set_attr
(
'out_dtype'
,
__target_dtype__
)
if
op
.
has_attr
(
'dtype'
)
and
op
.
attr
(
'dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
:
if
op
.
has_attr
(
'dtype'
)
and
op
.
attr
(
'dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
:
op
.
_set_attr
(
'dtype'
,
core
.
VarDesc
.
VarType
.
FP16
)
op
.
_set_attr
(
'dtype'
,
__target_dtype__
)
if
__target_dtype__
==
core
.
VarDesc
.
VarType
.
BF16
:
if
op
.
has_attr
(
'use_mkldnn'
):
op
.
_set_attr
(
'use_mkldnn'
,
True
)
if
op
.
has_attr
(
'mkldnn_data_type'
):
op
.
_set_attr
(
'mkldnn_data_type'
,
'bfloat16'
)
# adapot for backward op
# adapot for backward op
...
@@ -156,6 +178,7 @@ class FP16State(object):
...
@@ -156,6 +178,7 @@ class FP16State(object):
list
list
)
# {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]}
)
# {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]}
self
.
is_train
=
False
self
.
is_train
=
False
self
.
out_var_op_deps
=
{}
def
_is_fp16_op
(
self
,
op_id
):
def
_is_fp16_op
(
self
,
op_id
):
return
self
.
_op_fp16_dict
.
get
(
op_id
,
None
)
return
self
.
_op_fp16_dict
.
get
(
op_id
,
None
)
...
@@ -169,6 +192,14 @@ class FP16State(object):
...
@@ -169,6 +192,14 @@ class FP16State(object):
# assume all backward block are behind forward blocks
# assume all backward block are behind forward blocks
for
block
in
self
.
program
.
blocks
:
for
block
in
self
.
program
.
blocks
:
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
for
name
in
op
.
output_arg_names
:
if
name
not
in
self
.
out_var_op_deps
:
self
.
out_var_op_deps
[
name
]
=
[
op
.
desc
.
original_id
()]
else
:
self
.
out_var_op_deps
[
name
].
extend
(
[
op
.
desc
.
original_id
()]
)
self
.
_mark_op
(
op
)
self
.
_mark_op
(
op
)
# set forward tensor dtype
# set forward tensor dtype
...
@@ -192,6 +223,18 @@ class FP16State(object):
...
@@ -192,6 +223,18 @@ class FP16State(object):
if
op
.
type
==
"assign"
and
"array_"
in
op
.
input_arg_names
[
0
]:
if
op
.
type
==
"assign"
and
"array_"
in
op
.
input_arg_names
[
0
]:
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
False
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
False
return
return
# If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var.
if
op
.
type
==
"assign"
:
out_name
=
op
.
output_arg_names
[
0
]
if
len
(
self
.
out_var_op_deps
[
out_name
])
>
1
:
if
not
self
.
_op_fp16_dict
[
self
.
out_var_op_deps
[
out_name
][
0
]
]:
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
False
else
:
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
True
return
if
_need_keep_fp32
(
if
_need_keep_fp32
(
op
,
self
.
amp_list
.
unsupported_list
,
self
.
use_fp16_guard
op
,
self
.
amp_list
.
unsupported_list
,
self
.
use_fp16_guard
):
):
...
@@ -228,7 +271,7 @@ class FP16State(object):
...
@@ -228,7 +271,7 @@ class FP16State(object):
return
return
if
var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
if
var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP16
)
var
.
desc
.
set_dtype
(
__target_dtype__
)
def
resolute_tensor_dtype
(
self
,
block
):
def
resolute_tensor_dtype
(
self
,
block
):
...
@@ -260,7 +303,7 @@ class FP16State(object):
...
@@ -260,7 +303,7 @@ class FP16State(object):
out_var
=
block
.
vars
.
get
(
out_var_name
)
out_var
=
block
.
vars
.
get
(
out_var_name
)
if
out_var
is
None
or
out_var
.
type
not
in
_valid_types
:
if
out_var
is
None
or
out_var
.
type
not
in
_valid_types
:
continue
continue
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
out_var
.
dtype
==
__target_dtype__
:
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
elif
is_backward_op
(
op
):
elif
is_backward_op
(
op
):
if
self
.
_is_fp16_op
(
op
.
desc
.
original_id
())
==
True
:
if
self
.
_is_fp16_op
(
op
.
desc
.
original_id
())
==
True
:
...
@@ -276,7 +319,7 @@ class FP16State(object):
...
@@ -276,7 +319,7 @@ class FP16State(object):
out_var
=
block
.
vars
.
get
(
out_var_name
)
out_var
=
block
.
vars
.
get
(
out_var_name
)
if
out_var
is
None
or
out_var
.
type
not
in
_valid_types
:
if
out_var
is
None
or
out_var
.
type
not
in
_valid_types
:
continue
continue
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
out_var
.
dtype
==
__target_dtype__
:
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
def
cast_block
(
self
,
block
):
def
cast_block
(
self
,
block
):
...
@@ -295,7 +338,7 @@ class FP16State(object):
...
@@ -295,7 +338,7 @@ class FP16State(object):
op
,
op
,
idx
,
idx
,
block
,
block
,
core
.
VarDesc
.
VarType
.
FP16
,
__target_dtype__
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP32
,
self
.
dist_context
,
self
.
dist_context
,
)
)
...
@@ -305,7 +348,7 @@ class FP16State(object):
...
@@ -305,7 +348,7 @@ class FP16State(object):
idx
,
idx
,
block
,
block
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP16
,
__target_dtype__
,
self
.
dist_context
,
self
.
dist_context
,
)
)
elif
is_backward_op
(
op
):
elif
is_backward_op
(
op
):
...
@@ -315,7 +358,7 @@ class FP16State(object):
...
@@ -315,7 +358,7 @@ class FP16State(object):
op
,
op
,
idx
,
idx
,
block
,
block
,
core
.
VarDesc
.
VarType
.
FP16
,
__target_dtype__
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP32
,
self
.
dist_context
,
self
.
dist_context
,
)
)
...
@@ -325,7 +368,7 @@ class FP16State(object):
...
@@ -325,7 +368,7 @@ class FP16State(object):
idx
,
idx
,
block
,
block
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP16
,
__target_dtype__
,
self
.
dist_context
,
self
.
dist_context
,
)
)
elif
op
.
type
==
"sum"
:
elif
op
.
type
==
"sum"
:
...
@@ -399,6 +442,9 @@ class FP16State(object):
...
@@ -399,6 +442,9 @@ class FP16State(object):
dist_context
,
cast_var
,
ref_mapping
,
ref_mesh
dist_context
,
cast_var
,
ref_mapping
,
ref_mesh
)
)
op_namescope
=
"/"
if
op
.
has_attr
(
'op_namescope'
):
op_namescope
=
op
.
attr
(
'op_namescope'
)
cast_op
=
block
.
_insert_op_without_sync
(
cast_op
=
block
.
_insert_op_without_sync
(
idx
,
idx
,
type
=
"cast"
,
type
=
"cast"
,
...
@@ -410,6 +456,9 @@ class FP16State(object):
...
@@ -410,6 +456,9 @@ class FP16State(object):
OP_ROLE_KEY
:
OpRole
.
Forward
,
OP_ROLE_KEY
:
OpRole
.
Forward
,
},
},
)
)
cast_op
.
_set_attr
(
'op_namescope'
,
op_namescope
)
# for recompute
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
cast_op
,
ref_mesh
,
ref_mapping
,
dist_context
cast_op
,
ref_mesh
,
ref_mapping
,
dist_context
)
)
...
@@ -455,63 +504,79 @@ class FP16State(object):
...
@@ -455,63 +504,79 @@ class FP16State(object):
)
in
self
.
forward_input_cast_ops
[
forward_op_id
]:
)
in
self
.
forward_input_cast_ops
[
forward_op_id
]:
# rename input
# rename input
assert
src_name
in
op
.
input
(
# some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy
slot_name
if
op
.
type
!=
"scale"
and
slot_name
in
op
.
input_names
:
),
"var: {} not in op's {}. {}"
.
format
(
src_name
,
slot_name
,
str
(
op
))
src_var_dist_attr
=
grad_op_attr
.
get_input_dist_attr
(
src_name
)
assert
src_var_dist_attr
is
not
None
op
.
_rename_input
(
src_name
,
cast_name
)
grad_op_attr
.
set_input_dist_attr
(
cast_name
,
src_var_dist_attr
)
assert
src_name
in
op
.
input
(
slot_name
),
"var: {} not in op's {}. {}"
.
format
(
src_name
,
slot_name
,
str
(
op
)
)
src_var_dist_attr
=
grad_op_attr
.
get_input_dist_attr
(
src_name
)
assert
src_var_dist_attr
is
not
None
op
.
_rename_input
(
src_name
,
cast_name
)
grad_op_attr
.
set_input_dist_attr
(
cast_name
,
src_var_dist_attr
)
# NOTE Special for scale op, scale op's grad op is scale,
# so slot name map rule could not apply to grad scale op
# cast_name: mean_0.tmp_0.cast_bf16, src_name: mean_0.tmp_0, dst_dtype: paddle.bfloat16, src_dtype: paddle.float32, slot_name: X.
if
op
.
type
==
"scale"
:
grad_slot_name
=
"X"
# create cast grad
# create cast grad
grad_slot_name
=
slot_name
+
"@GRAD"
else
:
assert
grad_slot_name
in
op
.
output_names
grad_slot_name
=
slot_name
+
"@GRAD"
if
len
(
op
.
output
(
grad_slot_name
))
==
0
:
var
=
block
.
var
(
src_name
)
assert
var
.
stop_gradient
is
True
continue
assert
len
(
op
.
output
(
grad_slot_name
))
==
1
grad_name
=
op
.
output
(
grad_slot_name
)[
0
]
grad
=
block
.
var
(
grad_name
)
grad_dist_attr
=
grad_op_attr
.
get_output_dist_attr
(
grad_name
)
assert
grad_dist_attr
is
not
None
,
"{}"
.
format
(
grad_name
)
ref_mesh
=
grad_dist_attr
.
process_mesh
ref_mapping
=
grad_dist_attr
.
dims_mapping
cast_grad
=
block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
""
.
join
([
cast_name
,
'@GRAD'
])
),
dtype
=
dst_dtype
,
shape
=
grad
.
shape
,
type
=
grad
.
type
,
persistable
=
grad
.
persistable
,
stop_gradient
=
grad
.
stop_gradient
,
)
dist_context
.
set_tensor_dist_attr_for_program
(
cast_grad
,
grad_dist_attr
)
op
.
_rename_output
(
grad_name
,
cast_grad
.
name
)
grad_op_attr
.
set_output_dist_attr
(
cast_grad
.
name
,
grad_dist_attr
)
# add cast
cast_op
=
block
.
_insert_op_without_sync
(
idx
+
1
,
type
=
"cast"
,
inputs
=
{
"X"
:
[
cast_grad
.
name
]},
outputs
=
{
"Out"
:
[
grad
.
name
]},
attrs
=
{
"in_dtype"
:
dst_dtype
,
"out_dtype"
:
src_dtype
,
OP_ROLE_KEY
:
OpRole
.
Backward
,
},
)
grad
.
desc
.
set_dtype
(
src_dtype
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
if
grad_slot_name
in
op
.
output_names
:
cast_op
,
ref_mesh
,
ref_mapping
,
dist_context
# some forward input maybe stop_gradient=True, e.g. input_mask
)
if
len
(
op
.
output
(
grad_slot_name
))
==
0
:
num_cast_ops
+=
1
continue
assert
(
len
(
op
.
output
(
grad_slot_name
))
==
1
),
"[{}], Current Op: {}"
.
format
(
grad_slot_name
,
str
(
op
))
grad_name
=
op
.
output
(
grad_slot_name
)[
0
]
grad
=
block
.
var
(
grad_name
)
grad_dist_attr
=
grad_op_attr
.
get_output_dist_attr
(
grad_name
)
assert
grad_dist_attr
is
not
None
,
"{}"
.
format
(
grad_name
)
ref_mesh
=
grad_dist_attr
.
process_mesh
ref_mapping
=
grad_dist_attr
.
dims_mapping
cast_grad
=
block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
""
.
join
([
cast_name
,
'@GRAD'
])
),
dtype
=
dst_dtype
,
shape
=
grad
.
shape
,
type
=
grad
.
type
,
persistable
=
grad
.
persistable
,
stop_gradient
=
grad
.
stop_gradient
,
)
dist_context
.
set_tensor_dist_attr_for_program
(
cast_grad
,
grad_dist_attr
)
op
.
_rename_output
(
grad_name
,
cast_grad
.
name
)
grad_op_attr
.
set_output_dist_attr
(
cast_grad
.
name
,
grad_dist_attr
)
# add cast
cast_op
=
block
.
_insert_op_without_sync
(
idx
+
1
,
type
=
"cast"
,
inputs
=
{
"X"
:
[
cast_grad
.
name
]},
outputs
=
{
"Out"
:
[
grad
.
name
]},
attrs
=
{
"in_dtype"
:
dst_dtype
,
"out_dtype"
:
src_dtype
,
OP_ROLE_KEY
:
OpRole
.
Backward
,
},
)
grad
.
desc
.
set_dtype
(
src_dtype
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
cast_op
,
ref_mesh
,
ref_mapping
,
dist_context
)
num_cast_ops
+=
1
return
num_cast_ops
return
num_cast_ops
...
@@ -573,7 +638,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
...
@@ -573,7 +638,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
def
_split_grads
(
params_grads
):
def
_split_grads
(
params_grads
):
grads
=
[
g
for
_
,
g
in
params_grads
]
grads
=
[
g
for
_
,
g
in
params_grads
]
fp32_grads
=
[
g
for
g
in
grads
if
g
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
]
fp32_grads
=
[
g
for
g
in
grads
if
g
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
]
fp16_grads
=
[
g
for
g
in
grads
if
g
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
]
fp16_grads
=
[
g
for
g
in
grads
if
g
.
dtype
==
__target_dtype__
]
assert
len
(
fp32_grads
)
+
len
(
fp16_grads
)
==
len
(
assert
len
(
fp32_grads
)
+
len
(
fp16_grads
)
==
len
(
grads
grads
),
"Data types of all grads must be either fp16 or fp32."
),
"Data types of all grads must be either fp16 or fp32."
...
@@ -633,17 +698,15 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
...
@@ -633,17 +698,15 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
# TODO to support CUDAPinned/NPU/XPU Places
# TODO to support CUDAPinned/NPU/XPU Places
if
direction
==
"D2H"
:
if
direction
==
"D2H"
:
dst_place_type
=
0
dst_place_type
=
0
elif
direction
==
"D2H"
:
dst_place_type
=
1
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"direction [{}] is not supported yet."
.
format
(
direction
)
f
"direction [
{
direction
}
] is not supported yet."
)
)
attrs
=
{
'dst_place_type'
:
dst_place_type
}
attrs
=
{
'dst_place_type'
:
dst_place_type
}
new_op
=
block
.
_insert_op_without_sync
(
new_op
=
block
.
_insert_op_without_sync
(
index
=
idx
,
index
=
idx
,
type
=
'memcpy'
,
type
=
'memcpy
_d2h
'
,
inputs
=
{
'X'
:
[
src_var
]},
inputs
=
{
'X'
:
[
src_var
]},
outputs
=
{
'Out'
:
[
output_var
]},
outputs
=
{
'Out'
:
[
output_var
]},
attrs
=
attrs
,
attrs
=
attrs
,
...
@@ -678,17 +741,17 @@ def cast_startup_program():
...
@@ -678,17 +741,17 @@ def cast_startup_program():
for
op
in
startup_program
.
global_block
().
ops
:
for
op
in
startup_program
.
global_block
().
ops
:
if
is_initialization_op
(
op
):
if
is_initialization_op
(
op
):
output_name
=
op
.
output_arg_names
[
0
]
output_name
=
op
.
output_arg_names
[
0
]
if
(
if
param_to_dtype
.
get
(
output_name
,
None
)
==
__target_dtype__
:
param_to_dtype
.
get
(
output_name
,
None
)
==
core
.
VarDesc
.
VarType
.
FP16
):
assert
op
.
has_attr
(
assert
op
.
has_attr
(
'dtype'
'dtype'
),
"initialization op is supported to has dtype attribute but got {}."
.
format
(
),
"initialization op is supported to has dtype attribute but got {}."
.
format
(
str
(
op
)
str
(
op
)
)
)
out_var
=
startup_program
.
global_block
().
var
(
output_name
)
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
out_var
.
desc
.
set_dtype
(
__target_dtype__
)
if
op
.
attr
(
'dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
:
if
op
.
attr
(
'dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
:
op
.
_set_attr
(
'dtype'
,
core
.
VarDesc
.
VarType
.
FP16
)
op
.
_set_attr
(
'dtype'
,
__target_dtype__
)
@
register_pass
(
"auto_parallel_fp16"
)
@
register_pass
(
"auto_parallel_fp16"
)
...
@@ -701,14 +764,44 @@ class FP16Pass(AMPPass):
...
@@ -701,14 +764,44 @@ class FP16Pass(AMPPass):
# in distributed scenario, all ranks should have the same modification.
# in distributed scenario, all ranks should have the same modification.
def
_apply_single_impl
(
self
,
main_program
,
startup_program
,
context
):
def
_apply_single_impl
(
self
,
main_program
,
startup_program
,
context
):
self
.
dist_context
=
self
.
get_attr
(
"dist_context"
)
self
.
dist_context
=
self
.
get_attr
(
"dist_context"
)
self
.
target_dtype
=
self
.
get_attr
(
"dtype"
)
params_grads
=
self
.
get_attr
(
"params_grads"
)
params_grads
=
self
.
get_attr
(
"params_grads"
)
self
.
use_optimizer_fp16
=
self
.
get_attr
(
"use_optimizer_fp16"
,
None
)
if
self
.
use_optimizer_fp16
is
None
:
self
.
use_optimizer_fp16
=
self
.
get_attr
(
"level"
,
None
)
==
"o3"
# swith enviroment for fp16 / bf16.
if
self
.
target_dtype
==
"float16"
:
__target_dtype
=
core
.
VarDesc
.
VarType
.
FP16
elif
self
.
target_dtype
==
"bfloat16"
:
__target_dtype
=
core
.
VarDesc
.
VarType
.
BF16
else
:
raise
NotImplementedError
(
"target dtype [{}] is for amp o2 not supported yet."
.
format
(
self
.
target_dtype
)
)
global
__target_dtype__
__target_dtype__
=
__target_dtype
amp_list
=
AutoMixedPrecisionLists
(
amp_list
=
AutoMixedPrecisionLists
(
set
(
self
.
get_attr
(
"custom_white_list"
)),
set
(
self
.
get_attr
(
"custom_white_list"
)),
set
(
self
.
get_attr
(
"custom_black_list"
)),
set
(
self
.
get_attr
(
"custom_black_list"
)),
Non
e
,
dtype
=
self
.
target_dtyp
e
,
)
)
amp_list
.
unsupported_list
-=
{
"conditional_block_grad"
,
"conditional_block"
,
"conditional_block_infer"
,
"select_input"
,
"while"
,
"while_grad"
,
"cast"
,
"tensor_array_to_tensor"
,
"lod_array_length"
,
"write_to_array"
,
}
# NOTE don't not change input data dtype, since it is controled by dataloader
# NOTE don't not change input data dtype, since it is controled by dataloader
# and which is out of control of FP16 Pass
# and which is out of control of FP16 Pass
input_data_var_names
=
[
var
.
name
for
var
in
self
.
get_attr
(
"input_data"
)]
input_data_var_names
=
[
var
.
name
for
var
in
self
.
get_attr
(
"input_data"
)]
...
@@ -726,93 +819,96 @@ class FP16Pass(AMPPass):
...
@@ -726,93 +819,96 @@ class FP16Pass(AMPPass):
cast_startup_program
()
cast_startup_program
()
if
is_train
:
if
is_train
:
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
if
self
.
target_dtype
==
"fp16"
:
# TODO (JZ-LIANG)support cast forward program only when inference
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
self
.
_init_amp_var
()
# TODO (JZ-LIANG)support cast forward program only when inference
self
.
_scale_loss
()
self
.
_init_amp_var
()
self
.
_scale_loss
()
grads
,
fp32_grads
,
fp16_grads
=
_split_grads
(
params_grads
)
grads
,
fp32_grads
,
fp16_grads
=
_split_grads
(
params_grads
)
if
(
self
.
get_attr
(
"use_dynamic_loss_scaling"
)
if
(
or
self
.
get_attr
(
"init_loss_scaling"
)
!=
1.0
self
.
get_attr
(
"use_dynamic_loss_scaling"
)
):
or
self
.
get_attr
(
"init_loss_scaling"
)
!=
1.0
found_infs
=
[]
):
if
fp32_grads
:
found_infs
=
[]
if
fp32_grads
:
with
main_program
.
_optimized_guard
([]):
_
,
found_inf_fp32
=
_check_and_update_gradient
(
fp32_grads
,
self
.
_loss_scaling
,
"@fp32"
,
self
.
dist_context
,
)
found_infs
.
append
(
found_inf_fp32
)
if
fp16_grads
:
with
main_program
.
_optimized_guard
([]):
_
,
found_inf_fp16
=
_check_and_update_gradient
(
fp16_grads
,
self
.
_loss_scaling
,
"@fp16"
,
self
.
dist_context
,
)
found_infs
.
append
(
found_inf_fp16
)
with
main_program
.
_optimized_guard
([]):
with
main_program
.
_optimized_guard
([]):
_
,
found_inf_fp32
=
_check_and_update_gradient
(
block
=
main_program
.
global_block
()
fp32_grads
,
self
.
_loss_scaling
,
all_infs
=
paddle
.
fluid
.
layers
.
concat
(
found_infs
)
"@fp32"
,
set_var_dist_attr
(
self
.
dist_context
,
self
.
dist_context
,
all_infs
,
[
-
1
],
world_process_group
.
ranks
,
)
)
found_infs
.
append
(
found_inf_fp32
)
new_op
=
block
.
ops
[
-
1
]
if
fp16_grads
:
assert
new_op
.
type
==
"concat"
with
main_program
.
_optimized_guard
([]):
_set_op_dist_attr_with_ranks
(
_
,
found_inf_fp16
=
_check_and_update_gradient
(
new_op
,
fp16_grads
,
world_process_group
.
ranks
,
self
.
_loss_scaling
,
block
,
"@fp16"
,
self
.
dist_context
,
self
.
dist_context
,
)
)
found_infs
.
append
(
found_inf_fp16
)
with
main_program
.
_optimized_guard
([]):
block
=
main_program
.
global_block
()
all_infs
=
paddle
.
fluid
.
layers
.
concat
(
found_infs
)
set_var_dist_attr
(
self
.
dist_context
,
all_infs
,
[
-
1
],
world_process_group
.
ranks
,
)
new_op
=
block
.
ops
[
-
1
]
assert
new_op
.
type
==
"concat"
_set_op_dist_attr_with_ranks
(
new_op
,
world_process_group
.
ranks
,
block
,
self
.
dist_context
,
)
found_inf
=
paddle
.
fluid
.
layers
.
reduce_any
(
all_infs
)
found_inf
=
paddle
.
fluid
.
layers
.
reduce_any
(
all_infs
)
set_var_dist_attr
(
set_var_dist_attr
(
self
.
dist_context
,
self
.
dist_context
,
found_inf
,
found_inf
,
[
-
1
],
[
-
1
],
world_process_group
.
ranks
,
world_process_group
.
ranks
,
)
)
new_op
=
block
.
ops
[
-
1
]
new_op
=
block
.
ops
[
-
1
]
assert
new_op
.
type
==
"reduce_any"
assert
new_op
.
type
==
"reduce_any"
_set_op_dist_attr_with_ranks
(
_set_op_dist_attr_with_ranks
(
new_op
,
new_op
,
world_process_group
.
ranks
,
world_process_group
.
ranks
,
block
,
block
,
self
.
dist_context
,
self
.
dist_context
,
)
)
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
):
if
self
.
get_attr
(
"use_dynamic_loss_scaling"
):
with
main_program
.
_optimized_guard
([]):
with
main_program
.
_optimized_guard
([]):
if
fp32_grads
:
if
fp32_grads
:
self
.
_update_loss_scaling
(
fp32_grads
,
found_inf
)
self
.
_update_loss_scaling
(
fp32_grads
,
found_inf
)
if
fp16_grads
:
if
fp16_grads
:
self
.
_update_loss_scaling
(
fp16_grads
,
found_inf
)
self
.
_update_loss_scaling
(
fp16_grads
,
found_inf
)
# modify optimizer
# modify optimizer
base_opt
=
self
.
get_attr
(
"base_opt"
)
base_opt
=
self
.
get_attr
(
"base_opt"
)
base_opt
.
_multi_precision
=
True
base_opt
.
_multi_precision
=
True
if
self
.
get_attr
(
"use_optimizer_fp16"
)
:
if
self
.
use_optimizer_fp16
:
base_opt
.
_multi_precision
=
False
base_opt
.
_multi_precision
=
False
if
isinstance
(
if
self
.
target_dtype
==
"fp16"
:
base_opt
,
(
paddle
.
fluid
.
optimizer
.
Adam
,
paddle
.
optimizer
.
AdamW
)
if
isinstance
(
):
base_opt
,
with
main_program
.
_optimized_guard
([]):
(
paddle
.
fluid
.
optimizer
.
Adam
,
paddle
.
optimizer
.
AdamW
),
# found_inf = paddle.tensor.creation._memcpy(
):
# found_inf, paddle.CPUPlace())
with
main_program
.
_optimized_guard
([]):
insert_idx
=
_get_memcopy_idx
(
block
,
found_inf
)
# found_inf = paddle.tensor.creation._memcpy(
found_inf
=
_insert_memcopy
(
# found_inf, paddle.CPUPlace())
block
,
insert_idx
,
found_inf
,
self
.
dist_context
insert_idx
=
_get_memcopy_idx
(
block
,
found_inf
)
)
found_inf
=
_insert_memcopy
(
base_opt
.
_set_auxiliary_var
(
'found_inf'
,
found_inf
.
name
)
block
,
insert_idx
,
found_inf
,
self
.
dist_context
elif
hasattr
(
base_opt
,
"_set_auxiliary_var"
):
)
base_opt
.
_set_auxiliary_var
(
'found_inf'
,
found_inf
.
name
)
base_opt
.
_set_auxiliary_var
(
'found_inf'
,
found_inf
.
name
)
elif
hasattr
(
base_opt
,
"_set_auxiliary_var"
):
base_opt
.
_set_auxiliary_var
(
'found_inf'
,
found_inf
.
name
)
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
6f3c9643
...
@@ -40,6 +40,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
...
@@ -40,6 +40,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_random_ctrl MODULES test_random_ctrl ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_random_ctrl MODULES test_random_ctrl ENVS
${
dist_ENVS
}
)
set_tests_properties
(
test_random_ctrl PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
set_tests_properties
(
test_random_ctrl PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 50
)
TIMEOUT 50
)
py_test_modules
(
test_amp_o2_pass MODULES test_amp_o2_pass ENVS
${
dist_ENVS
}
)
set_tests_properties
(
test_amp_o2_pass PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 50
)
py_test_modules
(
test_iterable_dataset MODULES test_iterable_dataset ENVS
py_test_modules
(
test_iterable_dataset MODULES test_iterable_dataset ENVS
${
dist_ENVS
}
)
${
dist_ENVS
}
)
set_tests_properties
(
test_iterable_dataset
set_tests_properties
(
test_iterable_dataset
...
...
python/paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py
0 → 100644
浏览文件 @
6f3c9643
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
random
import
re
import
unittest
import
numpy
as
np
from
get_gpt_model
import
FakeDataset
,
generate_model
import
paddle
from
paddle.distributed.fleet
import
auto
from
paddle.fluid.framework
import
core
paddle
.
enable_static
()
def
get_cuda_version
():
result
=
os
.
popen
(
"nvcc --version"
).
read
()
regex
=
r
'release (\S+),'
match
=
re
.
search
(
regex
,
result
)
if
match
:
num
=
str
(
match
.
group
(
1
))
integer
,
decimal
=
num
.
split
(
'.'
)
return
int
(
integer
)
*
1000
+
int
(
float
(
decimal
)
*
10
)
else
:
return
-
1
def
apply_pass
(
use_amp
=
False
,
amp_dtype
=
"bfloat16"
):
strategy
=
auto
.
Strategy
()
strategy
.
auto_mode
=
"semi"
strategy
.
reinit
=
True
if
use_amp
:
amp
=
strategy
.
amp
amp
.
enable
=
True
amp
.
dtype
=
amp_dtype
amp
.
level
=
"o2"
amp
.
custom_black_list
=
[
'c_softmax_with_cross_entropy'
,
'elementwise_div'
,
'reduce_sum'
,
]
return
strategy
def
reset_prog
():
paddle
.
fluid
.
framework
.
switch_main_program
(
paddle
.
static
.
Program
())
paddle
.
fluid
.
framework
.
switch_startup_program
(
paddle
.
static
.
Program
())
class
TestShardingStage2WithNewEXE
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
batch_size
=
2
self
.
batch_num
=
10
self
.
clip_norm
=
0.2
self
.
dataset
=
FakeDataset
(
self
.
batch_size
*
self
.
batch_num
)
def
init
(
self
,
engine
):
paddle
.
seed
(
2022
)
np
.
random
.
seed
(
2022
)
random
.
seed
(
2022
)
place
=
paddle
.
fluid
.
CUDAPlace
(
paddle
.
distributed
.
ParallelEnv
().
dev_id
)
engine
.
_executor
=
paddle
.
static
.
Executor
(
place
)
def
get_engine
(
self
,
use_amp
=
False
,
amp_dtype
=
"bfloat16"
):
reset_prog
()
strategy
=
apply_pass
(
use_amp
,
amp_dtype
)
# clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
clip
=
None
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
model
,
loss
=
generate_model
(
"mp"
)
engine
=
auto
.
Engine
(
model
,
loss
,
opt
,
strategy
=
strategy
)
self
.
init
(
engine
)
return
engine
def
check_bf16
(
self
,
program
):
num_bf16
=
0
num_fp16
=
0
num_fp32
=
0
for
p
in
program
.
all_parameters
():
if
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
num_fp32
+=
1
if
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
num_fp16
+=
1
if
p
.
dtype
==
core
.
VarDesc
.
VarType
.
BF16
:
num_bf16
+=
1
self
.
assertEqual
(
num_bf16
,
25
)
self
.
assertEqual
(
num_fp16
,
0
)
self
.
assertEqual
(
num_fp32
,
11
)
def
test_param_grad_fuse_overlap
(
self
):
# std
mp_engine
=
self
.
get_engine
(
use_amp
=
False
)
mp_history
=
mp_engine
.
fit
(
self
.
dataset
,
3
,
epochs
=
1
,
steps_per_epoch
=
self
.
batch_num
,
log_freq
=
1
,
batch_size
=
self
.
batch_size
,
)
loss0
=
mp_history
.
history
[
'loss'
][
0
]
# bf16
mp_bf16_engine
=
self
.
get_engine
(
use_amp
=
True
)
if
not
paddle
.
is_compiled_with_cuda
()
or
get_cuda_version
()
<
11000
:
return
mp_bf16_history
=
mp_bf16_engine
.
fit
(
self
.
dataset
,
3
,
epochs
=
1
,
steps_per_epoch
=
self
.
batch_num
,
log_freq
=
1
,
batch_size
=
self
.
batch_size
,
)
loss1
=
mp_bf16_history
.
history
[
'loss'
][
0
]
np
.
testing
.
assert_allclose
(
loss0
,
loss1
,
atol
=
1e-3
,
rtol
=
1e-2
)
self
.
check_bf16
(
mp_bf16_engine
.
main_program
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
浏览文件 @
6f3c9643
...
@@ -38,7 +38,7 @@ def apply_pass(use_amp=False, level=None):
...
@@ -38,7 +38,7 @@ def apply_pass(use_amp=False, level=None):
]
]
amp
.
init_loss_scaling
=
32768
amp
.
init_loss_scaling
=
32768
amp
.
use_fp16_guard
=
False
amp
.
use_fp16_guard
=
False
amp
.
use_pure_fp16
=
level
in
[
"o2"
,
"o3"
]
amp
.
level
=
level
amp
.
use_optimizer_fp16
=
level
==
"o3"
amp
.
use_optimizer_fp16
=
level
==
"o3"
print
(
"amp level: "
,
level
)
print
(
"amp level: "
,
level
)
return
strategy
return
strategy
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py
0 → 100644
浏览文件 @
6f3c9643
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
subprocess
import
sys
import
tempfile
import
unittest
class
TestAMPO2
(
unittest
.
TestCase
):
def
test_bf16
(
self
):
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
launch_model_path
=
os
.
path
.
join
(
file_dir
,
"amp_o2_pass.py"
)
if
os
.
environ
.
get
(
"WITH_COVERAGE"
,
"OFF"
)
==
"ON"
:
coverage_args
=
[
"-m"
,
"coverage"
,
"run"
,
"--branch"
,
"-p"
]
else
:
coverage_args
=
[]
tmp_dir
=
tempfile
.
TemporaryDirectory
()
cmd
=
(
[
sys
.
executable
,
"-u"
]
+
coverage_args
+
[
"-m"
,
"paddle.distributed.launch"
,
"--devices"
,
"0,1"
,
"--log_dir"
,
tmp_dir
.
name
,
launch_model_path
,
]
)
process
=
subprocess
.
Popen
(
cmd
)
process
.
wait
()
self
.
assertEqual
(
process
.
returncode
,
0
)
tmp_dir
.
cleanup
()
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py
浏览文件 @
6f3c9643
...
@@ -13,13 +13,13 @@
...
@@ -13,13 +13,13 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
# import yaml
# import yaml
import
unittest
import
unittest
from
paddle.distributed.fleet
import
auto
from
paddle.distributed.fleet
import
auto
class
TestStrategy
(
unittest
.
TestCase
):
class
TestStrategy
(
unittest
.
TestCase
):
def
test_default_config
(
self
):
def
test_default_config
(
self
):
strategy
=
auto
.
Strategy
()
strategy
=
auto
.
Strategy
()
...
@@ -29,6 +29,8 @@ class TestStrategy(unittest.TestCase):
...
@@ -29,6 +29,8 @@ class TestStrategy(unittest.TestCase):
amp
=
strategy
.
amp
amp
=
strategy
.
amp
self
.
assertEqual
(
amp
.
enable
,
False
)
self
.
assertEqual
(
amp
.
enable
,
False
)
self
.
assertAlmostEqual
(
amp
.
dtype
,
"float16"
)
self
.
assertAlmostEqual
(
amp
.
level
,
"o1"
)
self
.
assertAlmostEqual
(
amp
.
init_loss_scaling
,
32768.0
)
self
.
assertAlmostEqual
(
amp
.
init_loss_scaling
,
32768.0
)
self
.
assertEqual
(
amp
.
incr_every_n_steps
,
1000
)
self
.
assertEqual
(
amp
.
incr_every_n_steps
,
1000
)
self
.
assertEqual
(
amp
.
decr_every_n_nan_or_inf
,
2
)
self
.
assertEqual
(
amp
.
decr_every_n_nan_or_inf
,
2
)
...
@@ -38,8 +40,7 @@ class TestStrategy(unittest.TestCase):
...
@@ -38,8 +40,7 @@ class TestStrategy(unittest.TestCase):
self
.
assertEqual
(
amp
.
custom_black_list
,
[])
self
.
assertEqual
(
amp
.
custom_black_list
,
[])
self
.
assertEqual
(
amp
.
custom_white_list
,
[])
self
.
assertEqual
(
amp
.
custom_white_list
,
[])
self
.
assertEqual
(
amp
.
custom_black_varnames
,
[])
self
.
assertEqual
(
amp
.
custom_black_varnames
,
[])
self
.
assertEqual
(
amp
.
use_pure_fp16
,
False
)
self
.
assertEqual
(
amp
.
use_fp16_guard
,
False
)
self
.
assertEqual
(
amp
.
use_fp16_guard
,
True
)
self
.
assertEqual
(
amp
.
use_optimizer_fp16
,
False
)
self
.
assertEqual
(
amp
.
use_optimizer_fp16
,
False
)
sharding
=
strategy
.
sharding
sharding
=
strategy
.
sharding
...
@@ -92,7 +93,6 @@ class TestStrategy(unittest.TestCase):
...
@@ -92,7 +93,6 @@ class TestStrategy(unittest.TestCase):
amp
.
custom_white_list
=
[
"x"
]
amp
.
custom_white_list
=
[
"x"
]
amp
.
custom_black_list
=
[
"y"
]
amp
.
custom_black_list
=
[
"y"
]
amp
.
custom_black_varnames
=
[
"z"
]
amp
.
custom_black_varnames
=
[
"z"
]
amp
.
use_pure_fp16
=
True
amp
.
use_fp16_guard
=
False
amp
.
use_fp16_guard
=
False
amp
.
use_optimizer_fp16
=
True
amp
.
use_optimizer_fp16
=
True
self
.
assertEqual
(
amp
.
enable
,
True
)
self
.
assertEqual
(
amp
.
enable
,
True
)
...
@@ -105,7 +105,6 @@ class TestStrategy(unittest.TestCase):
...
@@ -105,7 +105,6 @@ class TestStrategy(unittest.TestCase):
self
.
assertEqual
(
amp
.
custom_white_list
,
[
"x"
])
self
.
assertEqual
(
amp
.
custom_white_list
,
[
"x"
])
self
.
assertEqual
(
amp
.
custom_black_list
,
[
"y"
])
self
.
assertEqual
(
amp
.
custom_black_list
,
[
"y"
])
self
.
assertEqual
(
amp
.
custom_black_varnames
,
[
"z"
])
self
.
assertEqual
(
amp
.
custom_black_varnames
,
[
"z"
])
self
.
assertEqual
(
amp
.
use_pure_fp16
,
True
)
self
.
assertEqual
(
amp
.
use_fp16_guard
,
False
)
self
.
assertEqual
(
amp
.
use_fp16_guard
,
False
)
self
.
assertEqual
(
amp
.
use_optimizer_fp16
,
True
)
self
.
assertEqual
(
amp
.
use_optimizer_fp16
,
True
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录