Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e379455a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e379455a
编写于
7月 12, 2022
作者:
C
caozhou
提交者:
GitHub
7月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Auto Parallel】update base cost (#44095)
* update base cost * update unittest of cost model * add unittest
上级
3333a439
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
586 addition
and
50 deletion
+586
-50
python/paddle/distributed/auto_parallel/cost/base_cost.py
python/paddle/distributed/auto_parallel/cost/base_cost.py
+320
-45
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_base_cost.py
...dle/fluid/tests/unittests/auto_parallel/test_base_cost.py
+234
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_cluster.py
...addle/fluid/tests/unittests/auto_parallel/test_cluster.py
+4
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_comm_cost.py
...dle/fluid/tests/unittests/auto_parallel/test_comm_cost.py
+4
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py
...luid/tests/unittests/auto_parallel/test_new_cost_model.py
+23
-5
未找到文件。
python/paddle/distributed/auto_parallel/cost/base_cost.py
浏览文件 @
e379455a
...
...
@@ -17,8 +17,12 @@ from functools import reduce
import
paddle
from
..
cluster
import
LinkType
from
..
utils
import
_get_comm_group
,
_get_corresponding_rank
from
..process_group
import
get_process_group
from
..cluster
import
LinkType
from
..dist_tensor
import
DistributedTensor
from
..utils
import
_get_idx_in_axis
from
..dist_tensor
import
DistributedTensor
COMM_OP_TYPE
=
[
"send_v2"
,
"recv_v2"
,
"c_broadcast"
,
"c_allgather"
,
"c_allreduce_sum"
,
...
...
@@ -28,33 +32,22 @@ NON_COMP_TYPE = ["while"] + COMM_OP_TYPE
_g_op_cost_factory
=
{}
def
build_comm_desc
(
op_type
,
group_ranks
,
dtype
,
shape
,
attrs
=
None
):
desc
=
{}
desc
[
"op"
]
=
op_type
desc
[
"group_ranks"
]
=
group_ranks
desc
[
"inputs"
]
=
{
"X"
:
[(
dtype
,
shape
)]}
if
attrs
is
not
None
:
desc
[
"attrs"
]
=
attrs
return
desc
def
build_comp_desc_from_op
(
op
):
"""Build the description of computation op."""
# NOTE: The desc is for serial op.
from
..reshard
import
get_var_with_recursion
def
_parse_op_to_desc
(
op
,
dist_context
=
None
):
desc
=
{}
desc
[
"op"
]
=
op
.
type
# The desc of concat op is {"op": "concat", "inputs": {"X": [(paddle.float32, [20, 20]), (paddle.float32, [20, 20])]}, "outputs": {"Out": [(paddle.float32, [20, 40])], "attrs": {"axis": -1}}}
vars
=
op
.
block
.
vars
desc
[
"op"
]
=
op
.
type
input_desc
=
OrderedDict
()
for
input_name
in
op
.
input_names
:
var_name_list
=
op
.
input
(
input_name
)
var_desc
=
[]
for
var_name
in
var_name_list
:
var
=
vars
[
var_name
]
shape
=
None
if
dist_context
is
not
None
:
dist_tensor
=
dist_context
.
get_dist_tensor_for_program
(
var
)
shape
=
dist_tensor
.
local_sizes
()
else
:
shape
=
var
.
shape
assert
shape
is
not
None
var
=
get_var_with_recursion
(
var_name
,
op
.
block
,
op
.
block
.
program
)
shape
=
var
.
shape
var_desc
.
append
((
var
.
dtype
,
shape
))
input_desc
[
input_name
]
=
var_desc
desc
[
"inputs"
]
=
input_desc
...
...
@@ -64,14 +57,8 @@ def _parse_op_to_desc(op, dist_context=None):
var_name_list
=
op
.
output
(
out_name
)
var_desc
=
[]
for
var_name
in
var_name_list
:
var
=
vars
[
var_name
]
shape
=
None
if
dist_context
is
not
None
:
dist_tensor
=
dist_context
.
get_dist_tensor_for_program
(
var
)
shape
=
dist_tensor
.
local_sizes
()
else
:
shape
=
var
.
shape
assert
shape
is
not
None
var
=
get_var_with_recursion
(
var_name
,
op
.
block
,
op
.
block
.
program
)
shape
=
var
.
shape
var_desc
.
append
((
var
.
dtype
,
shape
))
output_desc
[
out_name
]
=
var_desc
desc
[
"outputs"
]
=
output_desc
...
...
@@ -82,19 +69,101 @@ def _parse_op_to_desc(op, dist_context=None):
return
desc
def
parse_to_desc
(
op
=
None
,
dist_op
=
None
,
dist_context
=
None
):
desc
=
None
if
op
is
None
and
dist_op
is
not
None
and
dist_context
is
not
None
:
desc
=
_parse_op_to_desc
(
op
=
dist_op
.
serial_op
,
dist_context
=
dist_context
)
elif
op
is
not
None
and
dist_op
is
None
and
dist_context
is
None
:
desc
=
_parse_op_to_desc
(
op
)
return
desc
def
parse_desc_to_str
(
desc
):
def
build_comp_desc_from_dist_op
(
dist_op
,
dist_context
):
"""Build descriptions of computation op distributed on the processes."""
from
..reshard
import
get_var_with_recursion
op_descs
=
{}
op
=
dist_op
.
serial_op
dist_attr
=
dist_op
.
dist_attr
process_mesh
=
dist_attr
.
process_mesh
assert
process_mesh
,
"Process mesh must not be None."
processes
=
process_mesh
.
processes
for
process
in
processes
:
desc
=
{}
desc
[
"op"
]
=
op
.
type
attr_desc
=
op
.
all_attrs
()
# NOTE: The attrs of desc is replica of serial op, there may be a bug if shape need to be partitioned involved in attrs.
desc
[
"attrs"
]
=
attr_desc
input_desc
=
OrderedDict
()
output_desc
=
OrderedDict
()
# Get partitioned shape of input
for
input_name
in
op
.
input_names
:
var_name_list
=
op
.
input
(
input_name
)
var_desc
=
[]
for
var_name
in
var_name_list
:
var
=
get_var_with_recursion
(
var_name
,
op
.
block
,
op
.
block
.
program
)
# Use op input_dims_mapping
dims_mapping
=
dist_attr
.
get_input_dims_mapping
(
var_name
)
global_sizes
=
var
.
shape
# NOTE: When support uneven partition, the shard_sizes will be got from dist_attr.
shard_sizes
=
None
topology
=
process_mesh
.
topology
shape
=
DistributedTensor
.
get_local_sizes
(
global_sizes
,
dims_mapping
,
topology
,
processes
,
process
,
shard_sizes
)
var_desc
.
append
((
var
.
dtype
,
shape
))
# For special op such as embedding and its grad op
if
op
.
type
==
"c_embedding"
or
op
.
type
==
"lookup_table_v2"
or
op
.
type
==
"c_embedding_grad"
or
op
.
type
==
"lookup_table_v2_grad"
:
if
input_name
==
"W"
:
embedding_row_dim_mapping
=
dist_attr
.
get_input_dims_mapping
(
op
.
input
(
input_name
)[
0
])[
0
]
relative_idx
=
_get_idx_in_axis
(
processes
,
dist_attr
.
process_mesh
.
topology
,
embedding_row_dim_mapping
,
process
)
per_part_size
=
shape
[
0
]
relative_idx
=
relative_idx
*
per_part_size
desc
[
"attrs"
][
"start_index"
]
=
relative_idx
input_desc
[
input_name
]
=
var_desc
desc
[
"inputs"
]
=
input_desc
for
out_name
in
op
.
output_names
:
var_name_list
=
op
.
output
(
out_name
)
var_desc
=
[]
for
var_name
in
var_name_list
:
# Use op output_dims_mapping
var
=
get_var_with_recursion
(
var_name
,
op
.
block
,
op
.
block
.
program
)
dist_attr
=
dist_op
.
dist_attr
dims_mapping
=
dist_attr
.
get_output_dims_mapping
(
var_name
)
process_mesh
=
dist_attr
.
process_mesh
global_sizes
=
var
.
shape
shard_sizes
=
None
processes
=
process_mesh
.
processes
topology
=
process_mesh
.
topology
shape
=
DistributedTensor
.
get_local_sizes
(
global_sizes
,
dims_mapping
,
topology
,
processes
,
process
,
shard_sizes
)
var_desc
.
append
((
var
.
dtype
,
shape
))
# For special op such as fill_constant_batch_size_like
if
op
.
type
==
"fill_constant_batch_size_like"
:
# Modify shape attr according to how output are partitioned
out_name
=
var_name_list
[
0
]
dims_mapping
=
dist_attr
.
get_output_dims_mapping
(
out_name
)
process_mesh_shape
=
dist_attr
.
process_mesh
.
topology
shape_list
=
op
.
attr
(
"shape"
)
# Modify target shape
for
idx
,
axis
in
enumerate
(
dims_mapping
):
if
axis
>=
0
:
shape_list
[
idx
]
=
shape_list
[
idx
]
//
process_mesh_shape
[
axis
]
desc
[
"attrs"
][
"shape"
]
=
shape_list
output_desc
[
out_name
]
=
var_desc
desc
[
"outputs"
]
=
output_desc
op_descs
[
process
]
=
desc
return
op_descs
def
build_comp_desc_str_for_predict
(
desc
):
# NOTE: The description format may change in the future.
def
_parse_dtype
(
dtype
):
dtype_str
=
""
if
dtype
==
paddle
.
float32
:
...
...
@@ -135,8 +204,208 @@ def parse_desc_to_str(desc):
shape_str
=
"["
+
","
.
join
(
shape_list
)
+
"]"
desc_str_list
+=
[
dtype_str
,
dims_str
,
shape_str
]
desc_str
=
"_"
.
join
(
desc_str_list
)
attrs
=
desc
[
"attrs"
]
parse_result
=
(
desc_str
,
attrs
)
return
parse_result
def
build_comm_desc_from_dist_op
(
op_type
,
dist_op
,
ctx
,
var_names
,
attrs
=
None
,
parallel_axis
=
None
,
group_ranks
=
None
):
"""Build descriptions of communication op distributed on the processes."""
from
..reshard
import
get_var_with_recursion
specific_op_type
=
[]
dist_attr
=
dist_op
.
dist_attr
assert
dist_attr
,
"Dist attr must not be None."
process_mesh
=
dist_attr
.
process_mesh
assert
process_mesh
,
"Process mesh must not be None."
processes
=
process_mesh
.
processes
op_descs
=
{}
for
process
in
processes
:
rank_id
=
process
desc
=
{}
desc
[
"op"
]
=
op_type
op_attrs
=
None
comm_group_ranks
=
None
if
op_type
not
in
specific_op_type
:
serial_op
=
dist_op
.
serial_op
input_list
=
[]
# The var_names usually contain just one item.
for
var_name
in
var_names
:
dist_attr
=
dist_op
.
dist_attr
has_found
=
False
# Find var_name in serial op input or output
for
name
in
dist_op
.
serial_op
.
input_arg_names
:
# If a tensor is the input of multi ops, sum the grad of all ops, so the name will be varname@RENAME@block@0 and so on.
if
var_name
in
name
:
var_name
=
name
has_found
=
True
break
if
not
has_found
:
for
name
in
dist_op
.
serial_op
.
output_arg_names
:
if
var_name
in
name
:
var_name
=
name
has_found
=
True
break
assert
has_found
var
=
get_var_with_recursion
(
var_name
,
serial_op
.
block
,
serial_op
.
block
.
program
)
dims_mapping
=
dist_attr
.
get_input_dims_mapping
(
var_name
)
if
var_name
in
dist_op
.
serial_op
.
input_arg_names
else
dist_attr
.
get_output_dims_mapping
(
var_name
)
global_sizes
=
var
.
shape
shard_sizes
=
None
topology
=
process_mesh
.
topology
shape
=
DistributedTensor
.
get_local_sizes
(
global_sizes
,
dims_mapping
,
topology
,
processes
,
process
,
shard_sizes
)
input_list
.
append
((
var
.
dtype
,
shape
))
# NOTE: The input_name of comm ops used usually is X.
desc
[
"inputs"
]
=
{
"X"
:
input_list
}
# Get comm group by parallel_axis or the given group_ranks.
if
parallel_axis
is
not
None
:
process_mesh_shape
=
process_mesh
.
topology
process_mesh_group
=
process_mesh
.
processes
comm_group_ranks
=
_get_comm_group
(
process_mesh_group
,
process_mesh_shape
,
parallel_axis
,
rank_id
)
elif
group_ranks
is
not
None
:
comm_group_ranks
=
group_ranks
else
:
raise
ValueError
(
"The parallel_axis and group_ranks can not be None in the same."
)
if
attrs
is
not
None
:
assert
isinstance
(
attrs
,
dict
)
op_attrs
=
attrs
else
:
op_attrs
=
{}
desc
[
"attrs"
]
=
op_attrs
desc
[
"group_ranks"
]
=
comm_group_ranks
op_descs
[
rank_id
]
=
desc
return
op_descs
def
build_comm_desc
(
op_type
,
group_ranks
,
dtype
,
shape
,
attrs
=
None
):
"""Build a comm desc directly."""
desc
=
{}
desc
[
"op"
]
=
op_type
desc
[
"group_ranks"
]
=
group_ranks
desc
[
"inputs"
]
=
{
"X"
:
[(
dtype
,
shape
)]}
desc
[
"attrs"
]
=
attrs
return
desc
return
desc_str
def
build_comm_costs_from_descs
(
op_cost_class
,
ctx
,
processes
,
descs
,
cluster
):
"""Build comm costs by descriptions"""
comm_context
=
CommContext
(
cluster
)
group_ranks_list
=
[]
comm_op_cost_list
=
[]
for
process
in
processes
:
desc
=
descs
[
process
]
group_ranks
=
desc
[
"group_ranks"
]
if
group_ranks
not
in
group_ranks_list
:
group_ranks_list
.
append
(
group_ranks
)
comm_op_cost
=
op_cost_class
(
op_desc
=
desc
,
comm_context
=
comm_context
)
comm_op_cost_list
.
append
(
comm_op_cost
)
return
comm_op_cost_list
def
build_comp_costs_from_descs
(
op_cost_class
,
ctx
,
processes
,
descs
,
cluster
):
"""Build comp costs by descriptions."""
costs
=
{}
for
process
in
processes
:
costs
[
process
]
=
op_cost_class
(
op_desc
=
descs
[
process
],
cluster
=
cluster
)
return
costs
def
build_dp_costs
(
result
,
dist_op
,
ctx
,
var_names
,
attrs
,
parallel_axis
,
cluster
):
"""DP cost contains a allreduce_sum op cost and a scale op cost"""
# The costs will be appended in the given result.
from
..reshard
import
get_var_with_recursion
dist_attr
=
dist_op
.
dist_attr
process_mesh
=
dist_attr
.
process_mesh
processes
=
process_mesh
.
processes
assert
len
(
var_names
)
==
1
vars
=
dist_op
.
serial_op
.
block
.
vars
var_name
=
var_names
[
0
]
has_found
=
False
for
name
in
dist_op
.
serial_op
.
input_arg_names
:
if
var_name
in
name
:
var_name
=
name
has_found
=
True
break
if
not
has_found
:
for
name
in
dist_op
.
serial_op
.
output_arg_names
:
if
var_name
in
name
:
var_name
=
name
has_found
=
True
break
if
not
has_found
:
return
c_allreduce_sum_descs
=
build_comm_desc_from_dist_op
(
"c_allreduce_sum"
,
dist_op
,
ctx
,
var_names
,
attrs
=
attrs
,
parallel_axis
=
parallel_axis
)
comm_cost_list
=
build_comm_costs_from_descs
(
_g_op_cost_factory
[
"c_allreduce_sum"
],
ctx
,
processes
,
c_allreduce_sum_descs
,
cluster
)
result
.
append
(
comm_cost_list
)
# The scale op just on the group_ranks
for
comm_cost
in
comm_cost_list
:
group_ranks
=
comm_cost
.
group_ranks
dp_degree
=
len
(
group_ranks
)
scale_costs
=
{}
op_type
=
"scale"
for
rank
in
group_ranks
:
desc
=
{}
desc
[
"op"
]
=
op_type
desc
[
"inputs"
]
=
{}
dims_mapping
=
dist_attr
.
get_input_dims_mapping
(
var_name
)
if
dist_attr
.
get_input_dims_mapping
(
var_name
)
is
not
None
else
dist_attr
.
get_output_dims_mapping
(
var_name
)
var
=
get_var_with_recursion
(
var_name
,
dist_op
.
serial_op
.
block
,
dist_op
.
serial_op
.
block
.
program
)
global_sizes
=
var
.
shape
shard_sizes
=
None
topology
=
process_mesh
.
topology
shape
=
DistributedTensor
.
get_local_sizes
(
global_sizes
,
dims_mapping
,
topology
,
processes
,
rank
,
shard_sizes
)
desc
[
"inputs"
][
"X"
]
=
[(
var
.
dtype
,
shape
)]
attrs
=
{
"scale"
:
1.0
/
dp_degree
}
desc
[
"attrs"
]
=
attrs
scale_op_cost
=
_g_op_cost_factory
[
"scale"
](
op_desc
=
desc
,
cluster
=
cluster
)
scale_costs
[
rank
]
=
scale_op_cost
result
.
append
(
scale_costs
)
class
CommContext
:
...
...
@@ -174,6 +443,8 @@ class CommContext:
# set default
self
.
base_ring
=
8.4
self
.
base_tree
=
0.
# self.base_inter_ring = 9.6
# self.base_inter_tree = 28
# NVL in default
self
.
intra_ring
=
3.4
self
.
intra_tree
=
28
...
...
@@ -441,6 +712,8 @@ class CommOpCost(OpCost):
@
property
def
comm_count
(
self
):
from
..reshard
import
get_var_with_recursion
if
self
.
_comm_count
is
None
:
dtype
=
None
shape
=
None
...
...
@@ -448,7 +721,8 @@ class CommOpCost(OpCost):
vars
=
self
.
op
.
block
.
vars
# NOTE: The tensor communicated input_name is "X" in default. Otherwise, this function should be overrided
var_name
=
self
.
op
.
input
(
"X"
)[
0
]
var
=
vars
[
var_name
]
var
=
get_var_with_recursion
(
var_name
,
self
.
op
.
block
,
self
.
program
)
dtype
=
var
.
dtype
shape
=
var
.
shape
elif
self
.
op_desc
is
not
None
:
...
...
@@ -464,9 +738,10 @@ class CommOpCost(OpCost):
factor
=
1
elif
dtype
==
paddle
.
float16
:
factor
=
2
elif
dtype
==
paddle
.
bool
:
factor
=
8
else
:
raise
TypeError
(
"This dtype {} is not supported now"
.
format
(
dtype
))
raise
ValueError
(
"Unsupported comm dtype {}"
.
format
(
dtype
))
comm_count
=
reduce
(
lambda
x
,
y
:
x
*
y
,
shape
)
*
factor
self
.
_comm_count
=
comm_count
...
...
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
e379455a
...
...
@@ -51,6 +51,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_cluster MODULES test_cluster ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_comm_cost MODULES test_comm_cost ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_comp_cost MODULES test_comp_cost ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_base_cost MODULES test_base_cost ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_dist_context MODULES test_dist_context ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_prim_dist_op MODULES test_prim_dist_op ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_to_static MODULES test_to_static ENVS
${
dist_ENVS
}
)
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_base_cost.py
0 → 100644
浏览文件 @
e379455a
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
os
import
json
import
tempfile
import
paddle
import
paddle.nn
as
nn
import
paddle.static
as
static
import
paddle.nn.functional
as
F
import
paddle.utils
as
utils
import
paddle.distributed.auto_parallel
as
auto
from
paddle.distributed.auto_parallel.completion
import
Completer
from
paddle.distributed.auto_parallel.dist_context
import
DistributedContext
from
paddle.distributed
import
fleet
from
paddle.distributed.auto_parallel.parallelizer
import
AutoParallelizer
from
paddle.distributed.auto_parallel.partitioner
import
Partitioner
from
paddle.distributed.auto_parallel.reshard
import
Resharder
from
paddle.distributed.auto_parallel.utils
import
print_program_with_dist_attr
from
paddle.distributed.auto_parallel.cluster
import
Cluster
from
paddle.distributed.auto_parallel.cost
import
CommContext
from
paddle.distributed.auto_parallel.cost.base_cost
import
build_comp_desc_from_dist_op
from
paddle.distributed.auto_parallel.cost.base_cost
import
build_comm_desc_from_dist_op
from
paddle.distributed.auto_parallel.cost.base_cost
import
build_comm_costs_from_descs
from
paddle.distributed.auto_parallel.cost.base_cost
import
build_comp_costs_from_descs
from
paddle.distributed.auto_parallel.cost.base_cost
import
build_dp_costs
from
paddle.distributed.auto_parallel.cost
import
AllreduceSumOpCost
from
paddle.distributed.auto_parallel.cost
import
_g_op_cost_factory
from
test_cluster
import
cluster_json
paddle
.
enable_static
()
_global_parallel_strategy
=
"dp_mp_pp"
_global_process_mesh
=
auto
.
ProcessMesh
([[[
0
,
1
],
[
4
,
5
]],
[[
2
,
3
],
[
6
,
7
]]])
PP_MESH_0
=
auto
.
ProcessMesh
([[
0
,
1
],
[
4
,
5
]])
PP_MESH_1
=
auto
.
ProcessMesh
([[
2
,
3
],
[
6
,
7
]])
class
MLPLayer
(
nn
.
Layer
):
def
__init__
(
self
,
hidden_size
=
1024
,
intermediate_size
=
4
*
1024
,
initializer_range
=
0.02
):
super
(
MLPLayer
,
self
).
__init__
()
d_model
=
hidden_size
dim_feedforward
=
intermediate_size
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Normal
(
mean
=
0.0
,
std
=
initializer_range
))
bias_attr
=
None
self
.
linear0
=
nn
.
Linear
(
d_model
,
dim_feedforward
,
weight_attr
,
bias_attr
=
bias_attr
)
self
.
linear1
=
nn
.
Linear
(
dim_feedforward
,
d_model
,
weight_attr
,
bias_attr
=
bias_attr
)
self
.
norm
=
nn
.
LayerNorm
(
d_model
,
epsilon
=
1e-5
)
def
forward
(
self
,
input
):
auto
.
shard_tensor
(
self
.
linear0
.
weight
,
dist_attr
=
{
"process_mesh"
:
PP_MESH_0
,
"dims_mapping"
:
[
-
1
,
1
]
})
auto
.
shard_tensor
(
self
.
linear1
.
weight
,
dist_attr
=
{
"process_mesh"
:
PP_MESH_1
,
"dims_mapping"
:
[
1
,
-
1
]
})
out
=
self
.
norm
(
input
)
out
=
self
.
linear0
(
out
)
out
=
F
.
gelu
(
out
,
approximate
=
True
)
out
=
self
.
linear1
(
out
)
return
out
def
mlp_forward
(
train_program
,
start_program
):
with
static
.
program_guard
(
train_program
,
start_program
),
utils
.
unique_name
.
guard
():
batch_size
=
4
hidden_size
=
1024
sequence_len
=
512
input
=
static
.
data
(
name
=
"input"
,
shape
=
[
batch_size
,
hidden_size
],
dtype
=
'float32'
)
label
=
static
.
data
(
name
=
"label"
,
shape
=
[
batch_size
,
1
],
dtype
=
'float32'
)
fill_constant_out
=
paddle
.
fluid
.
layers
.
fill_constant_batch_size_like
(
input
=
input
,
shape
=
[
batch_size
],
value
=
1
,
dtype
=
"int32"
)
embedding
=
paddle
.
nn
.
Embedding
(
10
,
hidden_size
,
sparse
=
True
)
embedding_out
=
embedding
(
fill_constant_out
)
auto
.
shard_tensor
(
input
,
dist_attr
=
{
"process_mesh"
:
PP_MESH_0
,
"dims_mapping"
:
[
0
,
-
1
]
})
auto
.
shard_tensor
(
label
,
dist_attr
=
{
"process_mesh"
:
PP_MESH_1
,
"dims_mapping"
:
[
0
,
-
1
]
})
mlp
=
MLPLayer
(
hidden_size
=
hidden_size
,
intermediate_size
=
4
*
hidden_size
,
initializer_range
=
0.02
)
predict
=
mlp
(
embedding_out
)
error_cost
=
paddle
.
nn
.
functional
.
square_error_cost
(
predict
,
label
)
loss
=
paddle
.
mean
(
error_cost
)
return
loss
,
train_program
,
start_program
def
get_prog
(
train_program
,
startup_program
,
dist_context
,
rank_id
):
global
_global_process_mesh
dist_context
.
process_mesh
=
_global_process_mesh
loss
,
train_program
,
startup_program
=
mlp_forward
(
train_program
,
startup_program
)
fleet
.
_user_defined_strategy
=
fleet
.
DistributedStrategy
()
fleet
.
user_defined_optimizer
=
paddle
.
fluid
.
optimizer
.
AdamOptimizer
()
parallelizer
=
AutoParallelizer
(
fleet
)
parallelizer
.
_dist_context
=
dist_context
# serial forward & backward completion
completer
=
Completer
(
dist_context
)
complete_train_program
=
completer
.
complete_forward_annotation
(
train_program
)
dist_context
.
block_state
.
parse_forward_blocks
(
complete_train_program
)
params_grads
=
parallelizer
.
_generate_backward
(
complete_train_program
,
startup_program
,
loss
,
parameter_list
=
None
,
no_grad_set
=
None
,
callbacks
=
None
)
return
train_program
,
startup_program
,
params_grads
class
TestBaseCost
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
temp_dir
=
tempfile
.
TemporaryDirectory
()
def
tearDown
(
self
):
self
.
temp_dir
.
cleanup
()
def
test_base_cost
(
self
):
# Build cluster
cluster_json_path
=
os
.
path
.
join
(
self
.
temp_dir
.
name
,
"auto_parallel_cluster.json"
)
cluster_json_object
=
json
.
loads
(
cluster_json
)
with
open
(
cluster_json_path
,
"w"
)
as
cluster_json_file
:
json
.
dump
(
cluster_json_object
,
cluster_json_file
)
cluster
=
Cluster
()
cluster
.
build_from_file
(
cluster_json_path
)
train_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
dist_context
=
DistributedContext
()
rank_id
=
2
train_program
,
startup_program
,
params_grads
=
get_prog
(
train_program
,
startup_program
,
dist_context
,
rank_id
)
for
op
in
train_program
.
global_block
().
ops
:
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
dist_op
:
processes
=
dist_op
.
dist_attr
.
process_mesh
.
processes
comp_descs
=
build_comp_desc_from_dist_op
(
dist_op
,
dist_context
)
self
.
assertTrue
(
isinstance
(
comp_descs
,
dict
)
and
comp_descs
)
var_names
=
None
if
op
.
input_arg_names
:
var_names
=
op
.
input_arg_names
[
0
]
comm_descs
=
build_comm_desc_from_dist_op
(
"c_allreduce_sum"
,
dist_op
,
dist_context
,
var_names
,
attrs
=
None
,
parallel_axis
=
0
,
group_ranks
=
None
)
self
.
assertTrue
(
isinstance
(
comm_descs
,
dict
)
and
comm_descs
)
comm_descs
=
build_comm_desc_from_dist_op
(
"c_allreduce_sum"
,
dist_op
,
dist_context
,
var_names
,
attrs
=
None
,
parallel_axis
=
None
,
group_ranks
=
processes
)
self
.
assertTrue
(
isinstance
(
comm_descs
,
dict
)
and
comm_descs
)
comm_costs
=
build_comm_costs_from_descs
(
AllreduceSumOpCost
,
dist_context
,
processes
,
comm_descs
,
cluster
)
self
.
assertTrue
(
comm_costs
)
comp_costs
=
build_comp_costs_from_descs
(
_g_op_cost_factory
[
op
.
type
],
dist_context
,
processes
,
comp_descs
,
cluster
)
self
.
assertTrue
(
comp_costs
)
result
=
[]
build_dp_costs
(
result
,
dist_op
,
dist_context
,
var_names
[
0
],
None
,
0
,
cluster
)
self
.
assertTrue
(
result
)
# Remove unnecessary files
if
os
.
path
.
exists
(
cluster_json_path
):
os
.
remove
(
cluster_json_path
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/test_cluster.py
浏览文件 @
e379455a
...
...
@@ -2018,6 +2018,10 @@ class TestCluster(unittest.TestCase):
self
.
assertTrue
(
devices
==
[
5
,
6
,
7
,
10
])
self
.
assertTrue
(
involved_machine_count
==
2
)
# Remove unnecessary files
if
os
.
path
.
exists
(
cluster_json_path
):
os
.
remove
(
cluster_json_path
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/test_comm_cost.py
浏览文件 @
e379455a
...
...
@@ -154,6 +154,10 @@ class TestCommOpCost(unittest.TestCase):
comm_context
=
comm_context
)
self
.
assertTrue
(
recv_op_cost
.
time
>
0
)
# Remove unnecessary files
if
os
.
path
.
exists
(
cluster_json_path
):
os
.
remove
(
cluster_json_path
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py
浏览文件 @
e379455a
...
...
@@ -19,8 +19,8 @@ import tempfile
import
paddle
import
paddle.distributed.auto_parallel.cost
as
cost_model
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_to_desc
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_desc_to_str
from
paddle.distributed.auto_parallel.cost.base_cost
import
build_comp_desc_from_op
from
paddle.distributed.auto_parallel.cost.base_cost
import
build_comp_desc_str_for_predict
from
paddle.distributed.auto_parallel.cost.base_cost
import
calc_time_by_modeling
from
paddle.distributed.auto_parallel.cluster
import
Cluster
from
paddle.distributed.auto_parallel.cost
import
CommContext
...
...
@@ -60,8 +60,8 @@ class TestCost(unittest.TestCase):
break
matmul_v2_cost
=
cost_model
.
_g_op_cost_factory
[
"matmul_v2"
](
op
=
matmul_v2_op
)
desc
=
parse_to_desc
(
op
=
matmul_v2_op
)
desc_str
=
parse_desc_to_str
(
desc
)
desc
=
build_comp_desc_from_op
(
op
=
matmul_v2_op
)
desc_str
=
build_comp_desc_str_for_predict
(
desc
)
self
.
assertIsNotNone
(
desc_str
)
self
.
assertTrue
(
check_cost
(
matmul_v2_cost
.
cost
))
time
=
calc_time_by_modeling
(
op
=
matmul_v2_op
)
...
...
@@ -92,11 +92,29 @@ class TestCost(unittest.TestCase):
op_desc
=
desc
,
comm_context
=
CommContext
(
cluster
))
self
.
assertTrue
(
check_cost
(
allreduce_cost
.
cost
))
# Remove unnecessary files
if
os
.
path
.
exists
(
cluster_json_path
):
os
.
remove
(
cluster_json_path
)
def
test_cost_estimator
(
self
):
# Build cluster
cluster_json_path
=
os
.
path
.
join
(
self
.
temp_dir
.
name
,
"auto_parallel_cluster.json"
)
cluster_json_object
=
json
.
loads
(
cluster_json
)
with
open
(
cluster_json_path
,
"w"
)
as
cluster_json_file
:
json
.
dump
(
cluster_json_object
,
cluster_json_file
)
cluster
=
Cluster
()
cluster
.
build_from_file
(
cluster_json_path
)
train_program
=
paddle
.
static
.
Program
()
cost_estimator
=
cost_model
.
CostEstimator
(
train_program
)
cost_estimator
=
cost_model
.
CostEstimator
(
train_program
,
cluster
=
cluster
)
self
.
assertIsNotNone
(
cost_estimator
)
# Remove unnecessary files
if
os
.
path
.
exists
(
cluster_json_path
):
os
.
remove
(
cluster_json_path
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录