Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
cbc5ca0f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cbc5ca0f
编写于
5月 13, 2022
作者:
T
Tao CHANG
提交者:
GitHub
5月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add communication cost for cost model (#42727)
上级
3052f36c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
347 addition
and
9 deletion
+347
-9
python/paddle/distributed/auto_parallel/cost/__init__.py
python/paddle/distributed/auto_parallel/cost/__init__.py
+11
-2
python/paddle/distributed/auto_parallel/cost/base_cost.py
python/paddle/distributed/auto_parallel/cost/base_cost.py
+17
-1
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
+136
-4
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_comm_cost.py
...dle/fluid/tests/unittests/auto_parallel/test_comm_cost.py
+158
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py
...luid/tests/unittests/auto_parallel/test_new_cost_model.py
+24
-2
未找到文件。
python/paddle/distributed/auto_parallel/cost/__init__.py
浏览文件 @
cbc5ca0f
...
@@ -14,7 +14,16 @@
...
@@ -14,7 +14,16 @@
from
.base_cost
import
_g_op_cost_factory
from
.base_cost
import
_g_op_cost_factory
from
.base_cost
import
Cost
from
.base_cost
import
Cost
from
.
comm_op_cost
import
AllreduceSumCos
t
from
.
base_cost
import
CommContex
t
from
.
comp_op_cost
import
MatmulV2OpCost
from
.
base_cost
import
build_comm_desc
from
.tensor_cost
import
TensorCost
from
.tensor_cost
import
TensorCost
from
.estimate_cost
import
CostEstimator
from
.estimate_cost
import
CostEstimator
from
.comp_op_cost
import
MatmulV2OpCost
from
.comm_op_cost
import
SendOpCost
from
.comm_op_cost
import
RecvOpCost
from
.comm_op_cost
import
IdentityOpCost
from
.comm_op_cost
import
BroadcastOpCost
from
.comm_op_cost
import
AllgatherOpCost
from
.comm_op_cost
import
AllreduceSumOpCost
python/paddle/distributed/auto_parallel/cost/base_cost.py
浏览文件 @
cbc5ca0f
...
@@ -13,15 +13,31 @@
...
@@ -13,15 +13,31 @@
# limitations under the License
# limitations under the License
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
functools
import
reduce
import
paddle
import
paddle
from
..cluster
import
LinkType
from
..process_group
import
get_process_group
COMM_OP_TYPE
=
[
COMM_OP_TYPE
=
[
"send_v2"
,
"recv_v2"
,
"c_broadcast"
,
"c_allgather"
,
"c_allreduce_sum"
"send_v2"
,
"recv_v2"
,
"c_broadcast"
,
"c_allgather"
,
"c_allreduce_sum"
,
"c_identity"
]
]
NON_COMP_TYPE
=
[
"while"
]
+
COMM_OP_TYPE
NON_COMP_TYPE
=
[
"while"
]
+
COMM_OP_TYPE
_g_op_cost_factory
=
{}
_g_op_cost_factory
=
{}
def
build_comm_desc
(
op_type
,
group_ranks
,
dtype
,
shape
,
attrs
=
None
):
desc
=
{}
desc
[
"op"
]
=
op_type
desc
[
"group_ranks"
]
=
group_ranks
desc
[
"inputs"
]
=
{
"X"
:
[(
dtype
,
shape
)]}
if
attrs
is
not
None
:
desc
[
"attrs"
]
=
attrs
return
desc
def
_parse_op_to_desc
(
op
,
dist_context
=
None
):
def
_parse_op_to_desc
(
op
,
dist_context
=
None
):
desc
=
{}
desc
=
{}
desc
[
"op"
]
=
op
.
type
desc
[
"op"
]
=
op
.
type
...
...
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
浏览文件 @
cbc5ca0f
...
@@ -12,17 +12,149 @@
...
@@ -12,17 +12,149 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License
# limitations under the License
from
.base_cost
import
register_op_cost
,
CommOpCost
import
math
from
.base_cost
import
register_op_cost
,
CommOpCost
,
_g_op_cost_factory
@
register_op_cost
@
register_op_cost
class
AllreduceSumCost
(
CommOpCost
):
class
AllreduceSum
Op
Cost
(
CommOpCost
):
OP_TYPE
=
"c_allreduce_sum"
OP_TYPE
=
"c_allreduce_sum"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
comm_context
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
comm_context
=
None
):
super
(
AllreduceSumCost
,
self
).
__init__
(
super
(
AllreduceSumOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
def
calc_time
(
self
):
# use tree if cross machine and use ring if in a single machine
time
=
None
cluster
=
self
.
comm_context
.
cluster
if
not
cluster
.
cross_machine
(
self
.
group_ranks
):
time
=
self
.
calc_time_ring
()
else
:
time
=
self
.
calc_time_tree
()
return
time
def
calc_time_ring
(
self
):
alpha
=
self
.
comm_context
.
base_ring
alpha
+=
2
*
(
self
.
rank_count
-
self
.
machine_count
)
*
self
.
comm_context
.
intra_ring
alpha
+=
2
*
(
self
.
machine_count
-
1
)
*
(
self
.
comm_context
.
inter_ring
+
self
.
hops
*
self
.
comm_context
.
switch
)
beta
=
self
.
comm_context
.
get_max_beta
(
self
.
group_ranks
)
time
=
alpha
+
2
*
(
self
.
rank_count
-
1
)
/
self
.
rank_count
*
self
.
comm_count
*
beta
return
time
def
calc_time_tree
(
self
):
alpha
=
self
.
comm_context
.
base_tree
alpha
+=
2
*
(
self
.
rank_count
/
self
.
machine_count
-
1
)
*
self
.
comm_context
.
intra_tree
alpha
+=
math
.
log2
(
self
.
machine_count
)
*
(
self
.
comm_context
.
inter_tree
+
self
.
hops
*
self
.
comm_context
.
switch
)
beta
=
self
.
comm_context
.
get_max_beta
(
self
.
group_ranks
)
time
=
alpha
+
2
*
self
.
comm_count
*
beta
return
time
@
register_op_cost
class
AllgatherOpCost
(
CommOpCost
):
OP_TYPE
=
"c_allgather"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
comm_context
=
None
):
super
(
AllgatherOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
def
calc_time
(
self
):
time
=
self
.
calc_time_ring
()
return
time
def
calc_time_ring
(
self
):
alpha
=
self
.
comm_context
.
base_ring
alpha
+=
(
self
.
rank_count
-
self
.
machine_count
)
*
self
.
comm_context
.
intra_ring
alpha
+=
(
self
.
machine_count
-
1
)
*
(
self
.
comm_context
.
inter_ring
+
self
.
hops
*
self
.
comm_context
.
switch
)
beta
=
self
.
comm_context
.
get_max_beta
(
self
.
group_ranks
)
time
=
alpha
+
(
self
.
rank_count
-
1
)
/
self
.
rank_count
*
self
.
comm_count
*
beta
return
time
@
register_op_cost
class
BroadcastOpCost
(
CommOpCost
):
OP_TYPE
=
"c_broadcast"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
comm_context
=
None
):
super
(
BroadcastOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
def
calc_time
(
self
):
time
=
self
.
calc_time_ring
()
return
time
def
calc_time_ring
(
self
):
alpha
=
self
.
comm_context
.
base_ring
if
self
.
machine_count
>
1
:
alpha
+=
self
.
comm_context
.
inter_ring
+
self
.
hops
*
self
.
comm_context
.
switch
else
:
alpha
+=
self
.
comm_context
.
intra_ring
beta
=
self
.
comm_context
.
get_max_beta
(
self
.
group_ranks
)
time
=
alpha
+
self
.
comm_count
*
beta
return
time
@
register_op_cost
class
IdentityOpCost
(
CommOpCost
):
OP_TYPE
=
"c_identity"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
comm_context
=
None
):
super
(
IdentityOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
def
calc_time
(
self
):
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future.
return
0
return
0
@
register_op_cost
class
RecvOpCost
(
CommOpCost
):
OP_TYPE
=
"recv_v2"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
comm_context
=
None
):
super
(
RecvOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
def
calc_time
(
self
):
alpha
=
self
.
comm_context
.
base_ring
if
self
.
machine_count
>
1
:
alpha
+=
self
.
comm_context
.
inter_ring
+
self
.
hops
*
self
.
comm_context
.
switch
else
:
alpha
+=
self
.
comm_context
.
intra_ring
beta
=
self
.
comm_context
.
get_max_beta
(
self
.
group_ranks
)
time
=
alpha
+
self
.
comm_count
*
beta
return
time
@
register_op_cost
class
SendOpCost
(
CommOpCost
):
OP_TYPE
=
"send_v2"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
comm_context
=
None
):
super
(
SendOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
def
calc_time
(
self
):
alpha
=
self
.
comm_context
.
base_ring
if
self
.
machine_count
>
1
:
alpha
+=
self
.
comm_context
.
inter_ring
+
self
.
hops
*
self
.
comm_context
.
switch
else
:
alpha
+=
self
.
comm_context
.
intra_ring
beta
=
self
.
comm_context
.
get_max_beta
(
self
.
group_ranks
)
time
=
alpha
+
self
.
comm_count
*
beta
return
time
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
cbc5ca0f
...
@@ -29,4 +29,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
...
@@ -29,4 +29,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_dist_pnorm MODULES test_dist_pnorm ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_dist_pnorm MODULES test_dist_pnorm ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_dist_slice MODULES test_dist_slice ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_dist_slice MODULES test_dist_slice ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_cluster MODULES test_cluster ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_cluster MODULES test_cluster ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_comm_cost MODULES test_comm_cost ENVS
${
dist_ENVS
}
)
endif
()
endif
()
python/paddle/fluid/tests/unittests/auto_parallel/test_comm_cost.py
0 → 100644
浏览文件 @
cbc5ca0f
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
os
import
json
import
paddle
from
paddle.distributed.auto_parallel.cluster
import
Cluster
from
paddle.distributed.auto_parallel.cost
import
CommContext
from
paddle.distributed.auto_parallel.cost
import
build_comm_desc
from
paddle.distributed.auto_parallel.cost
import
AllreduceSumOpCost
from
paddle.distributed.auto_parallel.cost
import
AllgatherOpCost
from
paddle.distributed.auto_parallel.cost
import
BroadcastOpCost
from
paddle.distributed.auto_parallel.cost
import
SendOpCost
from
paddle.distributed.auto_parallel.cost
import
RecvOpCost
from
paddle.distributed.auto_parallel.cost
import
IdentityOpCost
from
test_cluster
import
cluster_json
,
multi_cluster_json
class
TestCommOpCost
(
unittest
.
TestCase
):
def
test_comm_cost
(
self
):
# Build cluster
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
cluster_json_path
=
os
.
path
.
join
(
file_dir
,
"auto_parallel_cluster.json"
)
cluster_json_object
=
json
.
loads
(
cluster_json
)
with
open
(
cluster_json_path
,
"w"
)
as
cluster_json_file
:
json
.
dump
(
cluster_json_object
,
cluster_json_file
)
cluster
=
Cluster
()
cluster
.
build_from_file
(
cluster_json_path
)
# Build CommConetxt
CommContext
.
_has_instance
=
None
CommContext
.
_instance
=
None
comm_context
=
CommContext
(
cluster
)
# Check AllreduceSumCost 128MB ring cost
allreduce_sum_op_desc
=
build_comm_desc
(
"c_allreduce_sum"
,
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
paddle
.
float32
,
[
1
,
32
*
(
10
**
6
)])
allreduce_sum_op_cost
=
AllreduceSumOpCost
(
op_desc
=
allreduce_sum_op_desc
,
comm_context
=
comm_context
)
# Check AllgatherOpCost cost
allgather_op_desc
=
build_comm_desc
(
"c_allgather"
,
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
paddle
.
float32
,
[
1
,
32
*
(
10
**
6
)])
allgather_op_cost
=
AllgatherOpCost
(
op_desc
=
allgather_op_desc
,
comm_context
=
comm_context
)
self
.
assertTrue
(
allgather_op_cost
.
time
>
0
)
# Check BroadcastOpCost cost
broadcast_op_desc
=
build_comm_desc
(
"c_broadcast"
,
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
],
paddle
.
float32
,
[
1
,
32
*
(
10
**
6
)])
broadcast_op_cost
=
BroadcastOpCost
(
op_desc
=
broadcast_op_desc
,
comm_context
=
comm_context
)
self
.
assertTrue
(
broadcast_op_cost
.
time
>
0
)
# Check SendOpCost cost
send_op_desc
=
build_comm_desc
(
"send_v2"
,
[
0
,
1
],
paddle
.
float32
,
[
1
,
32
*
(
10
**
6
)])
send_op_cost
=
SendOpCost
(
op_desc
=
send_op_desc
,
comm_context
=
comm_context
)
self
.
assertTrue
(
send_op_cost
.
time
>
0
)
# Check RecvOpCost cost
recv_op_desc
=
build_comm_desc
(
"recv_v2"
,
[
0
,
1
],
paddle
.
float32
,
[
1
,
32
*
(
10
**
6
)])
recv_op_cost
=
RecvOpCost
(
op_desc
=
recv_op_desc
,
comm_context
=
comm_context
)
self
.
assertTrue
(
recv_op_cost
.
time
>
0
)
# Check IdentityOpCost cost
identity_op_desc
=
build_comm_desc
(
"c_identity"
,
[
0
,
1
],
paddle
.
float32
,
[
1
,
32
*
(
10
**
6
)])
identity_op_cost
=
IdentityOpCost
(
op_desc
=
identity_op_desc
,
comm_context
=
comm_context
)
self
.
assertTrue
(
identity_op_cost
.
time
>=
0
)
# Remove unnecessary files
if
os
.
path
.
exists
(
cluster_json_path
):
os
.
remove
(
cluster_json_path
)
def
test_cross_machine_comm_cost
(
self
):
# Build cluster
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
cluster_json_path
=
os
.
path
.
join
(
file_dir
,
"auto_parallel_cluster.json"
)
cluster_json_object
=
json
.
loads
(
multi_cluster_json
)
with
open
(
cluster_json_path
,
"w"
)
as
cluster_json_file
:
json
.
dump
(
cluster_json_object
,
cluster_json_file
)
cluster
=
Cluster
()
cluster
.
build_from_file
(
cluster_json_path
)
# Build CommConetxt
CommContext
.
_has_instance
=
None
CommContext
.
_instance
=
None
comm_context
=
CommContext
(
cluster
)
# Check AllreduceSumCost 128MB ring cost
allreduce_sum_op_desc
=
build_comm_desc
(
"c_allreduce_sum"
,
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
],
paddle
.
float32
,
[
1
,
32
*
(
10
**
6
)])
allreduce_sum_op_cost
=
AllreduceSumOpCost
(
op_desc
=
allreduce_sum_op_desc
,
comm_context
=
comm_context
)
# Check AllgatherOpCost cost
allgather_op_desc
=
build_comm_desc
(
"c_allgather"
,
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
],
paddle
.
float32
,
[
1
,
32
*
(
10
**
6
)])
allgather_op_cost
=
AllgatherOpCost
(
op_desc
=
allgather_op_desc
,
comm_context
=
comm_context
)
self
.
assertTrue
(
allgather_op_cost
.
time
>
0
)
# Check BroadcastOpCost cost
broadcast_op_desc
=
build_comm_desc
(
"c_broadcast"
,
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
],
paddle
.
float32
,
[
1
,
32
*
(
10
**
6
)])
broadcast_op_cost
=
BroadcastOpCost
(
op_desc
=
broadcast_op_desc
,
comm_context
=
comm_context
)
self
.
assertTrue
(
broadcast_op_cost
.
time
>
0
)
# Check SendOpCost cost
send_op_desc
=
build_comm_desc
(
"send_v2"
,
[
0
,
1
],
paddle
.
float32
,
[
1
,
32
*
(
10
**
6
)])
send_op_cost
=
SendOpCost
(
op_desc
=
send_op_desc
,
comm_context
=
comm_context
)
self
.
assertTrue
(
send_op_cost
.
time
>
0
)
# Check RecvOpCost cost
recv_op_desc
=
build_comm_desc
(
"recv_v2"
,
[
0
,
1
],
paddle
.
float32
,
[
1
,
32
*
(
10
**
6
)])
recv_op_cost
=
RecvOpCost
(
op_desc
=
recv_op_desc
,
comm_context
=
comm_context
)
self
.
assertTrue
(
recv_op_cost
.
time
>
0
)
# Remove unnecessary files
if
os
.
path
.
exists
(
cluster_json_path
):
os
.
remove
(
cluster_json_path
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py
浏览文件 @
cbc5ca0f
...
@@ -13,12 +13,17 @@
...
@@ -13,12 +13,17 @@
# limitations under the License.
# limitations under the License.
import
unittest
import
unittest
import
os
import
json
import
paddle
import
paddle
import
paddle.distributed.auto_parallel.cost
as
cost_model
import
paddle.distributed.auto_parallel.cost
as
cost_model
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_to_desc
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_to_desc
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_desc_to_str
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_desc_to_str
from
paddle.distributed.auto_parallel.cost.base_cost
import
calc_time_by_modeling
from
paddle.distributed.auto_parallel.cost.base_cost
import
calc_time_by_modeling
from
paddle.distributed.auto_parallel.cluster
import
Cluster
from
paddle.distributed.auto_parallel.cost
import
CommContext
from
test_cluster
import
cluster_json
,
multi_cluster_json
paddle
.
enable_static
()
paddle
.
enable_static
()
...
@@ -58,14 +63,31 @@ class TestCost(unittest.TestCase):
...
@@ -58,14 +63,31 @@ class TestCost(unittest.TestCase):
self
.
assertEqual
(
tensor_cost
.
cost
.
memory
,
1600
)
self
.
assertEqual
(
tensor_cost
.
cost
.
memory
,
1600
)
def
test_comm_cost
(
self
):
def
test_comm_cost
(
self
):
# Build cluster
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
cluster_json_path
=
os
.
path
.
join
(
file_dir
,
"auto_parallel_cluster.json"
)
cluster_json_object
=
json
.
loads
(
cluster_json
)
with
open
(
cluster_json_path
,
"w"
)
as
cluster_json_file
:
json
.
dump
(
cluster_json_object
,
cluster_json_file
)
cluster
=
Cluster
()
cluster
.
build_from_file
(
cluster_json_path
)
# Build CommConetxt
CommContext
.
_has_instance
=
None
CommContext
.
_instance
=
None
comm_context
=
CommContext
(
cluster
)
desc
=
{}
desc
=
{}
desc
[
"op"
]
=
"c_allreduce_sum"
desc
[
"op"
]
=
"c_allreduce_sum"
desc
[
"inputs"
]
=
{
"X"
:
[(
[
100
,
200
],
paddle
.
float32
)]}
desc
[
"inputs"
]
=
{
"X"
:
[(
paddle
.
float32
,
[
100
,
200
]
)]}
desc
[
"group_ranks"
]
=
[
0
,
1
]
desc
[
"group_ranks"
]
=
[
0
,
1
]
allreduce_cost
=
cost_model
.
_g_op_cost_factory
[
"c_allreduce_sum"
](
allreduce_cost
=
cost_model
.
_g_op_cost_factory
[
"c_allreduce_sum"
](
op_desc
=
desc
)
op_desc
=
desc
,
comm_context
=
CommContext
(
cluster
)
)
self
.
assertTrue
(
check_cost
(
allreduce_cost
.
cost
))
self
.
assertTrue
(
check_cost
(
allreduce_cost
.
cost
))
# Remove unnecessary files
if
os
.
path
.
exists
(
cluster_json_path
):
os
.
remove
(
cluster_json_path
)
def
test_cost_estimator
(
self
):
def
test_cost_estimator
(
self
):
train_program
=
paddle
.
static
.
Program
()
train_program
=
paddle
.
static
.
Program
()
cost_estimator
=
cost_model
.
CostEstimator
(
train_program
)
cost_estimator
=
cost_model
.
CostEstimator
(
train_program
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录