Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ec1e0d5a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ec1e0d5a
编写于
7月 29, 2022
作者:
C
caozhou
提交者:
GitHub
7月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add dist op costs (#44701)
上级
fecbc958
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
902 addition
and
22 deletion
+902
-22
python/paddle/distributed/auto_parallel/cost/__init__.py
python/paddle/distributed/auto_parallel/cost/__init__.py
+25
-5
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
+19
-0
python/paddle/distributed/auto_parallel/dist_context.py
python/paddle/distributed/auto_parallel/dist_context.py
+4
-2
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
...dle/distributed/auto_parallel/operators/dist_embedding.py
+92
-0
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+289
-11
python/paddle/distributed/auto_parallel/operators/dist_reshape.py
...addle/distributed/auto_parallel/operators/dist_reshape.py
+240
-1
python/paddle/distributed/auto_parallel/operators/dist_softmax.py
...addle/distributed/auto_parallel/operators/dist_softmax.py
+62
-0
python/paddle/distributed/auto_parallel/operators/dist_transpose.py
...dle/distributed/auto_parallel/operators/dist_transpose.py
+62
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py
.../fluid/tests/unittests/auto_parallel/test_dist_op_cost.py
+109
-3
未找到文件。
python/paddle/distributed/auto_parallel/cost/__init__.py
浏览文件 @
ec1e0d5a
...
...
@@ -12,20 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License
from
.base_cost
import
_g_op_cost_factory
from
.base_cost
import
Cost
from
.base_cost
import
CommContext
from
.base_cost
import
_g_op_cost_factory
from
.base_cost
import
build_comm_desc
from
.base_cost
import
build_comp_desc_from_op
from
.base_cost
import
build_comp_desc_from_dist_op
from
.base_cost
import
build_dp_costs
from
.base_cost
import
build_comp_desc_str_for_predict
from
.base_cost
import
build_comp_desc_from_dist_op
from
.base_cost
import
build_comm_desc_from_dist_op
from
.base_cost
import
build_comm_costs_from_descs
from
.base_cost
import
build_comp_costs_from_descs
from
.tensor_cost
import
TensorCost
from
.estimate_cost
import
CostEstimator
from
.comp_op_cost
import
EmbeddingOpCost
from
.comp_op_cost
import
EmbeddingGradOpCost
from
.comp_op_cost
import
ConcatOpCost
from
.comp_op_cost
import
MatmulOpCost
from
.comp_op_cost
import
MatmulGradOpCost
from
.comp_op_cost
import
MatmulV2OpCost
from
.comp_op_cost
import
MatmulV2GradOpCost
from
.comp_op_cost
import
MulOpCost
from
.comp_op_cost
import
MulGradOpCost
from
.comp_op_cost
import
Reshape2OpCost
from
.comp_op_cost
import
Reshape2GradOpCost
from
.comp_op_cost
import
SliceOpCost
from
.comp_op_cost
import
SplitOpCost
from
.comp_op_cost
import
SoftmaxOpCost
from
.comp_op_cost
import
SoftmaxGradOpCost
from
.comp_op_cost
import
Transpose2OpCost
from
.comp_op_cost
import
Transpose2GradOpCost
from
.comp_op_cost
import
FillConstantBatchSizeLikeOpCost
from
.tensor_cost
import
TensorCost
from
.estimate_cost
import
CostEstimator
from
.comm_op_cost
import
SendOpCost
from
.comm_op_cost
import
RecvOpCost
from
.comm_op_cost
import
IdentityOpCost
...
...
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
浏览文件 @
ec1e0d5a
...
...
@@ -15,6 +15,25 @@
from
.base_cost
import
Cost
,
register_op_cost
,
CompOpCost
,
_g_op_cost_factory
@
register_op_cost
class
AdamOpCost
(
CompOpCost
):
OP_TYPE
=
"adam"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
(
AdamOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
AssignOpCost
(
CompOpCost
):
OP_TYPE
=
"assign"
...
...
python/paddle/distributed/auto_parallel/dist_context.py
浏览文件 @
ec1e0d5a
...
...
@@ -831,8 +831,10 @@ class DistributedContext:
if
(
dist_tensor
is
not
None
)
and
(
not
dist_tensor
.
validate_dist_attr
()):
assert
False
,
"Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}."
.
format
(
dist_tensor
.
serial_tensor
.
name
,
dist_tensor
.
desc
.
id
(),
dist_tensor
.
desc
.
original_id
(),
dist_tensor
.
dist_attr
)
dist_tensor
.
serial_tensor
.
name
,
dist_tensor
.
serial_tensor
.
desc
.
id
(),
dist_tensor
.
serial_tensor
.
desc
.
original_id
(),
dist_tensor
.
dist_attr
)
for
op
in
block
.
ops
:
dist_op
=
self
.
get_dist_op_for_program
(
op
)
assert
dist_op
is
not
None
,
\
...
...
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
浏览文件 @
ec1e0d5a
...
...
@@ -31,6 +31,9 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
from
..process_group
import
new_process_group
from
..utils
import
_get_comm_group
,
_get_idx_in_axis
,
_get_corresponding_rank
from
..cost
import
build_comp_desc_from_dist_op
,
build_comm_desc_from_dist_op
from
..cost
import
build_comm_costs_from_descs
,
build_comp_costs_from_descs
,
build_dp_costs
from
..cost
import
EmbeddingOpCost
,
EmbeddingGradOpCost
,
AllreduceSumOpCost
,
IdentityOpCost
class
DistributedEmbedding
(
DistributedOperatorImplContainer
):
...
...
@@ -53,6 +56,95 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
self
.
_forward_implemented
=
True
self
.
_backward_implemented
=
True
def
calc_cost
(
self
,
op_role
,
dist_op
,
ctx
,
cluster
):
"""Calculate the cost by the op role."""
cost
=
None
if
int
(
op_role
)
==
int
(
OpRole
.
Forward
):
cost
=
self
.
calc_fwd_cost
(
dist_op
,
ctx
,
cluster
)
elif
int
(
op_role
)
==
int
(
OpRole
.
Backward
):
cost
=
self
.
calc_bwd_cost
(
dist_op
,
ctx
,
cluster
)
assert
cost
is
not
None
return
cost
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
# embedding need start_index
cost_mapping
=
build_comp_costs_from_descs
(
EmbeddingOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
serial_op
=
dist_op
.
serial_op
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"W"
)[
0
])[
0
]
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
var_names
=
serial_op
.
output
(
"Out"
)
c_allreduce_sum_desc_mapping
=
build_comm_desc_from_dist_op
(
"c_allreduce_sum"
,
dist_op
,
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
)
res_cost
=
[
cost_mapping
,
comm_op_cost_list
]
return
res_cost
def
calc_bwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# by now the backward function only insert the gradient allreduce for dist op itself
res
=
[]
backward_op
=
dist_op
.
serial_op
main_block
=
backward_op
.
block
dist_attr
=
dist_op
.
dist_attr
embedding_row_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"W"
)[
0
])[
0
]
parallel_axis
=
embedding_row_dim_mapping
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
var_names
=
[
backward_op
.
input
(
"Out@GRAD"
)[
0
]]
c_identity_desc_mapping
=
build_comm_desc_from_dist_op
(
"c_identity"
,
dist_op
,
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
comm_op_cost_list
=
build_comm_costs_from_descs
(
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
res
.
append
(
comm_op_cost_list
)
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
cost_mapping
=
build_comp_costs_from_descs
(
EmbeddingGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Ids"
)[
0
])
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
:
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'W@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
def
is_input_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
...
...
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
ec1e0d5a
...
...
@@ -13,6 +13,7 @@
# limitations under the License
import
copy
from
.common
import
infer_shape
from
.common
import
DistributedOperatorImplContainer
from
.common
import
DistributedOperatorImpl
...
...
@@ -35,6 +36,10 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY,
from
..process_group
import
new_process_group
from
..utils
import
_get_comm_group
,
_get_corresponding_rank
from
.dist_default
import
DistributedDefaultImpl0
from
..cost
import
build_comp_desc_from_dist_op
,
build_comm_desc_from_dist_op
,
build_dp_costs
from
..cost
import
build_comm_costs_from_descs
,
build_comp_costs_from_descs
from
..cost
import
MatmulV2OpCost
,
MatmulOpCost
,
MulOpCost
,
IdentityOpCost
,
AllreduceSumOpCost
from
..cost
import
MatmulV2GradOpCost
,
MatmulGradOpCost
,
MulGradOpCost
def
copy_op_with_new_input_output
(
ctx
,
block
,
src_op
,
**
kwargs
):
...
...
@@ -58,6 +63,14 @@ def _update_dims_mapping_for_matmul(dist_op):
x_name
=
op_desc
.
input
(
'X'
)[
0
]
y_name
=
op_desc
.
input
(
'Y'
)[
0
]
out_name
=
op_desc
.
output
(
'Out'
)[
0
]
trans_x
=
None
trans_y
=
None
if
op_desc
.
type
()
==
"matmul_v2"
:
trans_x
=
op_desc
.
attr
(
'trans_x'
)
trans_y
=
op_desc
.
attr
(
'trans_y'
)
elif
op_desc
.
type
()
==
"matmul"
:
trans_x
=
op_desc
.
attr
(
'transpose_X'
)
trans_y
=
op_desc
.
attr
(
'transpose_Y'
)
x_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
x_name
)
y_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
y_name
)
out_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
out_name
)
...
...
@@ -67,27 +80,34 @@ def _update_dims_mapping_for_matmul(dist_op):
# Add dim mapping to Make sure the length dims_mapping be at least 2
if
x_dims_mapping_len
==
1
:
assert
trans_x
is
False
x_dims_mapping
.
insert
(
0
,
-
1
)
out_dims_mapping
.
insert
(
out_dims_mapping_len
-
1
,
0
)
if
y_dims_mapping_len
==
1
:
assert
trans_y
is
False
y_dims_mapping
.
insert
(
1
,
-
1
)
out_dims_mapping
.
insert
(
out_dims_mapping_len
,
0
)
new_x_dims_mapping_len
=
len
(
x_dims_mapping
)
new_y_dims_mapping_len
=
len
(
y_dims_mapping
)
new_out_dims_mapping_len
=
len
(
out_dims_mapping
)
# Deal with dim > 2 and take care of broadcasting
if
out_dims_mapping_len
>
2
:
if
new_
out_dims_mapping_len
>
2
:
broadcast_x_dims_mapping
=
[]
broadcast_y_dims_mapping
=
[]
broadcast_out_dims_mapping
=
[]
for
i
in
range
(
out_dims_mapping_len
-
x_dims_mapping_len
):
for
i
in
range
(
new_out_dims_mapping_len
-
new_
x_dims_mapping_len
):
broadcast_x_dims_mapping
.
append
(
out_dims_mapping
[
i
])
for
i
in
range
(
x_dims_mapping_len
-
2
):
for
i
in
range
(
new_
x_dims_mapping_len
-
2
):
broadcast_x_dims_mapping
.
append
(
x_dims_mapping
[
i
])
for
i
in
range
(
out_dims_mapping_len
-
y_dims_mapping_len
):
for
i
in
range
(
new_out_dims_mapping_len
-
new_
y_dims_mapping_len
):
broadcast_y_dims_mapping
.
append
(
out_dims_mapping
[
i
])
for
i
in
range
(
y_dims_mapping_len
-
2
):
for
i
in
range
(
new_
y_dims_mapping_len
-
2
):
broadcast_y_dims_mapping
.
append
(
y_dims_mapping
[
i
])
for
i
in
range
(
out_dims_mapping_len
-
2
):
for
i
in
range
(
new_
out_dims_mapping_len
-
2
):
broadcast_out_dims_mapping
.
append
(
out_dims_mapping
[
i
])
compatible_dims_mapping
=
compute_compatible_dims_mapping
([
...
...
@@ -97,23 +117,30 @@ def _update_dims_mapping_for_matmul(dist_op):
if
compatible_dims_mapping
is
None
:
return
False
for
i
in
range
(
x_dims_mapping_len
-
2
):
new_idx
=
i
+
(
out_dims_mapping_len
-
x_dims_mapping_len
)
for
i
in
range
(
new_
x_dims_mapping_len
-
2
):
new_idx
=
i
+
(
out_dims_mapping_len
-
new_
x_dims_mapping_len
)
if
x_dims_mapping
[
i
]
!=
compatible_dims_mapping
[
new_idx
]:
x_dims_mapping
[
i
]
=
compatible_dims_mapping
[
new_idx
]
changed
=
True
for
i
in
range
(
y_dims_mapping_len
-
2
):
new_idx
=
i
+
(
out_dims_mapping_len
-
y_dims_mapping_len
)
for
i
in
range
(
new_
y_dims_mapping_len
-
2
):
new_idx
=
i
+
(
out_dims_mapping_len
-
new_
y_dims_mapping_len
)
if
y_dims_mapping
[
i
]
!=
compatible_dims_mapping
[
new_idx
]:
y_dims_mapping
[
i
]
=
compatible_dims_mapping
[
new_idx
]
changed
=
True
for
i
in
range
(
out_dims_mapping_len
-
2
):
for
i
in
range
(
new_
out_dims_mapping_len
-
2
):
if
out_dims_mapping
[
i
]
!=
compatible_dims_mapping
[
i
]:
out_dims_mapping
[
i
]
=
compatible_dims_mapping
[
i
]
changed
=
True
if
trans_x
:
x_dims_mapping
[
-
1
],
x_dims_mapping
[
-
2
]
=
x_dims_mapping
[
-
2
],
x_dims_mapping
[
-
1
]
if
trans_y
:
y_dims_mapping
[
-
1
],
y_dims_mapping
[
-
2
]
=
y_dims_mapping
[
-
2
],
y_dims_mapping
[
-
1
]
# The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
dim_changed
=
compute_compatible_and_update_dim_mapping
(
...
...
@@ -131,11 +158,20 @@ def _update_dims_mapping_for_matmul(dist_op):
if
dim_changed
:
changed
=
True
if
trans_x
:
x_dims_mapping
[
-
1
],
x_dims_mapping
[
-
2
]
=
x_dims_mapping
[
-
2
],
x_dims_mapping
[
-
1
]
if
trans_y
:
y_dims_mapping
[
-
1
],
y_dims_mapping
[
-
2
]
=
y_dims_mapping
[
-
2
],
y_dims_mapping
[
-
1
]
# Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor
if
x_dims_mapping_len
==
1
:
x_dims_mapping
.
pop
(
0
)
out_dims_mapping
.
pop
(
out_dims_mapping_len
-
1
)
if
y_dims_mapping_len
==
1
:
y_dims_mapping
.
pop
(
1
)
out_dims_mapping
.
pop
(
out_dims_mapping_len
)
assert
len
(
x_dims_mapping
)
==
x_dims_mapping_len
assert
len
(
y_dims_mapping
)
==
y_dims_mapping_len
...
...
@@ -484,6 +520,102 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
self
.
_forward_implemented
=
True
self
.
_backward_implemented
=
True
def
calc_cost
(
self
,
op_role
,
dist_op
,
ctx
,
cluster
):
cost
=
None
if
int
(
op_role
)
==
int
(
OpRole
.
Forward
):
cost
=
self
.
calc_fwd_cost
(
dist_op
,
ctx
,
cluster
)
elif
int
(
op_role
)
==
int
(
OpRole
.
Backward
):
cost
=
self
.
calc_bwd_cost
(
dist_op
,
ctx
,
cluster
)
assert
cost
is
not
None
return
cost
def
calc_bwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# by now the backward function only insert the gradient allreduce for dist op itself
res
=
[]
backward_op
=
dist_op
.
serial_op
dist_attr
=
dist_op
.
dist_attr
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
])
# col parallel: matmul + allreduce
assert
Y_var_dim_mapping
[
0
]
<
0
parallel_axis
=
Y_var_dim_mapping
[
1
]
has_x_grad
=
len
(
backward_op
.
output
(
"X@GRAD"
))
>
0
if
has_x_grad
:
assert
len
(
backward_op
.
output
(
"X@GRAD"
))
==
1
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
# calc comm op cost
if
has_x_grad
:
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
var_names
=
backward_op
.
output
(
"X@GRAD"
)
c_allreduce_sum_desc_mapping
=
build_comm_desc_from_dist_op
(
"c_allreduce_sum"
,
dist_op
,
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
)
res
.
append
(
comm_op_cost_list
)
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
# calc comm op cost
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
])[
-
1
]
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
var_names
=
serial_op
.
input
(
"X"
)
c_identity_desc_mapping
=
build_comm_desc_from_dist_op
(
"c_identity"
,
dist_op
,
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
res_cost
=
[
comm_op_cost_list
,
cost_mapping
]
return
res_cost
def
is_input_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
...
...
@@ -710,6 +842,99 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
self
.
_forward_implemented
=
True
self
.
_backward_implemented
=
True
def
calc_cost
(
self
,
op_role
,
dist_op
,
ctx
,
cluster
):
cost
=
None
if
int
(
op_role
)
==
int
(
OpRole
.
Forward
):
cost
=
self
.
calc_fwd_cost
(
dist_op
,
ctx
,
cluster
)
elif
int
(
op_role
)
==
int
(
OpRole
.
Backward
):
cost
=
self
.
calc_bwd_cost
(
dist_op
,
ctx
,
cluster
)
assert
cost
is
not
None
return
cost
def
calc_bwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# by now the backward function only insert the gradient allreduce for dist op itself
res
=
[]
backward_op
=
dist_op
.
serial_op
dist_attr
=
dist_op
.
dist_attr
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
Y_var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"Y"
)[
0
])
assert
Y_var_dim_mapping
[
1
]
<
0
parallel_axis
=
Y_var_dim_mapping
[
0
]
# calc comm op cost
var_names
=
[
backward_op
.
input
(
"Out@GRAD"
)[
0
]]
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
c_identity_desc_mapping
=
build_comm_desc_from_dist_op
(
"c_identity"
,
dist_op
,
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
comm_op_cost_list
=
build_comm_costs_from_descs
(
IdentityOpCost
,
ctx
,
processes
,
c_identity_desc_mapping
,
cluster
)
res
.
append
(
comm_op_cost_list
)
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
cost_mapping
=
build_comp_costs_from_descs
(
MatmulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
# calc comm op cost
serial_op
=
dist_op
.
serial_op
vars
=
serial_op
.
block
.
vars
parallel_axis
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
serial_op
.
input
(
"Y"
)[
0
])[
-
2
]
attrs
=
{
"use_calc_stream"
:
True
,
"use_model_parallel"
:
True
}
var_names
=
serial_op
.
output
(
"Out"
)
c_allreduce_sum_desc_mapping
=
build_comm_desc_from_dist_op
(
"c_allreduce_sum"
,
dist_op
,
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
comm_op_cost_list
=
build_comm_costs_from_descs
(
AllreduceSumOpCost
,
ctx
,
processes
,
c_allreduce_sum_desc_mapping
,
cluster
)
res_cost
=
[
cost_mapping
,
comm_op_cost_list
]
return
res_cost
def
is_input_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
...
...
@@ -920,6 +1145,59 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
def
__init__
(
self
,
name
):
super
(
DistributedMatmulImpl2
,
self
).
__init__
(
name
)
def
calc_cost
(
self
,
op_role
,
dist_op
,
ctx
,
cluster
):
cost
=
None
if
int
(
op_role
)
==
int
(
OpRole
.
Forward
):
cost
=
self
.
calc_fwd_cost
(
dist_op
,
ctx
,
cluster
)
elif
int
(
op_role
)
==
int
(
OpRole
.
Backward
):
cost
=
self
.
calc_bwd_cost
(
dist_op
,
ctx
,
cluster
)
assert
cost
is
not
None
return
cost
def
calc_bwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
res
=
[]
backward_op
=
dist_op
.
serial_op
dist_attr
=
dist_op
.
dist_attr
main_block
=
backward_op
.
block
vars
=
main_block
.
vars
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
# need gradient allreduce
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
backward_op
.
input
(
"X"
)[
0
])
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
and
is_parameter_related
(
backward_op
.
input
(
"Y"
)[
0
],
main_block
):
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
backward_op
.
output
(
'Y@GRAD'
)[
0
]]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
MatmulOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res_cost
=
[
cost_mapping
]
return
res_cost
def
is_input_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
...
...
python/paddle/distributed/auto_parallel/operators/dist_reshape.py
浏览文件 @
ec1e0d5a
...
...
@@ -15,7 +15,7 @@
from
.common
import
DistributedOperatorImplContainer
from
.common
import
DistributedOperatorImpl
from
.common
import
register_distributed_operator_impl_container
from
.common
import
register_distributed_operator_impl
from
.common
import
register_distributed_operator_impl
,
is_parameter_related
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_replicate
from
..utils
import
is_valid_list_index
...
...
@@ -28,6 +28,11 @@ from paddle.fluid.framework import _non_static_mode
from
paddle.fluid.framework
import
Program
,
Parameter
,
Variable
,
program_guard
from
paddle.fluid.data_feeder
import
check_variable_and_dtype
,
check_dtype
from
.dist_default
import
DistributedDefaultImpl0
from
..cost
import
build_comp_desc_from_dist_op
,
build_comp_costs_from_descs
from
..cost
import
build_comm_costs_from_descs
from
..cost
import
Reshape2OpCost
from
..cost
import
Reshape2GradOpCost
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
class
DistributedReshape2
(
DistributedOperatorImplContainer
):
...
...
@@ -46,6 +51,84 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
self
.
_forward_implemented
=
True
self
.
_backward_implemented
=
False
def
calc_cost
(
self
,
op_role
,
dist_op
,
ctx
,
cluster
):
cost
=
None
if
int
(
op_role
)
==
int
(
OpRole
.
Backward
):
cost
=
self
.
calc_bwd_cost
(
dist_op
,
ctx
,
cluster
)
else
:
cost
=
self
.
calc_fwd_cost
(
dist_op
,
ctx
,
cluster
)
assert
cost
is
not
None
return
cost
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
res
=
[]
op
=
dist_op
.
serial_op
vars
=
op
.
block
.
vars
dist_attr
=
dist_op
.
dist_attr
shape_list
=
op
.
desc
.
attr
(
"shape"
)
# got dist attribute info
dim_mapping
=
dist_attr
.
get_output_dims_mapping
(
op
.
output
(
"Out"
)[
0
])
process_mesh_shape
=
dist_attr
.
process_mesh
.
topology
# modify target shape
for
idx
,
axis
in
enumerate
(
dim_mapping
):
if
axis
>=
0
:
if
len
(
shape_list
)
>
idx
:
shape_list
[
idx
]
=
shape_list
[
idx
]
//
process_mesh_shape
[
axis
]
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_attr
.
process_mesh
.
processes
for
key
in
desc_mapping
:
desc_mapping
[
key
][
"shape"
]
=
shape_list
cost_mapping
=
build_comp_costs_from_descs
(
Reshape2OpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
return
res
def
calc_bwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
res
=
[]
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
dist_attr
=
dist_op
.
dist_attr
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
op_type
=
dist_op
.
serial_op
.
type
cost_mapping
=
build_comp_costs_from_descs
(
Reshape2GradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
backward_op
=
dist_op
.
serial_op
main_block
=
backward_op
.
block
need_gradient_allreduce
=
False
vars
=
main_block
.
vars
for
input_name
in
backward_op
.
desc
.
input_names
():
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
if
"@GRAD"
not
in
varname
and
is_parameter_related
(
varname
,
main_block
):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
varname
)
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
:
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
varname
+
"@GRAD"
]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
def
is_input_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
...
...
@@ -199,6 +282,84 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
self
.
_forward_implemented
=
True
self
.
_backward_implemented
=
False
def
calc_cost
(
self
,
op_role
,
dist_op
,
ctx
,
cluster
):
cost
=
None
if
int
(
op_role
)
==
int
(
OpRole
.
Backward
):
cost
=
self
.
calc_bwd_cost
(
dist_op
,
ctx
,
cluster
)
else
:
cost
=
self
.
calc_fwd_cost
(
dist_op
,
ctx
,
cluster
)
assert
cost
is
not
None
return
cost
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
res
=
[]
op
=
dist_op
.
serial_op
vars
=
op
.
block
.
vars
dist_attr
=
dist_op
.
dist_attr
shape_list
=
op
.
desc
.
attr
(
"shape"
)
# got dist attribute info
dim_mapping
=
dist_attr
.
get_output_dims_mapping
(
op
.
output
(
"Out"
)[
0
])
process_mesh_shape
=
dist_attr
.
process_mesh
.
topology
# modify target shape
for
idx
,
axis
in
enumerate
(
dim_mapping
):
if
axis
>=
0
:
if
len
(
shape_list
)
>
idx
:
shape_list
[
idx
]
=
shape_list
[
idx
]
//
process_mesh_shape
[
axis
]
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_attr
.
process_mesh
.
processes
for
key
in
desc_mapping
:
desc_mapping
[
key
][
"shape"
]
=
shape_list
cost_mapping
=
build_comp_costs_from_descs
(
Reshape2OpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
return
res
def
calc_bwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
res
=
[]
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
dist_attr
=
dist_op
.
dist_attr
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
op_type
=
dist_op
.
serial_op
.
type
cost_mapping
=
build_comp_costs_from_descs
(
Reshape2GradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
backward_op
=
dist_op
.
serial_op
main_block
=
backward_op
.
block
need_gradient_allreduce
=
False
vars
=
main_block
.
vars
for
input_name
in
backward_op
.
desc
.
input_names
():
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
if
"@GRAD"
not
in
varname
and
not
is_parameter_related
(
varname
,
main_block
):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
varname
)
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
:
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
varname
+
"@GRAD"
]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
def
is_input_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
...
...
@@ -355,6 +516,84 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
self
.
_forward_implemented
=
True
self
.
_backward_implemented
=
False
def
calc_cost
(
self
,
op_role
,
dist_op
,
ctx
,
cluster
):
cost
=
None
if
int
(
op_role
)
==
int
(
OpRole
.
Backward
):
cost
=
self
.
calc_bwd_cost
(
dist_op
,
ctx
,
cluster
)
else
:
cost
=
self
.
calc_fwd_cost
(
dist_op
,
ctx
,
cluster
)
assert
cost
is
not
None
return
cost
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
res
=
[]
op
=
dist_op
.
serial_op
vars
=
op
.
block
.
vars
dist_attr
=
dist_op
.
dist_attr
shape_list
=
op
.
desc
.
attr
(
"shape"
)
# got dist attribute info
dim_mapping
=
dist_attr
.
get_output_dims_mapping
(
op
.
output
(
"Out"
)[
0
])
process_mesh_shape
=
dist_attr
.
process_mesh
.
topology
# modify target shape
for
idx
,
axis
in
enumerate
(
dim_mapping
):
if
axis
>=
0
:
if
len
(
shape_list
)
>
idx
:
shape_list
[
idx
]
=
shape_list
[
idx
]
//
process_mesh_shape
[
axis
]
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_attr
.
process_mesh
.
processes
for
key
in
desc_mapping
:
desc_mapping
[
key
][
"shape"
]
=
shape_list
cost_mapping
=
build_comp_costs_from_descs
(
Reshape2OpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
return
res
def
calc_bwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
res
=
[]
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
dist_attr
=
dist_op
.
dist_attr
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
op_type
=
dist_op
.
serial_op
.
type
cost_mapping
=
build_comp_costs_from_descs
(
Reshape2GradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
backward_op
=
dist_op
.
serial_op
main_block
=
backward_op
.
block
need_gradient_allreduce
=
False
vars
=
main_block
.
vars
for
input_name
in
backward_op
.
desc
.
input_names
():
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
if
"@GRAD"
not
in
varname
and
not
is_parameter_related
(
varname
,
main_block
):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
varname
)
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
:
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
varname
+
"@GRAD"
]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
def
is_input_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
...
...
python/paddle/distributed/auto_parallel/operators/dist_softmax.py
浏览文件 @
ec1e0d5a
...
...
@@ -16,6 +16,7 @@ from .common import DistributedOperatorImplContainer
from
.common
import
DistributedOperatorImpl
from
.common
import
register_distributed_operator_impl_container
from
.common
import
register_distributed_operator_impl
from
.common
import
is_parameter_related
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_replicate
from
..utils
import
is_valid_list_index
...
...
@@ -23,6 +24,11 @@ from ..utils import compute_compatible_dim_mapping
from
..utils
import
compute_compatible_dims_mapping
from
..utils
import
compute_compatible_and_update_dim_mapping
from
.dist_default
import
DistributedDefaultImpl0
from
..cost
import
AllreduceSumOpCost
,
_g_op_cost_factory
from
..cost
import
build_comp_desc_from_dist_op
,
build_dp_costs
from
..cost
import
build_comp_costs_from_descs
from
..cost
import
SoftmaxOpCost
,
SoftmaxGradOpCost
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
class
DistributedSoftmax
(
DistributedOperatorImplContainer
):
...
...
@@ -41,6 +47,62 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
self
.
_forward_implemented
=
False
self
.
_backward_implemented
=
False
def
calc_cost
(
self
,
op_role
,
dist_op
,
ctx
,
cluster
):
cost
=
None
if
int
(
op_role
)
==
int
(
OpRole
.
Backward
):
cost
=
self
.
calc_bwd_cost
(
dist_op
,
ctx
,
cluster
)
else
:
cost
=
self
.
calc_fwd_cost
(
dist_op
,
ctx
,
cluster
)
assert
cost
is
not
None
return
cost
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
SoftmaxOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res_cost
=
[
cost_mapping
]
return
res_cost
def
calc_bwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
res
=
[]
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
dist_attr
=
dist_op
.
dist_attr
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
cost_mapping
=
build_comp_costs_from_descs
(
SoftmaxGradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
backward_op
=
dist_op
.
serial_op
main_block
=
backward_op
.
block
need_gradient_allreduce
=
False
vars
=
main_block
.
vars
for
input_name
in
backward_op
.
desc
.
input_names
():
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
if
"@GRAD"
not
in
varname
and
is_parameter_related
(
varname
,
main_block
):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
varname
)
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
:
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
varname
+
"@GRAD"
]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
def
is_input_compatible
(
self
,
dist_op
):
op_desc
=
dist_op
.
serial_op
.
desc
op_dist_attr
=
dist_op
.
dist_attr
...
...
python/paddle/distributed/auto_parallel/operators/dist_transpose.py
浏览文件 @
ec1e0d5a
...
...
@@ -16,6 +16,7 @@ from .common import DistributedOperatorImplContainer
from
.common
import
DistributedOperatorImpl
from
.common
import
register_distributed_operator_impl_container
from
.common
import
register_distributed_operator_impl
from
.common
import
is_parameter_related
from
..utils
import
is_dim_shard
from
..utils
import
is_dim_replicate
from
..utils
import
is_valid_list_index
...
...
@@ -23,6 +24,10 @@ from ..utils import compute_compatible_dim_mapping
from
..utils
import
compute_compatible_dims_mapping
from
..utils
import
compute_compatible_and_update_dim_mapping
from
.dist_default
import
DistributedDefaultImpl0
from
..cost
import
AllreduceSumOpCost
,
Transpose2OpCost
,
Transpose2GradOpCost
from
..cost
import
build_comp_desc_from_dist_op
,
build_comm_desc_from_dist_op
,
build_dp_costs
from
..cost
import
build_comp_costs_from_descs
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
class
DistributedTranspose2
(
DistributedOperatorImplContainer
):
...
...
@@ -116,6 +121,63 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
return
changed
def
calc_cost
(
self
,
op_role
,
dist_op
,
ctx
,
cluster
):
cost
=
None
if
int
(
op_role
)
==
int
(
OpRole
.
Backward
):
cost
=
self
.
calc_bwd_cost
(
dist_op
,
ctx
,
cluster
)
else
:
cost
=
self
.
calc_fwd_cost
(
dist_op
,
ctx
,
cluster
)
assert
cost
is
not
None
return
cost
def
calc_fwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
op_type
=
dist_op
.
serial_op
.
type
cost_mapping
=
build_comp_costs_from_descs
(
Transpose2OpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res_cost
=
[
cost_mapping
]
return
res_cost
def
calc_bwd_cost
(
self
,
dist_op
,
ctx
,
cluster
):
# calc comp op cost
res
=
[]
desc_mapping
=
build_comp_desc_from_dist_op
(
dist_op
=
dist_op
,
dist_context
=
ctx
)
dist_attr
=
dist_op
.
dist_attr
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
op_type
=
dist_op
.
serial_op
.
type
cost_mapping
=
build_comp_costs_from_descs
(
Transpose2GradOpCost
,
ctx
,
processes
,
desc_mapping
,
cluster
)
res
.
append
(
cost_mapping
)
backward_op
=
dist_op
.
serial_op
main_block
=
backward_op
.
block
need_gradient_allreduce
=
False
vars
=
main_block
.
vars
for
input_name
in
backward_op
.
desc
.
input_names
():
for
varname
in
backward_op
.
desc
.
input
(
input_name
):
if
"@GRAD"
not
in
varname
and
is_parameter_related
(
varname
,
main_block
):
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
var_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
varname
)
mesh_shape
=
process_mesh
.
topology
batch_size_axis
=
var_dim_mapping
[
0
]
if
batch_size_axis
>
-
1
and
mesh_shape
[
batch_size_axis
]
>
1
:
parallel_axis
=
batch_size_axis
attrs
=
{
"use_calc_stream"
:
True
}
var_names
=
[
varname
+
"@GRAD"
]
build_dp_costs
(
res
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
)
return
res
@
staticmethod
def
forward
(
ctx
,
*
args
,
**
kwargs
):
DistributedDefaultImpl0
.
forward
(
ctx
,
*
args
,
**
kwargs
)
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py
浏览文件 @
ec1e0d5a
...
...
@@ -47,7 +47,7 @@ def parallelizer(program_func, rank):
completer
.
complete_backward_annotation
(
main_program
)
dist_context
.
block_state
.
parse_backward_blocks
(
main_program
)
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.001
)
# generate opt and complete opt
with
program_guard
(
main_program
,
startup_program
):
optimize_ops
=
copy
.
deepcopy
(
optimizer
).
apply_gradients
(
params_grads
)
...
...
@@ -59,7 +59,7 @@ def parallelizer(program_func, rank):
class
TestDistOpCost
(
unittest
.
TestCase
):
def
test_dist_
fill_constatnt_batch_size_like_op_cost
(
self
):
def
test_dist_
op_cost_part1
(
self
):
def
make_program
():
main_program
=
paddle
.
static
.
Program
()
...
...
@@ -79,7 +79,7 @@ class TestDistOpCost(unittest.TestCase):
tmp
=
paddle
.
fluid
.
layers
.
fill_constant_batch_size_like
(
input
=
x
,
shape
=
[
2
,
8
],
value
=
1
,
dtype
=
'float32'
)
weight_attr
=
paddle
.
ParamAttr
()
linear
=
paddle
.
nn
.
Linear
(
8
,
8
,
weight_attr
=
weight_attr
)
linear
=
paddle
.
nn
.
Linear
(
8
,
1
,
weight_attr
=
weight_attr
)
linear_out
=
linear
(
x
)
gelu_out
=
paddle
.
nn
.
functional
.
gelu
(
linear_out
)
# default op with dp
...
...
@@ -109,6 +109,112 @@ class TestDistOpCost(unittest.TestCase):
dist_context
,
cluster
)
self
.
assertTrue
(
dist_op_cost
)
def
test_dist_op_cost_part2
(
self
):
def
make_program
():
main_program
=
paddle
.
static
.
Program
()
start_program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main_program
,
start_program
):
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
4
],
dtype
=
'float32'
)
x
.
stop_gradient
=
True
label
=
paddle
.
static
.
data
(
name
=
"label"
,
shape
=
[
8
,
1
],
dtype
=
'float32'
)
label
.
stop_gradient
=
True
auto
.
shard_tensor
(
x
,
dist_attr
=
{
"process_mesh"
:
auto
.
ProcessMesh
([
0
,
1
]),
"dims_mapping"
:
[
0
]
})
auto
.
shard_tensor
(
label
,
dist_attr
=
{
"process_mesh"
:
auto
.
ProcessMesh
([
0
,
1
]),
"dims_mapping"
:
[
0
,
-
1
]
})
# embedding
tmp
=
paddle
.
fluid
.
layers
.
fill_constant_batch_size_like
(
input
=
x
,
shape
=
[
4
],
value
=
1
,
dtype
=
'int32'
)
embedding
=
paddle
.
nn
.
Embedding
(
10
,
8
)
out
=
embedding
(
tmp
)
# row parallel embedding
for
op
in
main_program
.
global_block
().
ops
:
if
op
.
type
==
"lookup_table_v2"
:
W
=
main_program
.
global_block
().
vars
[
op
.
input
(
"W"
)[
0
]]
auto
.
shard_tensor
(
W
,
dist_attr
=
{
"process_mesh"
:
auto
.
ProcessMesh
([
0
,
1
]),
"dims_mapping"
:
[
0
,
-
1
]
})
out
=
paddle
.
fluid
.
layers
.
transpose
(
out
,
[
1
,
0
])
# [8, 2] [-1, 0]
# matmul
param1
=
paddle
.
fluid
.
layers
.
create_parameter
(
[
4
,
8
],
paddle
.
float32
)
# [2, 8] [0, -1]
auto
.
shard_tensor
(
param1
,
dist_attr
=
{
"process_mesh"
:
auto
.
ProcessMesh
([
0
,
1
]),
"dims_mapping"
:
[
0
,
-
1
]
})
param2
=
paddle
.
fluid
.
layers
.
create_parameter
(
[
8
,
8
],
paddle
.
float32
)
# [8, 4] [-1, 0]
auto
.
shard_tensor
(
param2
,
dist_attr
=
{
"process_mesh"
:
auto
.
ProcessMesh
([
0
,
1
]),
"dims_mapping"
:
[
-
1
,
0
]
})
out1
=
paddle
.
fluid
.
layers
.
matmul
(
out
,
param1
)
# [8, 8] [-1, -1]
tmp_param
=
paddle
.
fluid
.
layers
.
create_parameter
(
[
8
,
8
],
paddle
.
float32
)
# [8, 8] [-1, -1]
auto
.
shard_tensor
(
param2
,
dist_attr
=
{
"process_mesh"
:
auto
.
ProcessMesh
([
0
,
1
]),
"dims_mapping"
:
[
-
1
,
-
1
]
})
tmp_out
=
paddle
.
fluid
.
layers
.
matmul
(
out1
,
tmp_param
)
out2
=
paddle
.
fluid
.
layers
.
matmul
(
tmp_out
,
param2
)
# [8, 4] [-1, 0]
out8
=
paddle
.
fluid
.
layers
.
transpose
(
out2
,
[
1
,
0
])
# [4, 8] [0, -1]
# reshape
out9
=
paddle
.
reshape
(
out8
,
[
8
,
2
,
4
])
# [4, 2, 4] [0, -1, -1]
tmp_reshape_out
=
paddle
.
reshape
(
out9
,
[
8
,
4
,
2
])
out10
=
paddle
.
reshape
(
tmp_reshape_out
,
[
8
,
8
])
# [4, 8] [0, -1]
# softmax
softmax
=
paddle
.
nn
.
Softmax
()
out11
=
softmax
(
out10
)
error_cost
=
paddle
.
nn
.
functional
.
square_error_cost
(
out11
,
label
)
loss
=
paddle
.
mean
(
error_cost
)
return
main_program
,
start_program
,
loss
main_program
,
dist_context
=
parallelizer
(
make_program
,
0
)
ops
=
main_program
.
global_block
().
ops
cluster
=
Cluster
()
cluster
.
gen_default_config_cluster
(
device_count
=
2
)
for
idx
,
op
in
enumerate
(
ops
):
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
op_dist_attr
=
dist_op
.
dist_attr
processes
=
op_dist_attr
.
process_mesh
.
processes
if
is_elementwise_op
(
op
.
type
):
container
=
get_distributed_operator_impl_container
(
"elementwise"
)
else
:
container
=
get_distributed_operator_impl_container
(
op_dist_attr
.
impl_type
)
dist_impl
=
container
.
impls
[
op_dist_attr
.
impl_idx
]
dist_op_cost
=
dist_impl
.
calc_cost
(
op
.
attr
(
'op_role'
),
dist_op
,
dist_context
,
cluster
)
self
.
assertTrue
(
dist_op_cost
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录