Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6ac08db5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6ac08db5
编写于
5月 10, 2022
作者:
C
caozhou
提交者:
GitHub
5月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update base of cost model (#42601)
上级
cc077693
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
262 addition
and
66 deletion
+262
-66
python/paddle/distributed/auto_parallel/cost/__init__.py
python/paddle/distributed/auto_parallel/cost/__init__.py
+1
-1
python/paddle/distributed/auto_parallel/cost/base_cost.py
python/paddle/distributed/auto_parallel/cost/base_cost.py
+250
-55
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
+3
-3
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
+3
-3
python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py
...luid/tests/unittests/auto_parallel/test_new_cost_model.py
+5
-4
未找到文件。
python/paddle/distributed/auto_parallel/cost/__init__.py
浏览文件 @
6ac08db5
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License
# limitations under the License
from
.base_cost
import
OP_COST_FACTORY
from
.base_cost
import
_g_op_cost_factory
from
.base_cost
import
Cost
from
.base_cost
import
Cost
from
.comm_op_cost
import
AllreduceSumCost
from
.comm_op_cost
import
AllreduceSumCost
from
.comp_op_cost
import
MatmulV2OpCost
from
.comp_op_cost
import
MatmulV2OpCost
...
...
python/paddle/distributed/auto_parallel/cost/base_cost.py
浏览文件 @
6ac08db5
...
@@ -19,7 +19,7 @@ COMM_OP_TYPE = [
...
@@ -19,7 +19,7 @@ COMM_OP_TYPE = [
"send_v2"
,
"recv_v2"
,
"c_broadcast"
,
"c_allgather"
,
"c_allreduce_sum"
"send_v2"
,
"recv_v2"
,
"c_broadcast"
,
"c_allgather"
,
"c_allreduce_sum"
]
]
NON_COMP_TYPE
=
[
"while"
]
+
COMM_OP_TYPE
NON_COMP_TYPE
=
[
"while"
]
+
COMM_OP_TYPE
OP_COST_FACTORY
=
{}
_g_op_cost_factory
=
{}
def
_parse_op_to_desc
(
op
,
dist_context
=
None
):
def
_parse_op_to_desc
(
op
,
dist_context
=
None
):
...
@@ -126,66 +126,136 @@ class CommContext:
...
@@ -126,66 +126,136 @@ class CommContext:
_instance
=
None
_instance
=
None
_has_instance
=
False
_has_instance
=
False
def
__init__
(
self
,
cluster
):
if
CommContext
.
_has_instance
:
return
self
.
cluster
=
cluster
self
.
_alpha_base_ring
=
8.4
self
.
_alpha_base_tree
=
0
self
.
_alpha_inter
=
None
self
.
_alpha_intra
self
.
_beta
=
{}
def
__new__
(
cls
,
*
args
,
**
kwargs
):
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
cls
.
_instance
=
super
().
__new__
(
cls
)
_has_instance
=
True
_has_instance
=
True
return
cls
.
_instance
return
cls
.
_instance
@
property
def
__init__
(
self
,
cluster
):
def
alpha_inter
(
self
):
if
CommContext
.
_has_instance
:
if
self
.
_alpha_inter
is
None
:
return
if
cluster
.
alpha
.
inter
==
"NVL"
:
self
.
beta
=
{}
self
.
_alpha_inter
=
3.4
self
.
hops
=
{}
elif
cluster
.
alpha
.
inter
==
"PHB"
:
self
.
cluster
=
cluster
self
.
_alpha_inter
=
5.7
# if cluster has no info about those vars, it will be set by default
return
self
.
_alpha_inter
self
.
base_ring
=
None
self
.
base_tree
=
None
@
property
# self.base_inter_ring = None
def
alpha_intra
(
self
):
# self.base_inter_tree = None
if
self
.
_alpha_intra
is
None
:
self
.
intra_ring
=
None
if
cluster
.
alpha
.
intra
==
"NVL"
:
self
.
intra_tree
=
None
self
.
_alpha_intra
=
28
self
.
inter_ring
=
None
elif
cluster
.
alpha
.
intra
==
"PHB"
:
self
.
inter_tree
=
None
self
.
_alpha_intra
=
28
self
.
switch
=
None
return
self
.
_alpha_intra
self
.
_post_init
()
@
property
def
_post_init
(
self
):
def
alpha_base_ring
(
self
):
alpha_latency
=
self
.
cluster
.
alpha_latency
return
self
.
_alpha_base_ring
if
alpha_latency
is
None
:
# set default
@
property
self
.
base_ring
=
8.4
def
alpha_base_tree
(
self
):
self
.
base_tree
=
0.
return
self
.
_alpha_base_tree
# NVL in default
self
.
intra_ring
=
3.4
def
get_beta
(
self
,
ranks
):
self
.
intra_tree
=
28
# NET in default
self
.
inter_ring
=
9.6
self
.
inter_tree
=
28
self
.
switch
=
10.0
else
:
base_ring
=
alpha_latency
.
base_ring
self
.
base_ring
=
base_ring
if
base_ring
is
not
None
else
8.4
base_tree
=
alpha_latency
.
base_tree
self
.
base_tree
=
base_tree
if
base_tree
is
not
None
else
0.
intra_ring
=
alpha_latency
.
intra_ring
if
intra_ring
==
LinkType
.
NVL
:
self
.
intra_ring
=
3.4
elif
intra_ring
==
LinkType
.
PHB
:
self
.
intra_ring
=
5.7
elif
intra_ring
is
not
None
:
self
.
intra_ring
=
intra_ring
else
:
# NVL Default
self
.
intra_ring
=
3.4
intra_tree
=
alpha_latency
.
intra_tree
if
intra_tree
==
LinkType
.
NVL
:
self
.
intra_tree
=
28
elif
intra_tree
==
LinkType
.
PHB
:
self
.
intra_tree
=
28
elif
intra_tree
is
not
None
:
self
.
intra_tree
=
intra_tree
else
:
# NVL Default
self
.
intra_tree
=
28
inter_ring
=
alpha_latency
.
inter_ring
if
inter_ring
==
LinkType
.
NET
:
self
.
inter_ring
=
9.6
elif
inter_ring
is
not
None
:
self
.
inter_ring
=
inter_ring
else
:
# NET Default
self
.
inter_ring
=
9.6
inter_tree
=
alpha_latency
.
inter_tree
if
inter_tree
==
LinkType
.
NET
:
self
.
inter_tree
=
28
elif
inter_tree
is
not
None
:
self
.
inter_tree
=
inter_tree
else
:
# NET Default
self
.
inter_tree
=
28
switch
=
alpha_latency
.
switch
self
.
switch
=
switch
if
switch
is
not
None
else
10
assert
self
.
base_ring
is
not
None
assert
self
.
base_tree
is
not
None
assert
self
.
intra_ring
is
not
None
assert
self
.
intra_tree
is
not
None
assert
self
.
inter_ring
is
not
None
assert
self
.
inter_tree
is
not
None
assert
self
.
switch
is
not
None
def
get_max_beta
(
self
,
ranks
):
# NOTE: Get beta by ring, even in the case of tree such as tree broadcast
ranks
=
self
.
cluster
.
convert_rank_to_device_id
(
ranks
)
key
=
','
.
join
(
map
(
str
,
sorted
(
ranks
)))
key
=
','
.
join
(
map
(
str
,
sorted
(
ranks
)))
max_beta
=
None
max_beta
=
None
if
key
in
self
.
_beta
.
keys
:
if
key
in
self
.
beta
:
max_beta
=
self
.
_
beta
[
key
]
max_beta
=
self
.
beta
[
key
]
else
:
else
:
for
i
in
range
(
len
(
ranks
)):
for
i
in
range
(
len
(
ranks
)):
for
j
in
range
(
i
+
1
,
len
(
ranks
)):
for
j
in
range
(
i
+
1
,
len
(
ranks
)):
if
min_beta
==
None
:
forward_order_beta
=
self
.
cluster
.
get_beta
(
ranks
[
i
],
min_beta
=
cluster
.
get_beta
(
ranks
[
i
],
ranks
[
j
])
ranks
[
j
])
backward_order_beta
=
self
.
cluster
.
get_beta
(
ranks
[
j
],
ranks
[
i
])
beta
=
forward_order_beta
if
forward_order_beta
>
backward_order_beta
else
backward_order_beta
if
max_beta
==
None
:
max_beta
=
beta
else
:
else
:
beta
=
cluster
.
get_beta
(
ranks
[
i
],
ranks
[
j
])
if
beta
>
max_beta
:
if
beta
>
max_beta
:
max_beta
=
beta
max_beta
=
beta
self
.
_
beta
[
key
]
=
max_beta
self
.
beta
[
key
]
=
max_beta
return
max_beta
return
max_beta
def
get_hops
(
self
,
ranks
):
key
=
','
.
join
(
map
(
str
,
sorted
(
ranks
)))
hops
=
0
for
i
in
range
(
len
(
ranks
)):
for
j
in
range
(
i
+
1
,
len
(
ranks
)):
hop
=
self
.
cluster
.
get_hop
(
ranks
[
i
],
ranks
[
j
])
hops
+=
hop
self
.
hops
[
key
]
=
hops
return
hops
class
Cost
:
class
Cost
:
def
__init__
(
self
,
time
=
0
,
memory
=
0
,
flops
=
0
):
def
__init__
(
self
,
time
=
0
,
memory
=
0
,
flops
=
0
):
...
@@ -198,11 +268,13 @@ class Cost:
...
@@ -198,11 +268,13 @@ class Cost:
def
_check_memory
(
self
,
val
):
def
_check_memory
(
self
,
val
):
assert
isinstance
(
assert
isinstance
(
val
,
int
)
and
val
>=
0
,
"Memory must be int and greater than 0."
val
,
int
)
and
val
>=
0
,
"Memory must be int and greater than equal to 0."
def
_check_flops
(
self
,
val
):
def
_check_flops
(
self
,
val
):
assert
isinstance
(
assert
isinstance
(
val
,
int
)
and
val
>=
0
,
"FLOPs must be int and greater than 0."
val
,
int
)
and
val
>=
0
,
"FLOPs must be int and greater than equal to 0."
@
property
@
property
def
time
(
self
):
def
time
(
self
):
...
@@ -254,7 +326,7 @@ class OpCost:
...
@@ -254,7 +326,7 @@ class OpCost:
op_desc
is
not
None
)
op_desc
is
not
None
)
self
.
_op
=
op
self
.
_op
=
op
self
.
_op_desc
=
op_desc
self
.
_op_desc
=
op_desc
self
.
_cost
=
self
.
calc_cost
()
self
.
_cost
=
None
@
property
@
property
def
op
(
self
):
def
op
(
self
):
...
@@ -264,6 +336,18 @@ class OpCost:
...
@@ -264,6 +336,18 @@ class OpCost:
def
op_desc
(
self
):
def
op_desc
(
self
):
return
self
.
_op_desc
return
self
.
_op_desc
@
property
def
time
(
self
):
return
self
.
cost
.
time
@
property
def
memory
(
self
):
return
self
.
cost
.
memory
@
property
def
flops
(
self
):
return
self
.
cost
.
flops
@
property
@
property
def
cost
(
self
):
def
cost
(
self
):
return
self
.
_cost
return
self
.
_cost
...
@@ -284,6 +368,40 @@ class OpCost:
...
@@ -284,6 +368,40 @@ class OpCost:
cost
=
Cost
(
time
,
memory
,
flops
)
cost
=
Cost
(
time
,
memory
,
flops
)
return
cost
return
cost
def
__add__
(
self
,
rhs
):
assert
isinstance
(
rhs
,
(
OpCost
,
Cost
))
time
=
0
memory
=
0
flops
=
0
if
isinstance
(
rhs
,
OpCost
):
time
=
self
.
cost
.
time
+
rhs
.
cost
.
time
memory
=
self
.
cost
.
memory
+
rhs
.
cost
.
memory
flops
=
self
.
cost
.
flops
+
rhs
.
cost
.
flops
assert
(
time
>=
0
and
memory
>=
0
and
flops
>=
0
)
elif
isinstance
(
rhs
,
Cost
):
time
=
self
.
time
+
rhs
.
time
memory
=
self
.
memory
+
rhs
.
memory
flops
=
self
.
flops
+
rhs
.
flops
assert
(
time
>=
0
and
memory
>=
0
and
flops
>=
0
)
return
Cost
(
time
,
memory
,
flops
)
def
__sub__
(
self
,
rhs
):
assert
isinstance
(
rhs
,
(
OpCost
,
Cost
))
time
=
0
memory
=
0
flops
=
0
if
isinstance
(
rhs
,
OpCost
):
time
=
self
.
cost
.
time
-
rhs
.
cost
.
time
memory
=
self
.
cost
.
memory
-
rhs
.
cost
.
memory
flops
=
self
.
cost
.
flops
-
rhs
.
cost
.
flops
assert
(
time
>=
0
and
memory
>=
0
and
flops
>=
0
)
elif
isinstance
(
rhs
,
Cost
):
time
=
self
.
time
-
rhs
.
time
memory
=
self
.
memory
-
rhs
.
memory
flops
=
self
.
flops
-
rhs
.
flops
assert
(
time
>=
0
and
memory
>=
0
and
flops
>=
0
)
return
Cost
(
time
,
memory
,
flops
)
class
CommOpCost
(
OpCost
):
class
CommOpCost
(
OpCost
):
OP_TYPE
=
"COMM"
OP_TYPE
=
"COMM"
...
@@ -292,11 +410,83 @@ class CommOpCost(OpCost):
...
@@ -292,11 +410,83 @@ class CommOpCost(OpCost):
super
(
CommOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
)
super
(
CommOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
)
self
.
_check_comm_op_type
()
self
.
_check_comm_op_type
()
self
.
_comm_context
=
comm_context
self
.
_comm_context
=
comm_context
self
.
_group_ranks
=
None
self
.
_comm_count
=
None
self
.
_hops
=
None
self
.
_rank_count
=
len
(
self
.
group_ranks
)
self
.
_machine_count
=
None
self
.
_cost
=
self
.
calc_cost
()
@
property
@
property
def
comm_context
(
self
):
def
comm_context
(
self
):
return
self
.
_comm_context
return
self
.
_comm_context
@
property
def
comm_count
(
self
):
if
self
.
_comm_count
is
None
:
dtype
=
None
shape
=
None
if
self
.
op
is
not
None
:
vars
=
self
.
op
.
block
.
vars
# NOTE: The tensor communicated input_name is "X" in default. Otherwise, this function should be overrided
var_name
=
self
.
op
.
input
(
"X"
)[
0
]
var
=
vars
[
var_name
]
dtype
=
var
.
dtype
shape
=
var
.
shape
elif
self
.
op_desc
is
not
None
:
dtype
=
self
.
op_desc
[
"inputs"
][
"X"
][
0
][
0
]
shape
=
self
.
op_desc
[
"inputs"
][
"X"
][
0
][
1
]
factor
=
None
if
dtype
==
paddle
.
float32
or
dtype
==
paddle
.
int32
:
factor
=
4
elif
dtype
==
paddle
.
int64
:
factor
=
8
elif
dtype
==
paddle
.
uint8
:
factor
=
1
elif
dtype
==
paddle
.
float16
:
factor
=
2
else
:
raise
TypeError
(
"This dtype {} is not supported now"
.
format
(
dtype
))
comm_count
=
reduce
(
lambda
x
,
y
:
x
*
y
,
shape
)
*
factor
self
.
_comm_count
=
comm_count
return
self
.
_comm_count
@
property
def
rank_count
(
self
):
return
self
.
_rank_count
@
property
def
machine_count
(
self
):
if
self
.
_machine_count
is
None
:
cluster
=
self
.
_comm_context
.
cluster
self
.
_machine_count
=
cluster
.
get_involved_machine_count
(
self
.
group_ranks
)
return
self
.
_machine_count
@
property
def
hops
(
self
):
if
self
.
_hops
is
None
:
self
.
_hops
=
self
.
comm_context
.
get_hops
(
self
.
group_ranks
)
return
self
.
_hops
@
property
def
group_ranks
(
self
):
if
self
.
_group_ranks
is
None
:
if
self
.
op_desc
is
not
None
:
self
.
_group_ranks
=
self
.
op_desc
[
"group_ranks"
]
elif
self
.
op
is
not
None
:
ring_id
=
op
.
attrs
(
"ring_id"
)
process_group
=
get_process_group
(
ring_id
)
if
process_group
is
None
:
raise
ValueError
(
"There not exists process group whose ring_id is {}."
.
format
(
ring_id
))
self
.
_group_ranks
=
process_group
.
ranks
return
self
.
_group_ranks
@
classmethod
@
classmethod
def
_check_comm_op_type
(
cls
):
def
_check_comm_op_type
(
cls
):
if
cls
.
OP_TYPE
!=
"COMM"
:
if
cls
.
OP_TYPE
!=
"COMM"
:
...
@@ -311,6 +501,7 @@ class CompOpCost(OpCost):
...
@@ -311,6 +501,7 @@ class CompOpCost(OpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
(
CompOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
)
super
(
CompOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
)
self
.
_check_comp_op_type
()
self
.
_check_comp_op_type
()
self
.
_cost
=
self
.
calc_cost
()
self
.
cluster
=
cluster
self
.
cluster
=
cluster
@
classmethod
@
classmethod
...
@@ -325,18 +516,22 @@ def register_op_cost(cls):
...
@@ -325,18 +516,22 @@ def register_op_cost(cls):
op_type
=
cls
.
OP_TYPE
op_type
=
cls
.
OP_TYPE
def
register
(
op_type
):
def
register
(
op_type
):
OP_COST_FACTORY
[
op_type
]
=
cls
global
_g_op_cost_factory
_g_op_cost_factory
[
op_type
]
=
cls
return
register
(
op_type
)
register
(
op_type
)
return
cls
def
calc_time_
from_model
(
op
=
None
,
desc
=
None
,
cluster
=
None
,
comm_context
=
None
):
def
calc_time_
by_modeling
(
op
=
None
,
desc
=
None
,
cluster
=
None
):
op_type
=
op
.
type
if
op
is
not
None
else
desc
[
"op"
]
op_type
=
op
.
type
if
op
is
not
None
else
desc
[
"op"
]
if
op_type
in
COMM_OP_TYPE
:
if
op_type
in
COMM_OP_TYPE
:
op_cost
=
OP_COST_FACTORY
[
op_type
](
op
=
op
,
op_cost
=
_g_op_cost_factory
[
op_type
](
op
=
op
,
op_desc
=
desc
,
op_desc
=
desc
,
comm_context
=
comm_context
)
comm_context
=
CommContext
(
cluster
)
)
elif
op_type
not
in
NON_COMP_TYPE
:
elif
op_type
not
in
NON_COMP_TYPE
:
op_cost
=
OP_COST_FACTORY
[
op_type
](
op
=
op
,
op_desc
=
desc
,
cluster
=
cluster
)
op_cost
=
_g_op_cost_factory
[
op_type
](
op
=
op
,
op_desc
=
desc
,
cluster
=
cluster
)
time
=
op_cost
.
calc_time
()
time
=
op_cost
.
calc_time
()
return
time
return
time
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
浏览文件 @
6ac08db5
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License
# limitations under the License
from
.base_cost
import
register_op_cost
,
CommOpCost
,
OP_COST_FACTORY
from
.base_cost
import
register_op_cost
,
CommOpCost
@
register_op_cost
@
register_op_cost
...
@@ -20,7 +20,7 @@ class AllreduceSumCost(CommOpCost):
...
@@ -20,7 +20,7 @@ class AllreduceSumCost(CommOpCost):
OP_TYPE
=
"c_allreduce_sum"
OP_TYPE
=
"c_allreduce_sum"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
comm_context
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
comm_context
=
None
):
super
(
OP_COST_FACTORY
[
"c_allreduce_sum"
]
,
self
).
__init__
(
super
(
AllreduceSumCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
def
calc_time
(
self
):
def
calc_time
(
self
):
...
...
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
浏览文件 @
6ac08db5
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License
# limitations under the License
from
.base_cost
import
Cost
,
register_op_cost
,
CompOpCost
,
OP_COST_FACTORY
from
.base_cost
import
Cost
,
register_op_cost
,
CompOpCost
@
register_op_cost
@
register_op_cost
...
@@ -20,7 +20,7 @@ class MatmulV2OpCost(CompOpCost):
...
@@ -20,7 +20,7 @@ class MatmulV2OpCost(CompOpCost):
OP_TYPE
=
"matmul_v2"
OP_TYPE
=
"matmul_v2"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
(
OP_COST_FACTORY
[
"matmul_v2"
]
,
self
).
__init__
(
super
(
MatmulV2OpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function needs to be overrided
# For a concrete COMP OP, the calc_time and calc_flops function needs to be overrided
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py
浏览文件 @
6ac08db5
...
@@ -18,7 +18,7 @@ import paddle
...
@@ -18,7 +18,7 @@ import paddle
import
paddle.distributed.auto_parallel.cost
as
cost_model
import
paddle.distributed.auto_parallel.cost
as
cost_model
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_to_desc
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_to_desc
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_desc_to_str
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_desc_to_str
from
paddle.distributed.auto_parallel.cost.base_cost
import
calc_time_
from_model
from
paddle.distributed.auto_parallel.cost.base_cost
import
calc_time_
by_modeling
paddle
.
enable_static
()
paddle
.
enable_static
()
...
@@ -45,13 +45,13 @@ class TestCost(unittest.TestCase):
...
@@ -45,13 +45,13 @@ class TestCost(unittest.TestCase):
if
op
.
type
==
"matmul_v2"
:
if
op
.
type
==
"matmul_v2"
:
matmul_v2_op
=
op
matmul_v2_op
=
op
break
break
matmul_v2_cost
=
cost_model
.
OP_COST_FACTORY
[
"matmul_v2"
](
matmul_v2_cost
=
cost_model
.
_g_op_cost_factory
[
"matmul_v2"
](
op
=
matmul_v2_op
)
op
=
matmul_v2_op
)
desc
=
parse_to_desc
(
op
=
matmul_v2_op
)
desc
=
parse_to_desc
(
op
=
matmul_v2_op
)
desc_str
=
parse_desc_to_str
(
desc
)
desc_str
=
parse_desc_to_str
(
desc
)
self
.
assertIsNotNone
(
desc_str
)
self
.
assertIsNotNone
(
desc_str
)
self
.
assertTrue
(
check_cost
(
matmul_v2_cost
.
cost
))
self
.
assertTrue
(
check_cost
(
matmul_v2_cost
.
cost
))
time
=
calc_time_
from_model
(
op
=
matmul_v2_op
)
time
=
calc_time_
by_modeling
(
op
=
matmul_v2_op
)
self
.
assertEqual
(
time
,
matmul_v2_cost
.
cost
.
time
)
self
.
assertEqual
(
time
,
matmul_v2_cost
.
cost
.
time
)
tensor_cost
=
cost_model
.
TensorCost
(
tensor
=
x
)
tensor_cost
=
cost_model
.
TensorCost
(
tensor
=
x
)
# check memory
# check memory
...
@@ -61,7 +61,8 @@ class TestCost(unittest.TestCase):
...
@@ -61,7 +61,8 @@ class TestCost(unittest.TestCase):
desc
=
{}
desc
=
{}
desc
[
"op"
]
=
"c_allreduce_sum"
desc
[
"op"
]
=
"c_allreduce_sum"
desc
[
"inputs"
]
=
{
"X"
:
[([
100
,
200
],
paddle
.
float32
)]}
desc
[
"inputs"
]
=
{
"X"
:
[([
100
,
200
],
paddle
.
float32
)]}
allreduce_cost
=
cost_model
.
OP_COST_FACTORY
[
"c_allreduce_sum"
](
desc
[
"group_ranks"
]
=
[
0
,
1
]
allreduce_cost
=
cost_model
.
_g_op_cost_factory
[
"c_allreduce_sum"
](
op_desc
=
desc
)
op_desc
=
desc
)
self
.
assertTrue
(
check_cost
(
allreduce_cost
.
cost
))
self
.
assertTrue
(
check_cost
(
allreduce_cost
.
cost
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录