Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c1c9368f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c1c9368f
编写于
3月 24, 2022
作者:
C
caozhou
提交者:
GitHub
3月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Update cost model (#40457)
* refactor cost model
上级
1b491818
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
679 addition
and
0 deletion
+679
-0
python/paddle/distributed/auto_parallel/cost/__init__.py
python/paddle/distributed/auto_parallel/cost/__init__.py
+20
-0
python/paddle/distributed/auto_parallel/cost/base_cost.py
python/paddle/distributed/auto_parallel/cost/base_cost.py
+342
-0
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
+28
-0
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
+33
-0
python/paddle/distributed/auto_parallel/cost/estimate_cost.py
...on/paddle/distributed/auto_parallel/cost/estimate_cost.py
+69
-0
python/paddle/distributed/auto_parallel/cost/tensor_cost.py
python/paddle/distributed/auto_parallel/cost/tensor_cost.py
+110
-0
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py
...luid/tests/unittests/auto_parallel/test_new_cost_model.py
+75
-0
python/setup.py.in
python/setup.py.in
+1
-0
未找到文件。
python/paddle/distributed/auto_parallel/cost/__init__.py
0 → 100644
浏览文件 @
c1c9368f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from
.base_cost
import
OP_COST_FACTORY
from
.base_cost
import
Cost
from
.comm_op_cost
import
AllreduceSumCost
from
.comp_op_cost
import
MatmulV2OpCost
from
.tensor_cost
import
TensorCost
from
.estimate_cost
import
CostEstimator
python/paddle/distributed/auto_parallel/cost/base_cost.py
0 → 100644
浏览文件 @
c1c9368f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from
collections
import
OrderedDict
import
paddle
COMM_OP_TYPE
=
[
"send_v2"
,
"recv_v2"
,
"c_broadcast"
,
"c_allgather"
,
"c_allreduce_sum"
]
NON_COMP_TYPE
=
[
"while"
]
+
COMM_OP_TYPE
OP_COST_FACTORY
=
{}
def
_parse_op_to_desc
(
op
,
dist_context
=
None
):
desc
=
{}
desc
[
"op"
]
=
op
.
type
vars
=
op
.
block
.
vars
input_desc
=
OrderedDict
()
for
input_name
in
op
.
input_names
:
var_name_list
=
op
.
input
(
input_name
)
var_desc
=
[]
for
var_name
in
var_name_list
:
var
=
vars
[
var_name
]
shape
=
None
if
dist_context
is
not
None
:
dist_tensor
=
dist_context
.
get_dist_tensor_for_program
(
var
)
shape
=
dist_tensor
.
local_sizes
()
else
:
shape
=
var
.
shape
assert
shape
is
not
None
var_desc
.
append
((
var
.
dtype
,
shape
))
input_desc
[
input_name
]
=
var_desc
desc
[
"inputs"
]
=
input_desc
output_desc
=
OrderedDict
()
for
out_name
in
op
.
output_names
:
var_name_list
=
op
.
output
(
out_name
)
var_desc
=
[]
for
var_name
in
var_name_list
:
var
=
vars
[
var_name
]
shape
=
None
if
dist_context
is
not
None
:
dist_tensor
=
dist_context
.
get_dist_tensor_for_program
(
var
)
shape
=
dist_tensor
.
local_sizes
()
else
:
shape
=
var
.
shape
assert
shape
is
not
None
var_desc
.
append
((
var
.
dtype
,
shape
))
output_desc
[
out_name
]
=
var_desc
desc
[
"outputs"
]
=
output_desc
attr_desc
=
op
.
all_attrs
desc
[
"attrs"
]
=
attr_desc
return
desc
def
parse_to_desc
(
op
=
None
,
dist_op
=
None
,
dist_context
=
None
):
desc
=
None
if
op
is
None
and
dist_op
is
not
None
and
dist_context
is
not
None
:
desc
=
_parse_op_to_desc
(
op
=
dist_op
.
serial_op
,
dist_context
=
dist_context
)
elif
op
is
not
None
and
dist_op
is
None
and
dist_context
is
None
:
desc
=
_parse_op_to_desc
(
op
)
return
desc
def
parse_desc_to_str
(
desc
):
def
_parse_dtype
(
dtype
):
dtype_str
=
""
if
dtype
==
paddle
.
float32
:
dtype_str
=
"float32"
elif
dtype
==
paddle
.
float16
:
dtype_str
=
"float16"
elif
dtype
==
paddle
.
int32
:
dtype_str
=
"int32"
elif
dtype
==
paddle
.
int64
:
dtype_str
=
"int64"
elif
dtype
==
paddle
.
unit8
:
dtype_str
=
"unit8"
else
:
raise
TypeError
(
"Unsupported dtype {}"
.
format
(
dtype
))
return
dtype_str
assert
isinstance
(
desc
,
dict
)
desc_str_list
=
[]
desc_str
=
None
dtype_str_list
=
[]
dims_list
=
[]
shape_list
=
[]
desc_str_list
.
append
(
desc
[
"op"
])
inputs
=
desc
[
"inputs"
]
for
key
,
item
in
inputs
.
items
():
for
dtype
,
shape
in
item
:
dtype_str_list
.
append
(
_parse_dtype
(
dtype
))
shape_list
+=
list
(
shape
)
dims
=
len
(
shape
)
dims_list
.
append
(
dims
)
dtype_str
=
"*"
.
join
(
dtype_str_list
)
dims_list
=
[
str
(
item
)
for
item
in
dims_list
]
dims_str
=
"*"
.
join
(
dims_list
)
shape_list
=
[
str
(
item
)
for
item
in
shape_list
]
shape_str
=
"["
+
","
.
join
(
shape_list
)
+
"]"
desc_str_list
+=
[
dtype_str
,
dims_str
,
shape_str
]
desc_str
=
"_"
.
join
(
desc_str_list
)
return
desc_str
class
CommContext
:
_instance
=
None
_has_instance
=
False
def
__init__
(
self
,
cluster
):
if
CommContext
.
_has_instance
:
return
self
.
cluster
=
cluster
self
.
_alpha_base_ring
=
8.4
self
.
_alpha_base_tree
=
0
self
.
_alpha_inter
=
None
self
.
_alpha_intra
self
.
_beta
=
{}
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
_has_instance
=
True
return
cls
.
_instance
@
property
def
alpha_inter
(
self
):
if
self
.
_alpha_inter
is
None
:
if
cluster
.
alpha
.
inter
==
"NVL"
:
self
.
_alpha_inter
=
3.4
elif
cluster
.
alpha
.
inter
==
"PHB"
:
self
.
_alpha_inter
=
5.7
return
self
.
_alpha_inter
@
property
def
alpha_intra
(
self
):
if
self
.
_alpha_intra
is
None
:
if
cluster
.
alpha
.
intra
==
"NVL"
:
self
.
_alpha_intra
=
28
elif
cluster
.
alpha
.
intra
==
"PHB"
:
self
.
_alpha_intra
=
28
return
self
.
_alpha_intra
@
property
def
alpha_base_ring
(
self
):
return
self
.
_alpha_base_ring
@
property
def
alpha_base_tree
(
self
):
return
self
.
_alpha_base_tree
def
get_beta
(
self
,
ranks
):
key
=
','
.
join
(
map
(
str
,
sorted
(
ranks
)))
max_beta
=
None
if
key
in
self
.
_beta
.
keys
:
max_beta
=
self
.
_beta
[
key
]
else
:
for
i
in
range
(
len
(
ranks
)):
for
j
in
range
(
i
+
1
,
len
(
ranks
)):
if
min_beta
==
None
:
min_beta
=
cluster
.
get_beta
(
ranks
[
i
],
ranks
[
j
])
else
:
beta
=
cluster
.
get_beta
(
ranks
[
i
],
ranks
[
j
])
if
beta
>
max_beta
:
max_beta
=
beta
self
.
_beta
[
key
]
=
max_beta
return
max_beta
class
Cost
:
def
__init__
(
self
,
time
=
0
,
memory
=
0
,
flops
=
0
):
self
.
time
=
time
self
.
memory
=
memory
self
.
flops
=
flops
def
_check_time
(
self
,
val
):
assert
val
>=
0
,
"Time must be greater than or equal to 0."
def
_check_memory
(
self
,
val
):
assert
isinstance
(
val
,
int
)
and
val
>=
0
,
"Memory must be int and greater than 0."
def
_check_flops
(
self
,
val
):
assert
isinstance
(
val
,
int
)
and
val
>=
0
,
"FLOPs must be int and greater than 0."
@
property
def
time
(
self
):
return
self
.
_time
@
time
.
setter
def
time
(
self
,
val
):
self
.
_check_time
(
val
)
self
.
_time
=
val
@
property
def
memory
(
self
):
return
self
.
_memory
@
memory
.
setter
def
memory
(
self
,
val
):
self
.
_check_memory
(
val
)
self
.
_memory
=
val
@
property
def
flops
(
self
):
return
self
.
_flops
@
flops
.
setter
def
flops
(
self
,
val
):
self
.
_check_flops
(
val
)
self
.
_flops
=
val
def
__add__
(
self
,
rhs
):
assert
isinstance
(
rhs
,
Cost
)
time
=
self
.
time
+
rhs
.
time
memory
=
self
.
memory
+
rhs
.
memory
flops
=
self
.
flops
+
rhs
.
flops
assert
(
time
>=
0
and
memory
>=
0
and
flops
>=
0
)
return
Cost
(
time
,
memory
,
flops
)
def
__sub__
(
self
,
rhs
):
assert
isinstance
(
rhs
,
Cost
)
time
=
self
.
time
-
rhs
.
time
memory
=
self
.
memory
-
rhs
.
memory
flops
=
self
.
flops
-
rhs
.
flops
assert
(
time
>=
0
and
memory
>=
0
and
flops
>=
0
)
return
Cost
(
time
,
memory
,
flops
)
class
OpCost
:
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
):
assert
(
op
is
not
None
and
op_desc
is
None
)
or
(
op
is
None
and
op_desc
is
not
None
)
self
.
_op
=
op
self
.
_op_desc
=
op_desc
self
.
_cost
=
self
.
calc_cost
()
@
property
def
op
(
self
):
return
self
.
_op
@
property
def
op_desc
(
self
):
return
self
.
_op_desc
@
property
def
cost
(
self
):
return
self
.
_cost
def
calc_time
(
self
):
return
0
def
calc_memory
(
self
):
return
0
def
calc_flops
(
self
):
return
0
def
calc_cost
(
self
):
time
=
self
.
calc_time
()
memory
=
self
.
calc_memory
()
flops
=
self
.
calc_flops
()
cost
=
Cost
(
time
,
memory
,
flops
)
return
cost
class
CommOpCost
(
OpCost
):
OP_TYPE
=
"COMM"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
comm_context
=
None
):
super
(
CommOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
)
self
.
_check_comm_op_type
()
self
.
_comm_context
=
comm_context
@
property
def
comm_context
(
self
):
return
self
.
_comm_context
@
classmethod
def
_check_comm_op_type
(
cls
):
if
cls
.
OP_TYPE
!=
"COMM"
:
if
cls
.
OP_TYPE
not
in
COMM_OP_TYPE
:
raise
TypeError
(
"Please Check op type in {}, but got {}."
.
format
(
COMM_OP_TYPE
,
cls
.
OP_TYPE
))
class
CompOpCost
(
OpCost
):
OP_TYPE
=
"COMP"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
(
CompOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
)
self
.
_check_comp_op_type
()
self
.
cluster
=
cluster
@
classmethod
def
_check_comp_op_type
(
cls
):
if
cls
.
OP_TYPE
!=
"COMP"
:
if
cls
.
OP_TYPE
in
NON_COMP_TYPE
:
raise
TypeError
(
"Please Check op type not in {}, but got {}."
.
format
(
NON_COMP_TYPE
,
cls
.
OP_TYPE
))
def
register_op_cost
(
cls
):
op_type
=
cls
.
OP_TYPE
def
register
(
op_type
):
OP_COST_FACTORY
[
op_type
]
=
cls
return
register
(
op_type
)
def
calc_time_from_model
(
op
=
None
,
desc
=
None
,
cluster
=
None
,
comm_context
=
None
):
op_type
=
op
.
type
if
op
is
not
None
else
desc
[
"op"
]
if
op_type
in
COMM_OP_TYPE
:
op_cost
=
OP_COST_FACTORY
[
op_type
](
op
=
op
,
op_desc
=
desc
,
comm_context
=
comm_context
)
elif
op_type
not
in
NON_COMP_TYPE
:
op_cost
=
OP_COST_FACTORY
[
op_type
](
op
=
op
,
op_desc
=
desc
,
cluster
=
cluster
)
time
=
op_cost
.
calc_time
()
return
time
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
0 → 100644
浏览文件 @
c1c9368f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from
.base_cost
import
register_op_cost
,
CommOpCost
,
OP_COST_FACTORY
@
register_op_cost
class
AllreduceSumCost
(
CommOpCost
):
OP_TYPE
=
"c_allreduce_sum"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
comm_context
=
None
):
super
(
OP_COST_FACTORY
[
"c_allreduce_sum"
],
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future.
return
0
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
0 → 100644
浏览文件 @
c1c9368f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from
.base_cost
import
Cost
,
register_op_cost
,
CompOpCost
,
OP_COST_FACTORY
@
register_op_cost
class
MatmulV2OpCost
(
CompOpCost
):
OP_TYPE
=
"matmul_v2"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
(
OP_COST_FACTORY
[
"matmul_v2"
],
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function needs to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
python/paddle/distributed/auto_parallel/cost/estimate_cost.py
0 → 100644
浏览文件 @
c1c9368f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
class
CostEstimator
:
def
__init__
(
self
,
program
,
cluster
=
None
,
dist_context
=
None
,
mode
=
"modeling"
):
self
.
_program
=
program
self
.
_cluster
=
cluster
self
.
_dist_context
=
dist_context
self
.
_check_mode
(
mode
)
self
.
_mode
=
mode
self
.
_global_cost
=
None
self
.
_local_cost
=
{}
@
property
def
program
(
self
):
return
self
.
_program
@
property
def
dist_context
(
self
):
return
self
.
_dist_context
@
property
def
cluster
(
self
):
return
self
.
_cluster
@
property
def
mode
(
self
):
return
self
.
_mode
@
property
def
global_cost
(
self
):
return
self
.
_global_cost
@
property
def
local_cost
(
self
):
return
self
.
_local_cost
def
get_op_cost
(
self
):
return
0
def
get_tensor_cost
(
self
):
return
0
def
get_global_cost
(
self
):
return
0
def
get_local_cost
(
self
,
rank
=
None
):
return
0
def
_check_mode
(
self
,
mode
):
if
mode
not
in
[
"modeling"
,
"profiling"
]:
raise
ValueError
(
"Just support modeling and profiling, but got {}"
.
format
(
mode
))
python/paddle/distributed/auto_parallel/cost/tensor_cost.py
0 → 100644
浏览文件 @
c1c9368f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from
functools
import
reduce
import
paddle
from
paddle.fluid.framework
import
Variable
from
paddle.distributed.auto_parallel.dist_tensor
import
DistributedTensor
from
.base_cost
import
Cost
class
TensorCost
:
def
__init__
(
self
,
tensor
=
None
,
dist_tensor
=
None
,
shape
=
None
,
dtype
=
None
):
self
.
_check_args
(
tensor
,
dist_tensor
,
shape
,
dtype
)
self
.
_tensor
=
tensor
self
.
_dist_tensor
=
dist_tensor
self
.
_shape
=
shape
self
.
_dtype
=
dtype
self
.
_cost
=
self
.
calc_cost
()
@
property
def
tensor
(
self
):
return
self
.
_tensor
@
property
def
dist_tensor
(
self
):
return
self
.
_dist_tensor
@
property
def
shape
(
self
):
return
self
.
_shape
@
property
def
dtype
(
self
):
return
self
.
_dtype
def
_check_args
(
self
,
tensor
,
dist_tensor
,
shape
,
dtype
):
if
tensor
is
not
None
:
assert
(
shape
is
None
and
dist_tensor
is
None
and
dtype
is
None
)
if
not
isinstance
(
tensor
,
Variable
):
raise
TypeError
(
"Please check tensor type is Variable, but got {}"
.
format
(
type
(
tensor
)))
elif
dist_tensor
is
not
None
:
assert
(
tensor
is
None
and
shape
is
None
)
if
not
isinstance
(
dist_tensor
,
DistributedTensor
):
raise
TypeError
(
"Please check dist_tensor type is DistributedTensor, but got {}"
.
format
(
type
(
dist_tensor
)))
elif
shape
is
not
None
:
assert
(
tensor
is
None
and
dist_tensor
is
None
and
dtype
is
not
None
)
if
not
isinstance
(
shape
,
(
list
,
set
)):
raise
TypeError
(
"Please check shape type is list or set, but got {}"
.
format
(
type
(
shape
)))
elif
dtype
is
not
None
:
assert
(
tensor
is
None
and
dist_tensor
is
None
and
shape
is
not
None
)
@
property
def
cost
(
self
):
return
self
.
_cost
def
calc_cost
(
self
):
dtype
=
None
shape
=
None
if
self
.
dist_tensor
:
shape
=
self
.
dist_tensor
.
local_sizes
()
dtype
=
self
.
dist_tensor
.
serial_tensor
.
dtype
elif
self
.
tensor
:
shape
=
self
.
tensor
.
shape
dtype
=
self
.
tensor
.
dtype
elif
self
.
shape
and
self
.
dtype
:
shape
=
self
.
shape
dtype
=
self
.
dtype
total_count
=
reduce
(
lambda
x
,
y
:
x
*
y
,
shape
)
if
dtype
==
paddle
.
float32
or
dtype
==
paddle
.
int32
:
dtype_factor
=
4
elif
node
.
dtype
==
paddle
.
int64
:
dtype_factor
=
8
elif
node
.
dtype
==
paddle
.
uint8
:
dtype_factor
=
1
else
:
dtype_factor
=
2
memory
=
total_count
*
dtype_factor
assert
memory
>=
0
cost
=
Cost
(
memory
=
memory
)
return
cost
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
c1c9368f
...
@@ -17,4 +17,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
...
@@ -17,4 +17,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_tunable_space MODULES test_tunable_space ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_tunable_space MODULES test_tunable_space ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_recorder MODULES test_recorder ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_recorder MODULES test_recorder ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_trial MODULES test_trial ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_trial MODULES test_trial ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_new_cost_model MODULES test_new_cost_model ENVS
${
dist_ENVS
}
)
endif
()
endif
()
python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py
0 → 100644
浏览文件 @
c1c9368f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
paddle.distributed.auto_parallel.cost
as
cost_model
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_to_desc
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_desc_to_str
from
paddle.distributed.auto_parallel.cost.base_cost
import
calc_time_from_model
paddle
.
enable_static
()
def
check_cost
(
cost
):
if
cost
.
memory
>=
0
and
cost
.
flops
>=
0
and
cost
.
time
>=
0
:
return
True
return
False
class
TestCost
(
unittest
.
TestCase
):
def
test_base_cost
(
self
):
cost
=
cost_model
.
Cost
(
memory
=
100
,
flops
=
200
,
time
=
0.5
)
self
.
assertTrue
(
check_cost
(
cost
))
def
test_comp_cost
(
self
):
x
=
paddle
.
static
.
data
(
name
=
"x"
,
shape
=
[
20
,
20
],
dtype
=
'float32'
)
y
=
paddle
.
static
.
data
(
name
=
"y"
,
shape
=
[
20
,
20
],
dtype
=
'float32'
)
z
=
paddle
.
matmul
(
x
,
y
)
matmul_v2_op
=
None
ops
=
paddle
.
static
.
default_main_program
().
global_block
().
ops
for
op
in
ops
:
if
op
.
type
==
"matmul_v2"
:
matmul_v2_op
=
op
break
matmul_v2_cost
=
cost_model
.
OP_COST_FACTORY
[
"matmul_v2"
](
op
=
matmul_v2_op
)
desc
=
parse_to_desc
(
op
=
matmul_v2_op
)
desc_str
=
parse_desc_to_str
(
desc
)
self
.
assertIsNotNone
(
desc_str
)
self
.
assertTrue
(
check_cost
(
matmul_v2_cost
.
cost
))
time
=
calc_time_from_model
(
op
=
matmul_v2_op
)
self
.
assertEqual
(
time
,
matmul_v2_cost
.
cost
.
time
)
tensor_cost
=
cost_model
.
TensorCost
(
tensor
=
x
)
# check memory
self
.
assertEqual
(
tensor_cost
.
cost
.
memory
,
1600
)
def
test_comm_cost
(
self
):
desc
=
{}
desc
[
"op"
]
=
"c_allreduce_sum"
desc
[
"inputs"
]
=
{
"X"
:
[([
100
,
200
],
paddle
.
float32
)]}
allreduce_cost
=
cost_model
.
OP_COST_FACTORY
[
"c_allreduce_sum"
](
op_desc
=
desc
)
self
.
assertTrue
(
check_cost
(
allreduce_cost
.
cost
))
def
test_cost_estimator
(
self
):
train_program
=
paddle
.
static
.
Program
()
cost_estimator
=
cost_model
.
CostEstimator
(
train_program
)
self
.
assertIsNotNone
(
cost_estimator
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/setup.py.in
浏览文件 @
c1c9368f
...
@@ -307,6 +307,7 @@ packages=['paddle',
...
@@ -307,6 +307,7 @@ packages=['paddle',
'paddle.distributed.auto_parallel',
'paddle.distributed.auto_parallel',
'paddle.distributed.auto_parallel.operators',
'paddle.distributed.auto_parallel.operators',
'paddle.distributed.auto_parallel.tuner',
'paddle.distributed.auto_parallel.tuner',
'paddle.distributed.auto_parallel.cost',
'paddle.distributed.passes',
'paddle.distributed.passes',
'paddle.framework',
'paddle.framework',
'paddle.jit',
'paddle.jit',
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录