Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
c1c9368f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c1c9368f
编写于
3月 24, 2022
作者:
C
caozhou
提交者:
GitHub
3月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Update cost model (#40457)
* refactor cost model
上级
1b491818
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
679 addition
and
0 deletion
+679
-0
python/paddle/distributed/auto_parallel/cost/__init__.py
python/paddle/distributed/auto_parallel/cost/__init__.py
+20
-0
python/paddle/distributed/auto_parallel/cost/base_cost.py
python/paddle/distributed/auto_parallel/cost/base_cost.py
+342
-0
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
+28
-0
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
+33
-0
python/paddle/distributed/auto_parallel/cost/estimate_cost.py
...on/paddle/distributed/auto_parallel/cost/estimate_cost.py
+69
-0
python/paddle/distributed/auto_parallel/cost/tensor_cost.py
python/paddle/distributed/auto_parallel/cost/tensor_cost.py
+110
-0
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py
...luid/tests/unittests/auto_parallel/test_new_cost_model.py
+75
-0
python/setup.py.in
python/setup.py.in
+1
-0
未找到文件。
python/paddle/distributed/auto_parallel/cost/__init__.py
0 → 100644
浏览文件 @
c1c9368f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from
.base_cost
import
OP_COST_FACTORY
from
.base_cost
import
Cost
from
.comm_op_cost
import
AllreduceSumCost
from
.comp_op_cost
import
MatmulV2OpCost
from
.tensor_cost
import
TensorCost
from
.estimate_cost
import
CostEstimator
python/paddle/distributed/auto_parallel/cost/base_cost.py
0 → 100644
浏览文件 @
c1c9368f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from
collections
import
OrderedDict
import
paddle
COMM_OP_TYPE
=
[
"send_v2"
,
"recv_v2"
,
"c_broadcast"
,
"c_allgather"
,
"c_allreduce_sum"
]
NON_COMP_TYPE
=
[
"while"
]
+
COMM_OP_TYPE
OP_COST_FACTORY
=
{}
def
_parse_op_to_desc
(
op
,
dist_context
=
None
):
desc
=
{}
desc
[
"op"
]
=
op
.
type
vars
=
op
.
block
.
vars
input_desc
=
OrderedDict
()
for
input_name
in
op
.
input_names
:
var_name_list
=
op
.
input
(
input_name
)
var_desc
=
[]
for
var_name
in
var_name_list
:
var
=
vars
[
var_name
]
shape
=
None
if
dist_context
is
not
None
:
dist_tensor
=
dist_context
.
get_dist_tensor_for_program
(
var
)
shape
=
dist_tensor
.
local_sizes
()
else
:
shape
=
var
.
shape
assert
shape
is
not
None
var_desc
.
append
((
var
.
dtype
,
shape
))
input_desc
[
input_name
]
=
var_desc
desc
[
"inputs"
]
=
input_desc
output_desc
=
OrderedDict
()
for
out_name
in
op
.
output_names
:
var_name_list
=
op
.
output
(
out_name
)
var_desc
=
[]
for
var_name
in
var_name_list
:
var
=
vars
[
var_name
]
shape
=
None
if
dist_context
is
not
None
:
dist_tensor
=
dist_context
.
get_dist_tensor_for_program
(
var
)
shape
=
dist_tensor
.
local_sizes
()
else
:
shape
=
var
.
shape
assert
shape
is
not
None
var_desc
.
append
((
var
.
dtype
,
shape
))
output_desc
[
out_name
]
=
var_desc
desc
[
"outputs"
]
=
output_desc
attr_desc
=
op
.
all_attrs
desc
[
"attrs"
]
=
attr_desc
return
desc
def
parse_to_desc
(
op
=
None
,
dist_op
=
None
,
dist_context
=
None
):
desc
=
None
if
op
is
None
and
dist_op
is
not
None
and
dist_context
is
not
None
:
desc
=
_parse_op_to_desc
(
op
=
dist_op
.
serial_op
,
dist_context
=
dist_context
)
elif
op
is
not
None
and
dist_op
is
None
and
dist_context
is
None
:
desc
=
_parse_op_to_desc
(
op
)
return
desc
def
parse_desc_to_str
(
desc
):
def
_parse_dtype
(
dtype
):
dtype_str
=
""
if
dtype
==
paddle
.
float32
:
dtype_str
=
"float32"
elif
dtype
==
paddle
.
float16
:
dtype_str
=
"float16"
elif
dtype
==
paddle
.
int32
:
dtype_str
=
"int32"
elif
dtype
==
paddle
.
int64
:
dtype_str
=
"int64"
elif
dtype
==
paddle
.
unit8
:
dtype_str
=
"unit8"
else
:
raise
TypeError
(
"Unsupported dtype {}"
.
format
(
dtype
))
return
dtype_str
assert
isinstance
(
desc
,
dict
)
desc_str_list
=
[]
desc_str
=
None
dtype_str_list
=
[]
dims_list
=
[]
shape_list
=
[]
desc_str_list
.
append
(
desc
[
"op"
])
inputs
=
desc
[
"inputs"
]
for
key
,
item
in
inputs
.
items
():
for
dtype
,
shape
in
item
:
dtype_str_list
.
append
(
_parse_dtype
(
dtype
))
shape_list
+=
list
(
shape
)
dims
=
len
(
shape
)
dims_list
.
append
(
dims
)
dtype_str
=
"*"
.
join
(
dtype_str_list
)
dims_list
=
[
str
(
item
)
for
item
in
dims_list
]
dims_str
=
"*"
.
join
(
dims_list
)
shape_list
=
[
str
(
item
)
for
item
in
shape_list
]
shape_str
=
"["
+
","
.
join
(
shape_list
)
+
"]"
desc_str_list
+=
[
dtype_str
,
dims_str
,
shape_str
]
desc_str
=
"_"
.
join
(
desc_str_list
)
return
desc_str
class
CommContext
:
_instance
=
None
_has_instance
=
False
def
__init__
(
self
,
cluster
):
if
CommContext
.
_has_instance
:
return
self
.
cluster
=
cluster
self
.
_alpha_base_ring
=
8.4
self
.
_alpha_base_tree
=
0
self
.
_alpha_inter
=
None
self
.
_alpha_intra
self
.
_beta
=
{}
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
_has_instance
=
True
return
cls
.
_instance
@
property
def
alpha_inter
(
self
):
if
self
.
_alpha_inter
is
None
:
if
cluster
.
alpha
.
inter
==
"NVL"
:
self
.
_alpha_inter
=
3.4
elif
cluster
.
alpha
.
inter
==
"PHB"
:
self
.
_alpha_inter
=
5.7
return
self
.
_alpha_inter
@
property
def
alpha_intra
(
self
):
if
self
.
_alpha_intra
is
None
:
if
cluster
.
alpha
.
intra
==
"NVL"
:
self
.
_alpha_intra
=
28
elif
cluster
.
alpha
.
intra
==
"PHB"
:
self
.
_alpha_intra
=
28
return
self
.
_alpha_intra
@
property
def
alpha_base_ring
(
self
):
return
self
.
_alpha_base_ring
@
property
def
alpha_base_tree
(
self
):
return
self
.
_alpha_base_tree
def
get_beta
(
self
,
ranks
):
key
=
','
.
join
(
map
(
str
,
sorted
(
ranks
)))
max_beta
=
None
if
key
in
self
.
_beta
.
keys
:
max_beta
=
self
.
_beta
[
key
]
else
:
for
i
in
range
(
len
(
ranks
)):
for
j
in
range
(
i
+
1
,
len
(
ranks
)):
if
min_beta
==
None
:
min_beta
=
cluster
.
get_beta
(
ranks
[
i
],
ranks
[
j
])
else
:
beta
=
cluster
.
get_beta
(
ranks
[
i
],
ranks
[
j
])
if
beta
>
max_beta
:
max_beta
=
beta
self
.
_beta
[
key
]
=
max_beta
return
max_beta
class
Cost
:
def
__init__
(
self
,
time
=
0
,
memory
=
0
,
flops
=
0
):
self
.
time
=
time
self
.
memory
=
memory
self
.
flops
=
flops
def
_check_time
(
self
,
val
):
assert
val
>=
0
,
"Time must be greater than or equal to 0."
def
_check_memory
(
self
,
val
):
assert
isinstance
(
val
,
int
)
and
val
>=
0
,
"Memory must be int and greater than 0."
def
_check_flops
(
self
,
val
):
assert
isinstance
(
val
,
int
)
and
val
>=
0
,
"FLOPs must be int and greater than 0."
@
property
def
time
(
self
):
return
self
.
_time
@
time
.
setter
def
time
(
self
,
val
):
self
.
_check_time
(
val
)
self
.
_time
=
val
@
property
def
memory
(
self
):
return
self
.
_memory
@
memory
.
setter
def
memory
(
self
,
val
):
self
.
_check_memory
(
val
)
self
.
_memory
=
val
@
property
def
flops
(
self
):
return
self
.
_flops
@
flops
.
setter
def
flops
(
self
,
val
):
self
.
_check_flops
(
val
)
self
.
_flops
=
val
def
__add__
(
self
,
rhs
):
assert
isinstance
(
rhs
,
Cost
)
time
=
self
.
time
+
rhs
.
time
memory
=
self
.
memory
+
rhs
.
memory
flops
=
self
.
flops
+
rhs
.
flops
assert
(
time
>=
0
and
memory
>=
0
and
flops
>=
0
)
return
Cost
(
time
,
memory
,
flops
)
def
__sub__
(
self
,
rhs
):
assert
isinstance
(
rhs
,
Cost
)
time
=
self
.
time
-
rhs
.
time
memory
=
self
.
memory
-
rhs
.
memory
flops
=
self
.
flops
-
rhs
.
flops
assert
(
time
>=
0
and
memory
>=
0
and
flops
>=
0
)
return
Cost
(
time
,
memory
,
flops
)
class
OpCost
:
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
):
assert
(
op
is
not
None
and
op_desc
is
None
)
or
(
op
is
None
and
op_desc
is
not
None
)
self
.
_op
=
op
self
.
_op_desc
=
op_desc
self
.
_cost
=
self
.
calc_cost
()
@
property
def
op
(
self
):
return
self
.
_op
@
property
def
op_desc
(
self
):
return
self
.
_op_desc
@
property
def
cost
(
self
):
return
self
.
_cost
def
calc_time
(
self
):
return
0
def
calc_memory
(
self
):
return
0
def
calc_flops
(
self
):
return
0
def
calc_cost
(
self
):
time
=
self
.
calc_time
()
memory
=
self
.
calc_memory
()
flops
=
self
.
calc_flops
()
cost
=
Cost
(
time
,
memory
,
flops
)
return
cost
class
CommOpCost
(
OpCost
):
OP_TYPE
=
"COMM"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
comm_context
=
None
):
super
(
CommOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
)
self
.
_check_comm_op_type
()
self
.
_comm_context
=
comm_context
@
property
def
comm_context
(
self
):
return
self
.
_comm_context
@
classmethod
def
_check_comm_op_type
(
cls
):
if
cls
.
OP_TYPE
!=
"COMM"
:
if
cls
.
OP_TYPE
not
in
COMM_OP_TYPE
:
raise
TypeError
(
"Please Check op type in {}, but got {}."
.
format
(
COMM_OP_TYPE
,
cls
.
OP_TYPE
))
class
CompOpCost
(
OpCost
):
OP_TYPE
=
"COMP"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
(
CompOpCost
,
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
)
self
.
_check_comp_op_type
()
self
.
cluster
=
cluster
@
classmethod
def
_check_comp_op_type
(
cls
):
if
cls
.
OP_TYPE
!=
"COMP"
:
if
cls
.
OP_TYPE
in
NON_COMP_TYPE
:
raise
TypeError
(
"Please Check op type not in {}, but got {}."
.
format
(
NON_COMP_TYPE
,
cls
.
OP_TYPE
))
def
register_op_cost
(
cls
):
op_type
=
cls
.
OP_TYPE
def
register
(
op_type
):
OP_COST_FACTORY
[
op_type
]
=
cls
return
register
(
op_type
)
def
calc_time_from_model
(
op
=
None
,
desc
=
None
,
cluster
=
None
,
comm_context
=
None
):
op_type
=
op
.
type
if
op
is
not
None
else
desc
[
"op"
]
if
op_type
in
COMM_OP_TYPE
:
op_cost
=
OP_COST_FACTORY
[
op_type
](
op
=
op
,
op_desc
=
desc
,
comm_context
=
comm_context
)
elif
op_type
not
in
NON_COMP_TYPE
:
op_cost
=
OP_COST_FACTORY
[
op_type
](
op
=
op
,
op_desc
=
desc
,
cluster
=
cluster
)
time
=
op_cost
.
calc_time
()
return
time
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
0 → 100644
浏览文件 @
c1c9368f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from
.base_cost
import
register_op_cost
,
CommOpCost
,
OP_COST_FACTORY
@
register_op_cost
class
AllreduceSumCost
(
CommOpCost
):
OP_TYPE
=
"c_allreduce_sum"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
comm_context
=
None
):
super
(
OP_COST_FACTORY
[
"c_allreduce_sum"
],
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future.
return
0
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
0 → 100644
浏览文件 @
c1c9368f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from
.base_cost
import
Cost
,
register_op_cost
,
CompOpCost
,
OP_COST_FACTORY
@
register_op_cost
class
MatmulV2OpCost
(
CompOpCost
):
OP_TYPE
=
"matmul_v2"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
(
OP_COST_FACTORY
[
"matmul_v2"
],
self
).
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function needs to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
python/paddle/distributed/auto_parallel/cost/estimate_cost.py
0 → 100644
浏览文件 @
c1c9368f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
class
CostEstimator
:
def
__init__
(
self
,
program
,
cluster
=
None
,
dist_context
=
None
,
mode
=
"modeling"
):
self
.
_program
=
program
self
.
_cluster
=
cluster
self
.
_dist_context
=
dist_context
self
.
_check_mode
(
mode
)
self
.
_mode
=
mode
self
.
_global_cost
=
None
self
.
_local_cost
=
{}
@
property
def
program
(
self
):
return
self
.
_program
@
property
def
dist_context
(
self
):
return
self
.
_dist_context
@
property
def
cluster
(
self
):
return
self
.
_cluster
@
property
def
mode
(
self
):
return
self
.
_mode
@
property
def
global_cost
(
self
):
return
self
.
_global_cost
@
property
def
local_cost
(
self
):
return
self
.
_local_cost
def
get_op_cost
(
self
):
return
0
def
get_tensor_cost
(
self
):
return
0
def
get_global_cost
(
self
):
return
0
def
get_local_cost
(
self
,
rank
=
None
):
return
0
def
_check_mode
(
self
,
mode
):
if
mode
not
in
[
"modeling"
,
"profiling"
]:
raise
ValueError
(
"Just support modeling and profiling, but got {}"
.
format
(
mode
))
python/paddle/distributed/auto_parallel/cost/tensor_cost.py
0 → 100644
浏览文件 @
c1c9368f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from
functools
import
reduce
import
paddle
from
paddle.fluid.framework
import
Variable
from
paddle.distributed.auto_parallel.dist_tensor
import
DistributedTensor
from
.base_cost
import
Cost
class
TensorCost
:
def
__init__
(
self
,
tensor
=
None
,
dist_tensor
=
None
,
shape
=
None
,
dtype
=
None
):
self
.
_check_args
(
tensor
,
dist_tensor
,
shape
,
dtype
)
self
.
_tensor
=
tensor
self
.
_dist_tensor
=
dist_tensor
self
.
_shape
=
shape
self
.
_dtype
=
dtype
self
.
_cost
=
self
.
calc_cost
()
@
property
def
tensor
(
self
):
return
self
.
_tensor
@
property
def
dist_tensor
(
self
):
return
self
.
_dist_tensor
@
property
def
shape
(
self
):
return
self
.
_shape
@
property
def
dtype
(
self
):
return
self
.
_dtype
def
_check_args
(
self
,
tensor
,
dist_tensor
,
shape
,
dtype
):
if
tensor
is
not
None
:
assert
(
shape
is
None
and
dist_tensor
is
None
and
dtype
is
None
)
if
not
isinstance
(
tensor
,
Variable
):
raise
TypeError
(
"Please check tensor type is Variable, but got {}"
.
format
(
type
(
tensor
)))
elif
dist_tensor
is
not
None
:
assert
(
tensor
is
None
and
shape
is
None
)
if
not
isinstance
(
dist_tensor
,
DistributedTensor
):
raise
TypeError
(
"Please check dist_tensor type is DistributedTensor, but got {}"
.
format
(
type
(
dist_tensor
)))
elif
shape
is
not
None
:
assert
(
tensor
is
None
and
dist_tensor
is
None
and
dtype
is
not
None
)
if
not
isinstance
(
shape
,
(
list
,
set
)):
raise
TypeError
(
"Please check shape type is list or set, but got {}"
.
format
(
type
(
shape
)))
elif
dtype
is
not
None
:
assert
(
tensor
is
None
and
dist_tensor
is
None
and
shape
is
not
None
)
@
property
def
cost
(
self
):
return
self
.
_cost
def
calc_cost
(
self
):
dtype
=
None
shape
=
None
if
self
.
dist_tensor
:
shape
=
self
.
dist_tensor
.
local_sizes
()
dtype
=
self
.
dist_tensor
.
serial_tensor
.
dtype
elif
self
.
tensor
:
shape
=
self
.
tensor
.
shape
dtype
=
self
.
tensor
.
dtype
elif
self
.
shape
and
self
.
dtype
:
shape
=
self
.
shape
dtype
=
self
.
dtype
total_count
=
reduce
(
lambda
x
,
y
:
x
*
y
,
shape
)
if
dtype
==
paddle
.
float32
or
dtype
==
paddle
.
int32
:
dtype_factor
=
4
elif
node
.
dtype
==
paddle
.
int64
:
dtype_factor
=
8
elif
node
.
dtype
==
paddle
.
uint8
:
dtype_factor
=
1
else
:
dtype_factor
=
2
memory
=
total_count
*
dtype_factor
assert
memory
>=
0
cost
=
Cost
(
memory
=
memory
)
return
cost
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
c1c9368f
...
@@ -17,4 +17,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
...
@@ -17,4 +17,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_tunable_space MODULES test_tunable_space ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_tunable_space MODULES test_tunable_space ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_recorder MODULES test_recorder ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_recorder MODULES test_recorder ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_trial MODULES test_trial ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_trial MODULES test_trial ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_new_cost_model MODULES test_new_cost_model ENVS
${
dist_ENVS
}
)
endif
()
endif
()
python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py
0 → 100644
浏览文件 @
c1c9368f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
paddle.distributed.auto_parallel.cost
as
cost_model
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_to_desc
from
paddle.distributed.auto_parallel.cost.base_cost
import
parse_desc_to_str
from
paddle.distributed.auto_parallel.cost.base_cost
import
calc_time_from_model
paddle
.
enable_static
()
def
check_cost
(
cost
):
if
cost
.
memory
>=
0
and
cost
.
flops
>=
0
and
cost
.
time
>=
0
:
return
True
return
False
class
TestCost
(
unittest
.
TestCase
):
def
test_base_cost
(
self
):
cost
=
cost_model
.
Cost
(
memory
=
100
,
flops
=
200
,
time
=
0.5
)
self
.
assertTrue
(
check_cost
(
cost
))
def
test_comp_cost
(
self
):
x
=
paddle
.
static
.
data
(
name
=
"x"
,
shape
=
[
20
,
20
],
dtype
=
'float32'
)
y
=
paddle
.
static
.
data
(
name
=
"y"
,
shape
=
[
20
,
20
],
dtype
=
'float32'
)
z
=
paddle
.
matmul
(
x
,
y
)
matmul_v2_op
=
None
ops
=
paddle
.
static
.
default_main_program
().
global_block
().
ops
for
op
in
ops
:
if
op
.
type
==
"matmul_v2"
:
matmul_v2_op
=
op
break
matmul_v2_cost
=
cost_model
.
OP_COST_FACTORY
[
"matmul_v2"
](
op
=
matmul_v2_op
)
desc
=
parse_to_desc
(
op
=
matmul_v2_op
)
desc_str
=
parse_desc_to_str
(
desc
)
self
.
assertIsNotNone
(
desc_str
)
self
.
assertTrue
(
check_cost
(
matmul_v2_cost
.
cost
))
time
=
calc_time_from_model
(
op
=
matmul_v2_op
)
self
.
assertEqual
(
time
,
matmul_v2_cost
.
cost
.
time
)
tensor_cost
=
cost_model
.
TensorCost
(
tensor
=
x
)
# check memory
self
.
assertEqual
(
tensor_cost
.
cost
.
memory
,
1600
)
def
test_comm_cost
(
self
):
desc
=
{}
desc
[
"op"
]
=
"c_allreduce_sum"
desc
[
"inputs"
]
=
{
"X"
:
[([
100
,
200
],
paddle
.
float32
)]}
allreduce_cost
=
cost_model
.
OP_COST_FACTORY
[
"c_allreduce_sum"
](
op_desc
=
desc
)
self
.
assertTrue
(
check_cost
(
allreduce_cost
.
cost
))
def
test_cost_estimator
(
self
):
train_program
=
paddle
.
static
.
Program
()
cost_estimator
=
cost_model
.
CostEstimator
(
train_program
)
self
.
assertIsNotNone
(
cost_estimator
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/setup.py.in
浏览文件 @
c1c9368f
...
@@ -307,6 +307,7 @@ packages=['paddle',
...
@@ -307,6 +307,7 @@ packages=['paddle',
'paddle.distributed.auto_parallel',
'paddle.distributed.auto_parallel',
'paddle.distributed.auto_parallel.operators',
'paddle.distributed.auto_parallel.operators',
'paddle.distributed.auto_parallel.tuner',
'paddle.distributed.auto_parallel.tuner',
'paddle.distributed.auto_parallel.cost',
'paddle.distributed.passes',
'paddle.distributed.passes',
'paddle.framework',
'paddle.framework',
'paddle.jit',
'paddle.jit',
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录