Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
6ac08db5
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6ac08db5
编写于
5月 10, 2022
作者:
C
caozhou
提交者:
GitHub
5月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update base of cost model (#42601)
上级
cc077693
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
262 addition
and
66 deletion
+262
-66
python/paddle/distributed/auto_parallel/cost/__init__.py
python/paddle/distributed/auto_parallel/cost/__init__.py
+1
-1
python/paddle/distributed/auto_parallel/cost/base_cost.py
python/paddle/distributed/auto_parallel/cost/base_cost.py
+250
-55
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
+3
-3
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
+3
-3
python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py
...luid/tests/unittests/auto_parallel/test_new_cost_model.py
+5
-4
未找到文件。
python/paddle/distributed/auto_parallel/cost/__init__.py
浏览文件 @
6ac08db5
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License
# limitations under the License
from
.base_cost
import
OP_COST_FACTORY
from
.base_cost
import
_g_op_cost_factory
from
.base_cost
import
Cost
from
.base_cost
import
Cost
from
.comm_op_cost
import
AllreduceSumCost
from
.comm_op_cost
import
AllreduceSumCost
from
.comp_op_cost
import
MatmulV2OpCost
from
.comp_op_cost
import
MatmulV2OpCost
...
...
python/paddle/distributed/auto_parallel/cost/base_cost.py
浏览文件 @
6ac08db5
...
@@ -19,7 +19,7 @@ COMM_OP_TYPE = [
...
@@ -19,7 +19,7 @@ COMM_OP_TYPE = [
"send_v2"
,
"recv_v2"
,
"c_broadcast"
,
"c_allgather"
,
"c_allreduce_sum"
"send_v2"
,
"recv_v2"
,
"c_broadcast"
,
"c_allgather"
,
"c_allreduce_sum"
]
]
NON_COMP_TYPE
=
[
"while"
]
+
COMM_OP_TYPE
NON_COMP_TYPE
=
[
"while"
]
+
COMM_OP_TYPE
OP_COST_FACTORY
=
{}
_g_op_cost_factory
=
{}
def
_parse_op_to_desc
(
op
,
dist_context
=
None
):
def
_parse_op_to_desc
(
op
,
dist_context
=
None
):
...
@@ -126,66 +126,136 @@ class CommContext:
...
@@ -126,66 +126,136 @@ class CommContext:
_instance
=
None
_instance
=
None
_has_instance
=
False
_has_instance
=
False
def
__init__
(
self
,
cluster
):
if
CommContext
.
_has_instance
:
return
self
.
cluster
=
cluster
self
.
_alpha_base_ring
=
8.4
self
.
_alpha_base_tree
=
0
self
.
_alpha_inter
=
None
self
.
_alpha_intra
self
.
_beta
=
{}
def
__new__
(
cls
,
*
args
,
**
kwargs
):
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
cls
.
_instance
=
super
().
__new__
(
cls
)
_has_instance
=
True
_has_instance
=
True
return
cls
.
_instance
return
cls
.
_instance
@
property
def
__init__
(
self
,
cluster
):
def
alpha_inter
(
self
):
if
CommContext
.
_has_instance
:
if
self
.
_alpha_inter
is
None
:
return
if
cluster
.
alpha
.
inter
==
"NVL"
:
self
.
beta
=
{}
self
.
_alpha_inter
=
3.4
self
.
hops
=
{}
elif
cluster
.
alpha
.
inter
==
"PHB"
:
self
.
cluster
=
cluster
self
.
_alpha_inter
=
5.7
# if cluster has no info about those vars, it will be set by default
return
self
.
_alpha_inter
self
.
base_ring
=
None
self
.
base_tree
=
None
@
property
# self.base_inter_ring = None
def
alpha_intra
(
self
):
# self.base_inter_tree = None
if
self
.
_alpha_intra
is
None
:
self
.
intra_ring
=
None
if
cluster
.
alpha
.
intra
==
"NVL"
:
self
.
intra_tree
=
None
self
.
_alpha_intra
=
28
self
.
inter_ring
=
None
elif
cluster
.
alpha
.
intra
==
"PHB"
:
self
.
inter_tree
=
None
self
.
_alpha_intra
=
28
self
.
switch
=
None
return
self
.
_alpha_intra
self
.
_post_init
()
@
property
def
_post_init
(
self
):
def
alpha_base_ring
(
self
):
alpha_latency
=
self
.
cluster
.
alpha_latency
return
self
.
_alpha_base_ring
if
alpha_latency
is
None
:
# set default
@
property
self
.
base_ring
=
8.4
def
alpha_base_tree
(
self
):
self
.
base_tree
=
0.
return
self
.
_alpha_base_tree
# NVL in default
self
.
intra_ring
=
3.4
def
get_beta
(
self
,
ranks
):
self
.
intra_tree
=
28
# NET in default
self
.
inter_ring
=
9.6
self
.
inter_tree
=
28
self
.
switch
=
10.0
else
:
base_ring
=
alpha_latency
.
base_ring
self
.
base_ring
=
base_ring
if
base_ring
is
not
None
else
8.4
base_tree
=
alpha_latency
.
base_tree
self
.
base_tree
=
base_tree
if
base_tree
is
not
None
else
0.
intra_ring
=
alpha_latency
.
intra_ring
if
intra_ring
==
LinkType
.
NVL
:
self
.
intra_ring
=
3.4
elif
intra_ring
==
LinkType
.
PHB
:
self
.
intra_ring
=
5.7
elif
intra_ring
is
not
None
:
self
.
intra_ring
=
intra_ring
else
:
# NVL Default
self
.
intra_ring
=
3.4
intra_tree
=
alpha_latency
.
intra_tree
if
intra_tree
==
LinkType
.
NVL
:
self
.
intra_tree
=
28
elif
intra_tree
==
LinkType
.
PHB
:
self
.
intra_tree
=
28
elif
intra_tree
is
not
None
:
self
.
intra_tree
=
intra_tree
else
:
# NVL Default
self
.
intra_tree
=
28
inter_ring
=
alpha_latency
.
inter_ring
if
inter_ring
==
LinkType
.
NET
:
self
.
inter_ring
=
9.6
elif
inter_ring
is
not
None
:
self
.
inter_ring
=
inter_ring
else
:
# NET Default
self
.
inter_ring
=
9.6
inter_tree
=
alpha_latency
.
inter_tree
if
inter_tree
==
LinkType
.
NET
:
self
.
inter_tree
=
28
elif
inter_tree
is
not
None
:
self
.
inter_tree
=
inter_tree
else
:
# NET Default
self
.
inter_tree
=
28
switch
=
alpha_latency
.
switch
self
.
switch
=
switch
if
switch
is
not
None
else
10
assert
self
.
base_ring
is
not
None
assert
self
.
base_tree
is
not
None
assert
self
.
intra_ring
is
not
None
assert
self
.
intra_tree
is
not
None
assert
self
.
inter_ring
is
not
None
assert
self
.
inter_tree
is
not
None
assert
self
.
switch
is
not
None
def
get_max_beta
(
self
,
ranks
):
# NOTE: Get beta by ring, even in the case of tree such as tree broadcast
ranks
=
self
.
cluster
.
convert_rank_to_device_id
(
ranks
)
key
=
','
.
join
(
map
(
str
,
sorted
(
ranks
)))
key
=
','
.
join
(
map
(
str
,
sorted
(
ranks
)))
max_beta
=
None
max_beta
=
None
if
key
in
self
.
_beta
.
keys
:
if
key
in
self
.
beta
:
max_beta
=
self
.
_
beta
[
key
]
max_beta
=
self
.
beta
[
key
]
else
:
else
:
for
i
in
range
(
len
(
ranks
)):
for
i
in
range
(
len
(
ranks
)):
for
j
in
range
(
i
+
1
,
len
(
ranks
)):
for
j
in
range
(
i
+
1
,
len
(
ranks
)):
if
min_beta
==
None
:
forward_order_beta
=
self
.
cluster
.
get_beta
(
ranks
[
i
],
min_beta
=
cluster
.
get_beta
(
ranks
[
i
],
ranks
[
j
])
ranks
[
j
])
backward_order_beta
=
self
.
cluster
.
get_beta
(
ranks
[
j
],
ranks
[
i
])
beta
=
forward_order_beta
if
forward_order_beta
>
backward_order_beta
else
backward_order_beta
if
max_beta
==
None
:
max_beta
=
beta
else
:
else
:
beta
=
cluster
.
get_beta
(
ranks
[
i
],
ranks
[
j
])
if
beta
>
max_beta
:
if
beta
>
max_beta
:
max_beta
=
beta
max_beta
=
beta
self
.
_
beta
[
key
]
=
max_beta
self
.
beta
[
key
]
=
max_beta
return
max_beta
return
max_beta
def
get_hops
(
self
,
ranks
):
key
=
','
.
join
(
map
(
str
,
sorted
(
ranks
)))
hops
=
0
for
i
in
range
(
len
(
ranks
)):
for
j
in
range
(
i
+
1
,
len
(
ranks
)):
hop
=
self
.
cluster
.
get_hop
(
ranks
[
i
],
ranks
[
j
])
hops
+=
hop
self
.
hops
[
key
]
=
hops
return
hops
class
Cost
:
class
Cost
:
def
__init__
(
self
,
time
=
0
,
memory
=
0
,
flops
=
0
):
def
__init__
(
self
,
time
=
0
,
memory
=
0
,
flops
=
0
):
...
@@ -198,11 +268,13 @@ class Cost:
...
@@ -198,11 +268,13 @@ class Cost:
def
_check_memory
(
self
,
val
):
def
_check_memory
(
self
,
val
):
assert
isinstance
(
assert
isinstance
(
val
,
int
)
and
val
>=
0
,
"Memory must be int and greater than 0."
val
,
int
)
and
val
>=
0
,
"Memory must be int and greater than equal to 0."
def
_check_flops
(
self
,
val
):
def
_check_flops
(
self
,
val
):
assert
isinstance
(
assert
isinstance
(
val
,
int
)
and
val
>=
0
,
"FLOPs must be int and greater than 0."
val
,
int
)
and
val
>=
0
,
"FLOPs must be int and greater than equal to 0."
@
property
@
property
def
time
(
self
):
def
time
(
self
):
...
@@ -254,7 +326,7 @@ class OpCost:
...
@@ -254,7 +326,7 @@ class OpCost:
op_desc
is
not
None
)
op_desc
is
not
None
)
self
.
_op
=
op
self
.
_op
=
op
self
.
_op_desc
=
op_desc
self
.
_op_desc
=
op_desc
self
.
_cost
=
self
.
calc_cost
()
self
.
_cost
=
None
@
property
@
property
def
op
(
self
):
def
op
(
self
):
...
@@ -264,6 +336,18 @@ class OpCost:
...
@@ -264,6 +336,18 @@ class OpCost:
def
op_desc
(
self
):
def
op_desc
(
self
):
return
self
.
_op_desc
return
self
.
_op_desc
@
property
def
time
(
self
):
return
self
.
cost
.
time
@
property
def
memory
(
self
):
return
self
.
cost
.
memory
@
property
def
flops
(
self
):
return
self
.
cost
.
flops
@
property
@
property
def
cost
(
self
):
def
cost
(
self
):
return
self
.
_cost
return
self
.
_cost
...
@@ -284,6 +368,40 @@ class OpCost:
...
@@ -284,6 +368,40 @@ class OpCost:
cost
=
Cost
(
time
,
memory
,
flops
)
cost
=
Cost
(
time
,
memory
,
flops
)
return
cost
return
cost
def
__add__
(
self
,
rhs
):
assert
isinstance
(
rhs
,
(
OpCost
,
Cost
))
time
=
0
memory
=
0
flops
=
0
if
isinstance
(
rhs
,
OpCost
):
time
=
self
.
cost
.
time
+
rhs
.
cost
.
time
memory
=
self
.
cost
.
memory
+
rhs
.
cost
.
memory
flops
=
self
.
cost
.
flops
+
rhs
.
cost
.
flops
assert
(
time
>=
0
and
memory
>=
0
and
flops
>=
0
)
elif
isinstance
(
rhs
,
Cost
):
time
=
self
.
time
+
rhs
.
time
memory
=
self
.
memory
+
rhs
.
memory
flops
=
self
.
flops
+
rhs
.
flops
assert
(
time
>=
0
and
memory
>=
0
and
flops
>=
0
)
return
Cost
(
time
,
memory
,
flops
)
def
__sub__
(
self
,
rhs
):
assert
isinstance
(
rhs
,
(
OpCost
,
Cost
))
time
=
0
memory
=
0
flops
=
0
if
isinstance
(
rhs
,
OpCost
):
time
=
self
.
cost
.
time
-
rhs
.
cost
.
time
memory
=
self
.
cost
.
memory
-
rhs
.
cost
.
memory
flops
=
self
.
cost
.
flops
-
rhs
.
cost
.
flops
assert
(
time
>=
0
and
memory
>=
0
and
flops
>=
0
)
elif
isinstance
(
rhs
,
Cost
):
time
=
self
.
time
-
rhs
.
time
memory
=
self
.
memory
-
rhs
.
memory
flops
=
self
.
flops
-
rhs
.
flops
assert
(
time
>=
0
and
memory
>=
0
and
flops
>=
0
)
return
Cost
(
time
,
memory
,
flops
)
class
CommOpCost
(
OpCost
):
class
CommOpCost
(
OpCost
):
OP_TYPE
=
"COMM"
OP_TYPE
=
"COMM"
...
@@ -292,11 +410,83 @@ class CommOpCost(OpCost):
...
@@ -292,11 +410,83 @@ class CommOpCost(OpCost):
super
(
CommOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
)
super
(
CommOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
)
self
.
_check_comm_op_type
()
self
.
_check_comm_op_type
()
self
.
_comm_context
=
comm_context
self
.
_comm_context
=
comm_context
self
.
_group_ranks
=
None
self
.
_comm_count
=
None
self
.
_hops
=
None
self
.
_rank_count
=
len
(
self
.
group_ranks
)
self
.
_machine_count
=
None
self
.
_cost
=
self
.
calc_cost
()
@
property
@
property
def
comm_context
(
self
):
def
comm_context
(
self
):
return
self
.
_comm_context
return
self
.
_comm_context
@
property
def
comm_count
(
self
):
if
self
.
_comm_count
is
None
:
dtype
=
None
shape
=
None
if
self
.
op
is
not
None
:
vars
=
self
.
op
.
block
.
vars
# NOTE: The tensor communicated input_name is "X" in default. Otherwise, this function should be overrided
var_name
=
self
.
op
.
input
(
"X"
)[
0
]
var
=
vars
[
var_name
]
dtype
=
var
.
dtype
shape
=
var
.
shape
elif
self
.
op_desc
is
not
None
:
dtype
=
self
.
op_desc
[
"inputs"
][
"X"
][
0
][
0
]
shape
=
self
.
op_desc
[
"inputs"
][
"X"
][
0
][
1
]
factor
=
None
if
dtype
==
paddle
.
float32
or
dtype
==
paddle
.
int32
:
factor
=
4
elif
dtype
==
paddle
.
int64
:
factor
=
8
elif
dtype
==
paddle
.
uint8
:
factor
=
1
elif
dtype
==
paddle
.
float16
:
factor
=
2
else
:
raise
TypeError
(
"This dtype {} is not supported now"
.
format
(
dtype
))
comm_count
=
reduce
(
lambda
x
,
y
:
x
*
y
,
shape
)
*
factor
self
.
_comm_count
=
comm_count
return
self
.
_comm_count
@
property
def
rank_count
(
self
):
return
self
.
_rank_count
@
property
def
machine_count
(
self
):
if
self
.
_machine_count
is
None
:
cluster
=
self
.
_comm_context
.
cluster
self
.
_machine_count
=
cluster
.
get_involved_machine_count
(
self
.
group_ranks
)
return
self
.
_machine_count
@
property
def
hops
(
self
):
if
self
.
_hops
is
None
:
self
.
_hops
=
self
.
comm_context
.
get_hops
(
self
.
group_ranks
)
return
self
.
_hops
@
property
def
group_ranks
(
self
):
if
self
.
_group_ranks
is
None
:
if
self
.
op_desc
is
not
None
:
self
.
_group_ranks
=
self
.
op_desc
[
"group_ranks"
]
elif
self
.
op
is
not
None
:
ring_id
=
op
.
attrs
(
"ring_id"
)
process_group
=
get_process_group
(
ring_id
)
if
process_group
is
None
:
raise
ValueError
(
"There not exists process group whose ring_id is {}."
.
format
(
ring_id
))
self
.
_group_ranks
=
process_group
.
ranks
return
self
.
_group_ranks
@
classmethod
@
classmethod
def
_check_comm_op_type
(
cls
):
def
_check_comm_op_type
(
cls
):
if
cls
.
OP_TYPE
!=
"COMM"
:
if
cls
.
OP_TYPE
!=
"COMM"
:
...
@@ -311,6 +501,7 @@ class CompOpCost(OpCost):
...
@@ -311,6 +501,7 @@ class CompOpCost(OpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
(
CompOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
)
super
(
CompOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
)
self
.
_check_comp_op_type
()
self
.
_check_comp_op_type
()
self
.
_cost
=
self
.
calc_cost
()
self
.
cluster
=
cluster
self
.
cluster
=
cluster
@
classmethod
@
classmethod
...
@@ -325,18 +516,22 @@ def register_op_cost(cls):
...
@@ -325,18 +516,22 @@ def register_op_cost(cls):
op_type
=
cls
.
OP_TYPE
op_type
=
cls
.
OP_TYPE
def
register
(
op_type
):
def
register
(
op_type
):
OP_COST_FACTORY
[
op_type
]
=
cls
global
_g_op_cost_factory
_g_op_cost_factory
[
op_type
]
=
cls
return
register
(
op_type
)
register
(
op_type
)
return
cls
def
calc_time_
from_model
(
op
=
None
,
desc
=
None
,
cluster
=
None
,
comm_context
=
None
):
def
calc_time_
by_modeling
(
op
=
None
,
desc
=
None
,
cluster
=
None
):
op_type
=
op
.
type
if
op
is
not
None
else
desc
[
"op"
]
op_type
=
op
.
type
if
op
is
not
None
else
desc
[
"op"
]
if
op_type
in
COMM_OP_TYPE
:
if
op_type
in
COMM_OP_TYPE
:
op_cost
=
OP_COST_FACTORY
[
op_type
](
op
=
op
,
op_cost
=
_g_op_cost_factory
[
op_type
](
op
=
op
,
op_desc
=
desc
,
op_desc
=
desc
,
comm_context
=
comm_context
)
comm_context
=
CommContext
(
cluster
)
)
elif
op_type
not
in
NON_COMP_TYPE
:
elif
op_type
not
in
NON_COMP_TYPE
:
op_cost
=
OP_COST_FACTORY
[
op_type
](
op
=
op
,
op_desc
=
desc
,
cluster
=
cluster
)
op_cost
=
_g_op_cost_factory
[
op_type
](
op
=
op
,
op_desc
=
desc
,
cluster
=
cluster
)
time
=
op_cost
.
calc_time
()
time
=
op_cost
.
calc_time
()
return
time
return
time
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
浏览文件 @
6ac08db5
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License
# limitations under the License
from
.base_cost
import
register_op_cost
,
CommOpCost
,
OP_COST_FACTORY
from
.base_cost
import
register_op_cost
,
CommOpCost
@
register_op_cost
@
register_op_cost
...
@@ -20,7 +20,7 @@ class AllreduceSumCost(CommOpCost):
...
@@ -20,7 +20,7 @@ class AllreduceSumCost(CommOpCost):
OP_TYPE
=
"c_allreduce_sum"
OP_TYPE
=
"c_allreduce_sum"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
comm_context
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
comm_context
=
None
):
super
(
OP_COST_FACTORY
[
"c_allreduce_sum"
]
,
self
).
__init__
(
super
(
AllreduceSumCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
def
calc_time
(
self
):
def
calc_time
(
self
):
...
...
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
浏览文件 @
6ac08db5
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License
# limitations under the License
from
.base_cost
import
Cost
,
register_op_cost
,
CompOpCost
,
OP_COST_FACTORY
from
.base_cost
import
Cost
,
register_op_cost
,
CompOpCost
@
register_op_cost
@
register_op_cost
...
@@ -20,7 +20,7 @@ class MatmulV2OpCost(CompOpCost):
...
@@ -20,7 +20,7 @@ class MatmulV2OpCost(CompOpCost):
OP_TYPE
=
"matmul_v2"
OP_TYPE
=
"matmul_v2"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
(
OP_COST_FACTORY
[
"matmul_v2"
]
,
self
).
__init__
(
super
(
MatmulV2OpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function needs to be overrided
# For a concrete COMP OP, the calc_time and calc_flops function needs to be overrided
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py
浏览文件 @
6ac08db5
...
@@ -18,7 +18,7 @@ import paddle
...
@@ -18,7 +18,7 @@ import paddle
import
paddle.distributed.auto_parallel.cost
as
cost_model
import
paddle.distributed.auto_parallel.cost
as
cost_model
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_to_desc
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_to_desc
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_desc_to_str
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_desc_to_str
from
paddle.distributed.auto_parallel.cost.base_cost
import
calc_time_
from_model
from
paddle.distributed.auto_parallel.cost.base_cost
import
calc_time_
by_modeling
paddle
.
enable_static
()
paddle
.
enable_static
()
...
@@ -45,13 +45,13 @@ class TestCost(unittest.TestCase):
...
@@ -45,13 +45,13 @@ class TestCost(unittest.TestCase):
if
op
.
type
==
"matmul_v2"
:
if
op
.
type
==
"matmul_v2"
:
matmul_v2_op
=
op
matmul_v2_op
=
op
break
break
matmul_v2_cost
=
cost_model
.
OP_COST_FACTORY
[
"matmul_v2"
](
matmul_v2_cost
=
cost_model
.
_g_op_cost_factory
[
"matmul_v2"
](
op
=
matmul_v2_op
)
op
=
matmul_v2_op
)
desc
=
parse_to_desc
(
op
=
matmul_v2_op
)
desc
=
parse_to_desc
(
op
=
matmul_v2_op
)
desc_str
=
parse_desc_to_str
(
desc
)
desc_str
=
parse_desc_to_str
(
desc
)
self
.
assertIsNotNone
(
desc_str
)
self
.
assertIsNotNone
(
desc_str
)
self
.
assertTrue
(
check_cost
(
matmul_v2_cost
.
cost
))
self
.
assertTrue
(
check_cost
(
matmul_v2_cost
.
cost
))
time
=
calc_time_
from_model
(
op
=
matmul_v2_op
)
time
=
calc_time_
by_modeling
(
op
=
matmul_v2_op
)
self
.
assertEqual
(
time
,
matmul_v2_cost
.
cost
.
time
)
self
.
assertEqual
(
time
,
matmul_v2_cost
.
cost
.
time
)
tensor_cost
=
cost_model
.
TensorCost
(
tensor
=
x
)
tensor_cost
=
cost_model
.
TensorCost
(
tensor
=
x
)
# check memory
# check memory
...
@@ -61,7 +61,8 @@ class TestCost(unittest.TestCase):
...
@@ -61,7 +61,8 @@ class TestCost(unittest.TestCase):
desc
=
{}
desc
=
{}
desc
[
"op"
]
=
"c_allreduce_sum"
desc
[
"op"
]
=
"c_allreduce_sum"
desc
[
"inputs"
]
=
{
"X"
:
[([
100
,
200
],
paddle
.
float32
)]}
desc
[
"inputs"
]
=
{
"X"
:
[([
100
,
200
],
paddle
.
float32
)]}
allreduce_cost
=
cost_model
.
OP_COST_FACTORY
[
"c_allreduce_sum"
](
desc
[
"group_ranks"
]
=
[
0
,
1
]
allreduce_cost
=
cost_model
.
_g_op_cost_factory
[
"c_allreduce_sum"
](
op_desc
=
desc
)
op_desc
=
desc
)
self
.
assertTrue
(
check_cost
(
allreduce_cost
.
cost
))
self
.
assertTrue
(
check_cost
(
allreduce_cost
.
cost
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录