Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d6011cb6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d6011cb6
编写于
3月 28, 2023
作者:
C
caozhou
提交者:
GitHub
3月 28, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Add o1 level tune (#52041)
* add tune o1 level * add unittest
上级
418b983c
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
820 addition
and
721 deletion
+820
-721
python/paddle/distributed/auto_parallel/cluster.py
python/paddle/distributed/auto_parallel/cluster.py
+76
-14
python/paddle/distributed/auto_parallel/cost/base_cost.py
python/paddle/distributed/auto_parallel/cost/base_cost.py
+38
-25
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
+1
-1
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
+0
-640
python/paddle/distributed/auto_parallel/cost/estimate_cost.py
...on/paddle/distributed/auto_parallel/cost/estimate_cost.py
+82
-32
python/paddle/distributed/auto_parallel/reshard.py
python/paddle/distributed/auto_parallel/reshard.py
+2
-0
python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
...addle/distributed/auto_parallel/tuner/rule_based_tuner.py
+614
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py
...id/tests/unittests/auto_parallel/test_rule_based_tuner.py
+7
-9
未找到文件。
python/paddle/distributed/auto_parallel/cluster.py
浏览文件 @
d6011cb6
...
...
@@ -14,6 +14,7 @@
import
json
import
os
import
re
from
enum
import
IntEnum
,
unique
import
paddle
...
...
@@ -449,7 +450,6 @@ class Cluster:
npu_models
=
[
"NPU"
]
dcu_models
=
[
"DCU"
]
all_gpu_models
=
gpu_models
+
xpu_models
+
npu_models
+
dcu_models
assert
gpu_model
in
all_gpu_models
self
.
_num_devices_per_machine
=
device_count
def
_convert_to_type
(
gpu_model
):
...
...
@@ -462,6 +462,8 @@ class Cluster:
type
=
"NPU"
elif
gpu_model
in
dcu_models
:
type
=
"DCU"
else
:
type
=
"GPU"
assert
type
is
not
None
return
type
...
...
@@ -470,6 +472,12 @@ class Cluster:
model
=
None
if
gpu_model
==
"V100"
:
model
=
"Tesla V100-SXM2-"
+
str
(
gpu_memory
)
+
"GB"
elif
gpu_model
==
"A100"
:
model
=
"Tesla A100-SXM-"
+
str
(
gpu_memory
)
+
"GB"
elif
gpu_model
==
"A30"
:
model
=
"Tesla A30-SXM-"
+
str
(
gpu_memory
)
+
"GB"
else
:
model
=
gpu_model
+
str
(
gpu_memory
)
+
"GB"
assert
model
is
not
None
return
model
...
...
@@ -527,6 +535,8 @@ class Cluster:
device
[
"memory"
]
=
memory
device
[
"sp_gflops"
]
=
sp_gflops
device
[
"dp_gflops"
]
=
dp_gflops
# hard code
device
[
"type"
]
=
"GPU"
global_id_to_device_type
[
global_id
]
=
type
global_id_to_node
[
global_id
]
=
i
devices
.
append
(
device
)
...
...
@@ -820,13 +830,45 @@ class Cluster:
return
self
.
__str__
()
def
get_default_cluster
():
def
get_default_cluster
(
json_config
=
None
):
def
is_by_json_config
(
json_config
):
if
not
json_config
:
return
False
if
"cluster"
not
in
json_config
:
return
False
else
:
if
"path"
not
in
json_config
[
"cluster"
]:
if
"num_nodes"
not
in
json_config
[
"cluster"
]:
return
False
if
"num_gpus"
not
in
json_config
[
"cluster"
]:
return
False
if
"gpu_model"
not
in
json_config
[
"cluster"
]:
return
False
if
"gpu_memory"
not
in
json_config
[
"cluster"
]:
return
False
return
True
else
:
return
True
cluster
=
Cluster
()
if
json_config
and
is_by_json_config
(
json_config
):
# Get GPU info by json config
if
"path"
in
json_config
[
"cluster"
]:
cluster
.
build_from_file
(
json_config
[
"cluster"
][
"path"
])
return
cluster
else
:
node_count
=
json_config
[
"cluster"
][
"num_nodes"
]
local_device_count
=
json_config
[
"cluster"
][
"num_gpus"
]
gpu_model
=
json_config
[
"cluster"
][
"gpu_model"
]
memory
=
json_config
[
"cluster"
][
"gpu_memory"
]
else
:
# Get GPU info by get_device_properties
local_device_count
=
os
.
getenv
(
"PADDLE_LOCAL_SIZE"
)
if
local_device_count
is
None
:
local_device_count
=
1
else
:
local_device_count
=
int
(
local_device_count
)
global_device_count
=
os
.
getenv
(
"PADDLE_GLOBAL_SIZE"
)
if
global_device_count
is
None
:
node_count
=
1
...
...
@@ -834,16 +876,36 @@ def get_default_cluster():
global_device_count
=
int
(
global_device_count
)
assert
global_device_count
%
local_device_count
==
0
node_count
=
int
(
global_device_count
)
//
local_device_count
gpu_info
=
paddle
.
device
.
cuda
.
get_device_properties
()
assert
gpu_info
,
"Auto parallel just runs on gpu now."
gpu_name
=
gpu_info
.
name
try
:
re_result
=
re
.
split
(
r
'[ , -]'
,
gpu_name
)
gpu_model
=
re_result
[
1
]
memory
=
int
(
re_result
[
-
1
][:
-
2
])
except
:
memory
=
int
(
gpu_info
.
total_memory
)
//
(
1000
**
3
)
gpu_model
=
gpu_name
print
(
"Node Count: "
,
node_count
,
"Local Device Size: "
,
local_device_count
,
"GPU Model: "
,
gpu_model
,
"GPU Memory: "
,
memory
,
"World size: "
,
paddle
.
distributed
.
get_world_size
(),
flush
=
True
,
)
cluster
.
gen_default_config_cluster
(
node_count
=
node_count
,
device_count
=
local_device_count
node_count
=
node_count
,
device_count
=
local_device_count
,
gpu_model
=
gpu_model
,
gpu_memory
=
memory
,
)
return
cluster
python/paddle/distributed/auto_parallel/cost/base_cost.py
浏览文件 @
d6011cb6
...
...
@@ -16,6 +16,7 @@ from collections import OrderedDict
from
functools
import
reduce
import
paddle
from
paddle.utils.flops
import
flops
from
..cluster
import
LinkType
from
..dist_tensor
import
DistributedTensor
...
...
@@ -91,9 +92,10 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
output_desc
=
OrderedDict
()
# Get partitioned shape of input
input_var_desc
=
{}
for
input_name
in
op
.
input_names
:
var_name_list
=
op
.
input
(
input_name
)
var_desc
=
[]
input_var_desc
[
input_name
]
=
[]
for
var_name
in
var_name_list
:
var
=
get_var_with_recursion
(
var_name
,
op
.
block
,
op
.
block
.
program
...
...
@@ -112,7 +114,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
process
,
shard_sizes
,
)
var_desc
.
append
((
var
.
dtype
,
shape
)
)
input_var_desc
[
input_name
].
append
(
shape
)
# For special op such as embedding and its grad op
if
(
...
...
@@ -137,8 +139,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
relative_idx
=
relative_idx
*
per_part_size
desc
[
"attrs"
][
"start_index"
]
=
relative_idx
input_desc
[
input_name
]
=
var_desc
desc
[
"inputs"
]
=
input_desc
desc
[
"inputs"
]
=
input_var_desc
for
out_name
in
op
.
output_names
:
var_name_list
=
op
.
output
(
out_name
)
...
...
@@ -350,7 +351,9 @@ def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None):
return
desc
def
build_comm_costs_from_descs
(
op_cost_class
,
ctx
,
processes
,
descs
,
cluster
):
def
build_comm_costs_from_descs
(
op_cost_class
,
ctx
,
processes
,
descs
,
cluster
,
is_dp
=
False
):
"""Build comm costs by descriptions"""
comm_context
=
CommContext
(
cluster
)
group_ranks_list
=
[]
...
...
@@ -363,6 +366,8 @@ def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster):
comm_op_cost
=
op_cost_class
(
op_desc
=
desc
,
comm_context
=
comm_context
)
if
is_dp
:
comm_op_cost
.
cost
.
time
*=
0.9
comm_op_cost_list
.
append
(
comm_op_cost
)
return
comm_op_cost_list
...
...
@@ -389,6 +394,7 @@ def build_dp_costs(
vars
=
dist_op
.
serial_op
.
block
.
vars
var_name
=
var_names
[
0
]
has_found
=
False
is_input
=
True
for
name
in
dist_op
.
serial_op
.
input_arg_names
:
if
var_name
in
name
:
var_name
=
name
...
...
@@ -400,6 +406,7 @@ def build_dp_costs(
if
var_name
in
name
:
var_name
=
name
has_found
=
True
is_input
=
False
break
if
not
has_found
:
return
...
...
@@ -418,6 +425,7 @@ def build_dp_costs(
processes
,
c_allreduce_sum_descs
,
cluster
,
is_dp
=
True
,
)
result
.
append
(
comm_cost_list
)
...
...
@@ -431,22 +439,11 @@ def build_dp_costs(
desc
=
{}
desc
[
"op"
]
=
op_type
desc
[
"inputs"
]
=
{}
if
var_name
in
dist_attr
.
inputs_dist_attrs
:
dims_mapping
=
dist_attr
.
get_input_dims_mapping
(
var_name
)
elif
var_name
in
dist_attr
.
outputs_dist_attrs
:
dims_mapping
=
dist_attr
.
get_output_dims_mapping
(
var_name
)
else
:
raise
AssertionError
(
"cannot find dims_mapping for {} in {}"
.
format
(
var_name
,
dist_attr
)
dims_mapping
=
(
dist_attr
.
get_input_dims_mapping
(
var_name
)
if
is_input
else
dist_attr
.
get_output_dims_mapping
(
var_name
)
)
# dims_mapping = (
# dist_attr.get_input_dims_mapping(var_name)
# if dist_attr.get_input_dims_mapping(var_name) is not None
# else dist_attr.get_output_dims_mapping(var_name)
# )
var
=
get_var_with_recursion
(
var_name
,
dist_op
.
serial_op
.
block
,
...
...
@@ -493,8 +490,6 @@ class CommContext:
# if cluster has no info about those vars, it will be set by default
self
.
base_ring
=
None
self
.
base_tree
=
None
# self.base_inter_ring = None
# self.base_inter_tree = None
self
.
intra_ring
=
None
self
.
intra_tree
=
None
self
.
inter_ring
=
None
...
...
@@ -508,8 +503,6 @@ class CommContext:
# set default
self
.
base_ring
=
8.4
self
.
base_tree
=
0.0
# self.base_inter_ring = 9.6
# self.base_inter_tree = 28
# NVL in default
self
.
intra_ring
=
3.4
self
.
intra_tree
=
28
...
...
@@ -681,6 +674,8 @@ class Cost:
class
OpCost
:
OP_TYPE
=
"op"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
):
self
.
_op
=
op
self
.
_op_desc
=
op_desc
...
...
@@ -883,6 +878,24 @@ class CompOpCost(OpCost):
)
)
def
calc_flops
(
self
):
if
not
self
.
op_desc
:
return
0
if
"_grad"
in
self
.
__class__
.
OP_TYPE
:
op_type
=
self
.
__class__
.
OP_TYPE
[:
len
(
self
.
__class__
.
OP_TYPE
)
-
5
]
return
2
*
flops
(
op_type
,
self
.
op_desc
[
"inputs"
],
self
.
op_desc
[
"attrs"
]
)
return
flops
(
self
.
__class__
.
OP_TYPE
,
self
.
op_desc
[
"inputs"
],
self
.
op_desc
[
"attrs"
],
)
def
calc_time
(
self
):
flops_count
=
self
.
calc_flops
()
return
flops_count
*
2.9e-7
def
register_op_cost
(
cls
):
op_type
=
cls
.
OP_TYPE
...
...
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
浏览文件 @
d6011cb6
...
...
@@ -140,7 +140,7 @@ class IdentityOpCost(CommOpCost):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
def
calc_time
(
self
):
return
0
return
self
.
comm_count
*
1
/
(
144
*
1e3
)
@
register_op_cost
...
...
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
浏览文件 @
d6011cb6
...
...
@@ -22,15 +22,6 @@ class AdamOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
ArgsortOpCost
(
CompOpCost
):
...
...
@@ -39,15 +30,6 @@ class ArgsortOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
AssignOpCost
(
CompOpCost
):
...
...
@@ -56,15 +38,6 @@ class AssignOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
AssignValueOpCost
(
CompOpCost
):
...
...
@@ -73,15 +46,6 @@ class AssignValueOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
BeamSearchOpCost
(
CompOpCost
):
...
...
@@ -90,15 +54,6 @@ class BeamSearchOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
BeamSearchDecodeOpCost
(
CompOpCost
):
...
...
@@ -107,15 +62,6 @@ class BeamSearchDecodeOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
CastOpCost
(
CompOpCost
):
...
...
@@ -124,15 +70,6 @@ class CastOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
ConcatOpCost
(
CompOpCost
):
...
...
@@ -141,15 +78,6 @@ class ConcatOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
DropoutOpCost
(
CompOpCost
):
...
...
@@ -158,15 +86,6 @@ class DropoutOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
DropoutGradOpCost
(
CompOpCost
):
...
...
@@ -175,15 +94,6 @@ class DropoutGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
ElementwiseAddOpCost
(
CompOpCost
):
...
...
@@ -192,15 +102,6 @@ class ElementwiseAddOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
ElementwiseAddGradOpCost
(
CompOpCost
):
...
...
@@ -209,15 +110,6 @@ class ElementwiseAddGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
ElementwiseDivOpCost
(
CompOpCost
):
...
...
@@ -226,15 +118,6 @@ class ElementwiseDivOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
ElementwiseDivGradOpCost
(
CompOpCost
):
...
...
@@ -243,15 +126,6 @@ class ElementwiseDivGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
ElementwiseMulOpCost
(
CompOpCost
):
...
...
@@ -260,15 +134,6 @@ class ElementwiseMulOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
ElementwiseMulGradOpCost
(
CompOpCost
):
...
...
@@ -277,15 +142,6 @@ class ElementwiseMulGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
ElementwiseSubOpCost
(
CompOpCost
):
...
...
@@ -294,15 +150,6 @@ class ElementwiseSubOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
ElementwiseSubGradOpCost
(
CompOpCost
):
...
...
@@ -311,15 +158,6 @@ class ElementwiseSubGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
EqualOpCost
(
CompOpCost
):
...
...
@@ -328,15 +166,6 @@ class EqualOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
EmbeddingOpCost
(
CompOpCost
):
...
...
@@ -345,15 +174,6 @@ class EmbeddingOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
EmbeddingGradOpCost
(
CompOpCost
):
...
...
@@ -362,15 +182,6 @@ class EmbeddingGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
FillConstantOpCost
(
CompOpCost
):
...
...
@@ -379,15 +190,6 @@ class FillConstantOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
FillConstantBatchSizeLikeOpCost
(
CompOpCost
):
...
...
@@ -396,15 +198,6 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
FusedSoftmaxMaskUpperTriangleOpCost
(
CompOpCost
):
...
...
@@ -413,15 +206,6 @@ class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
FusedSoftmaxMaskUpperTriangleGradOpCost
(
CompOpCost
):
...
...
@@ -430,15 +214,6 @@ class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
GatherOpCost
(
CompOpCost
):
...
...
@@ -447,15 +222,6 @@ class GatherOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
GeluOpCost
(
CompOpCost
):
...
...
@@ -464,15 +230,6 @@ class GeluOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
GeluGradOpCost
(
CompOpCost
):
...
...
@@ -481,15 +238,6 @@ class GeluGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
GreaterEqualOpCost
(
CompOpCost
):
...
...
@@ -498,15 +246,6 @@ class GreaterEqualOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
IncrementOpCost
(
CompOpCost
):
...
...
@@ -515,11 +254,6 @@ class IncrementOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
IsEmptyOpCost
(
CompOpCost
):
...
...
@@ -528,11 +262,6 @@ class IsEmptyOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
LayerNormOpCost
(
CompOpCost
):
...
...
@@ -541,15 +270,6 @@ class LayerNormOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
LayerNormGradOpCost
(
CompOpCost
):
...
...
@@ -558,15 +278,6 @@ class LayerNormGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
LessThanOpCost
(
CompOpCost
):
...
...
@@ -575,15 +286,6 @@ class LessThanOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
LogicalNotOpCost
(
CompOpCost
):
...
...
@@ -592,15 +294,6 @@ class LogicalNotOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
LogicalAndOpCost
(
CompOpCost
):
...
...
@@ -609,15 +302,6 @@ class LogicalAndOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
LodResetOpCost
(
CompOpCost
):
...
...
@@ -626,15 +310,6 @@ class LodResetOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
LogOpCost
(
CompOpCost
):
...
...
@@ -643,15 +318,6 @@ class LogOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
LookupTableV2OpCost
(
CompOpCost
):
...
...
@@ -660,15 +326,6 @@ class LookupTableV2OpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
LookupTableV2GradOpCost
(
CompOpCost
):
...
...
@@ -677,15 +334,6 @@ class LookupTableV2GradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
MatmulOpCost
(
CompOpCost
):
...
...
@@ -694,15 +342,6 @@ class MatmulOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
MatmulGradOpCost
(
CompOpCost
):
...
...
@@ -711,15 +350,6 @@ class MatmulGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
MatmulV2OpCost
(
CompOpCost
):
...
...
@@ -728,15 +358,6 @@ class MatmulV2OpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
MatmulV2GradOpCost
(
CompOpCost
):
...
...
@@ -745,15 +366,6 @@ class MatmulV2GradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
MemcpyOpCost
(
CompOpCost
):
...
...
@@ -762,15 +374,6 @@ class MemcpyOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
MulOpCost
(
CompOpCost
):
...
...
@@ -779,15 +382,6 @@ class MulOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
MulGradOpCost
(
CompOpCost
):
...
...
@@ -796,15 +390,6 @@ class MulGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
OneHotOpCost
(
CompOpCost
):
...
...
@@ -813,15 +398,6 @@ class OneHotOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
ReadFromArrayOpCost
(
CompOpCost
):
...
...
@@ -830,15 +406,6 @@ class ReadFromArrayOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
ReduceSumOpCost
(
CompOpCost
):
...
...
@@ -847,15 +414,6 @@ class ReduceSumOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
ReduceSumGradOpCost
(
CompOpCost
):
...
...
@@ -864,15 +422,6 @@ class ReduceSumGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
Reshape2OpCost
(
CompOpCost
):
...
...
@@ -881,15 +430,6 @@ class Reshape2OpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
Reshape2GradOpCost
(
CompOpCost
):
...
...
@@ -898,15 +438,6 @@ class Reshape2GradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
ReduceMeanOpCost
(
CompOpCost
):
...
...
@@ -915,15 +446,6 @@ class ReduceMeanOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
ReduceMeanGradOpCost
(
CompOpCost
):
...
...
@@ -932,15 +454,6 @@ class ReduceMeanGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
SamplingIdOpCost
(
CompOpCost
):
...
...
@@ -949,15 +462,6 @@ class SamplingIdOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
ScaleOpCost
(
CompOpCost
):
...
...
@@ -966,15 +470,6 @@ class ScaleOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
SliceOpCost
(
CompOpCost
):
...
...
@@ -983,15 +478,6 @@ class SliceOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
SoftmaxOpCost
(
CompOpCost
):
...
...
@@ -1000,15 +486,6 @@ class SoftmaxOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
SoftmaxGradOpCost
(
CompOpCost
):
...
...
@@ -1017,15 +494,6 @@ class SoftmaxGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
SoftmaxWithCrossEntropyOpCost
(
CompOpCost
):
...
...
@@ -1034,15 +502,6 @@ class SoftmaxWithCrossEntropyOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
SoftmaxWithCrossEntropyGradOpCost
(
CompOpCost
):
...
...
@@ -1051,15 +510,6 @@ class SoftmaxWithCrossEntropyGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
SplitOpCost
(
CompOpCost
):
...
...
@@ -1068,15 +518,6 @@ class SplitOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
Squeeze2OpCost
(
CompOpCost
):
...
...
@@ -1085,15 +526,6 @@ class Squeeze2OpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
SquareOpCost
(
CompOpCost
):
...
...
@@ -1102,15 +534,6 @@ class SquareOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
SquareGradOpCost
(
CompOpCost
):
...
...
@@ -1119,15 +542,6 @@ class SquareGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
SumOpCost
(
CompOpCost
):
...
...
@@ -1136,15 +550,6 @@ class SumOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
TopKOpCost
(
CompOpCost
):
...
...
@@ -1153,15 +558,6 @@ class TopKOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
Transpose2OpCost
(
CompOpCost
):
...
...
@@ -1170,15 +566,6 @@ class Transpose2OpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
Transpose2GradOpCost
(
CompOpCost
):
...
...
@@ -1187,15 +574,6 @@ class Transpose2GradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
Unsqueeze2OpCost
(
CompOpCost
):
...
...
@@ -1204,15 +582,6 @@ class Unsqueeze2OpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
class
WriteToArrayOpCost
(
CompOpCost
):
...
...
@@ -1220,12 +589,3 @@ class WriteToArrayOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
python/paddle/distributed/auto_parallel/cost/estimate_cost.py
浏览文件 @
d6011cb6
...
...
@@ -189,6 +189,9 @@ class CostEstimator:
# Calc dist op cost
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_op
:
continue
op_dist_attr
=
dist_op
.
dist_attr
processes
=
op_dist_attr
.
process_mesh
.
process_ids
...
...
@@ -225,6 +228,8 @@ class CostEstimator:
for
rank
in
group_ranks
:
self
.
local_cost
(
rank
).
time
=
(
max_time
+
comm_op_cost
.
time
if
op
.
attr
(
'op_role'
)
!=
OpRole
.
Backward
else
max_time
+
0.9
*
comm_op_cost
.
time
)
if
rank
not
in
self
.
_bubble_time_mapping
:
self
.
_bubble_time_mapping
[
rank
]
=
0
...
...
@@ -290,6 +295,7 @@ class CostEstimator:
self
.
_ordered_ops
.
append
([
op
.
desc
.
id
(),
op
])
self
.
_ordered_ops
.
sort
(
key
=
lambda
x
:
x
[
0
])
parameters
=
set
()
for
op_id
,
op
in
self
.
_ordered_ops
:
if
op
.
type
in
[
"create_py_reader"
,
...
...
@@ -298,11 +304,14 @@ class CostEstimator:
]:
continue
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_op
:
continue
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
for
var_name
in
op
.
input_arg_names
:
input_dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
var_name
)
if
var_name
not
in
var_info
:
var_info
[
var_name
]
=
{}
key
=
_convert_pm_and_dm_to_str
(
...
...
@@ -311,6 +320,10 @@ class CostEstimator:
if
key
not
in
var_info
[
var_name
]:
var_info
[
var_name
][
key
]
=
{}
# It is even partition now
if
"position"
not
in
var_info
[
var_name
][
key
]:
var_info
[
var_name
][
key
][
"position"
]
=
[]
var_info
[
var_name
][
key
][
"position"
].
append
(
op_id
)
if
"memory"
not
in
var_info
[
var_name
][
key
]:
var
=
dist_op
.
get_serial_input
(
var_name
)
global_sizes
=
var
.
shape
...
...
@@ -324,9 +337,16 @@ class CostEstimator:
var_info
[
var_name
][
key
][
"memory"
]
=
self
.
_calculate_bytes
(
sizes
,
dtype
)
if
"position"
not
in
var_info
[
var_name
][
key
]:
var_info
[
var_name
][
key
][
"position"
]
=
[]
var_info
[
var_name
][
key
][
"position"
].
append
(
op_id
)
if
var
.
persistable
:
name
=
var_name
+
key
if
name
not
in
parameters
:
parameters
.
add
(
name
)
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
memories
:
memories
[
process
]
=
0
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
for
var_name
in
op
.
output_arg_names
:
output_dims_mapping
=
dist_op
.
dist_attr
.
get_output_dims_mapping
(
...
...
@@ -339,6 +359,10 @@ class CostEstimator:
)
if
key
not
in
var_info
[
var_name
]:
var_info
[
var_name
][
key
]
=
{}
if
"position"
not
in
var_info
[
var_name
][
key
]:
var_info
[
var_name
][
key
][
"position"
]
=
[]
var_info
[
var_name
][
key
][
"position"
].
append
(
op_id
)
if
"memory"
not
in
var_info
[
var_name
][
key
]:
var
=
dist_op
.
get_serial_output
(
var_name
)
global_sizes
=
var
.
shape
...
...
@@ -352,11 +376,19 @@ class CostEstimator:
var_info
[
var_name
][
key
][
"memory"
]
=
self
.
_calculate_bytes
(
sizes
,
dtype
)
if
"position"
not
in
var_info
[
var_name
][
key
]:
var_info
[
var_name
][
key
][
"position"
]
=
[]
var_info
[
var_name
][
key
][
"position"
].
append
(
op_id
)
if
var
.
persistable
:
name
=
var_name
+
key
if
name
not
in
parameters
:
parameters
.
add
(
name
)
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
memories
:
memories
[
process
]
=
0
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
has_used_vars
=
set
()
not_calc_vars
=
set
()
for
op_id
,
op
in
self
.
_ordered_ops
:
if
op
.
type
in
[
"create_py_reader"
,
...
...
@@ -367,6 +399,8 @@ class CostEstimator:
can_free_memories
=
{}
can_free_vars
=
set
()
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_op
:
continue
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
for
var_name
in
op
.
input_arg_names
:
input_dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
...
...
@@ -378,24 +412,30 @@ class CostEstimator:
has_used_var
=
var_name
+
key
var
=
dist_op
.
get_serial_input
(
var_name
)
# Not used
if
var_name
+
key
not
in
has_used_vars
:
if
(
has_used_var
not
in
has_used_vars
and
has_used_var
not
in
parameters
):
if
has_used_var
in
not_calc_vars
:
continue
has_used_vars
.
add
(
has_used_var
)
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
memories
:
memories
[
process
]
=
0
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
# Used
else
:
if
op_id
==
var_info
[
var_name
][
key
][
"position"
][
-
1
]:
if
has_used_var
not
in
can_free_vars
:
if
(
has_used_var
not
in
can_free_vars
and
not
var
.
persistable
):
can_free_vars
.
add
(
has_used_var
)
if
not
var
.
persistable
:
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
can_free_memories
:
can_free_memories
[
process
]
=
0
can_free_memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
can_free_memories
[
process
]
+=
var_info
[
var_name
]
[
key
][
"memory"
]
for
var_name
in
op
.
output_arg_names
:
output_dims_mapping
=
dist_op
.
dist_attr
.
get_output_dims_mapping
(
...
...
@@ -406,25 +446,36 @@ class CostEstimator:
)
has_used_var
=
var_name
+
key
var
=
dist_op
.
get_serial_output
(
var_name
)
if
(
op
.
type
==
"reshape2"
or
op
.
type
==
"transpose2"
or
op
.
type
==
"elementwise_add"
):
not_calc_vars
.
add
(
has_used_var
)
continue
# Not used
if
var_name
+
key
not
in
has_used_vars
:
if
(
has_used_var
not
in
has_used_vars
and
has_used_var
not
in
parameters
):
has_used_vars
.
add
(
has_used_var
)
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
memories
:
memories
[
process
]
=
0
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
# Used
else
:
if
op_id
==
var_info
[
var_name
][
key
][
"position"
][
-
1
]:
if
has_used_var
not
in
can_free_vars
:
if
(
has_used_var
not
in
can_free_vars
and
not
var
.
persistable
):
can_free_vars
.
add
(
has_used_var
)
if
not
var
.
persistable
:
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
can_free_memories
:
can_free_memories
[
process
]
=
0
can_free_memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
can_free_memories
[
process
]
+=
var_info
[
var_name
]
[
key
][
"memory"
]
# Calc peak memory
for
process
in
memories
:
...
...
@@ -433,7 +484,6 @@ class CostEstimator:
else
:
if
memories
[
process
]
>
self
.
max_memories
[
process
]:
self
.
max_memories
[
process
]
=
memories
[
process
]
# Free memory
for
process
in
can_free_memories
:
if
process
in
memories
:
...
...
@@ -513,7 +563,7 @@ class CostEstimator:
# Padding automatically
max_len
=
0
header
=
[
"Execution Time(
m
s)"
,
"Max Memory(MiB)"
]
header
=
[
"Execution Time(
u
s)"
,
"Max Memory(MiB)"
]
vals
=
[
round
(
self
.
global_cost
.
time
,
3
),
int
(
self
.
max_memory
//
1e6
)]
for
memory
in
vals
+
header
:
if
len
(
str
(
memory
))
>
max_len
:
...
...
python/paddle/distributed/auto_parallel/reshard.py
浏览文件 @
d6011cb6
...
...
@@ -2716,6 +2716,8 @@ class Resharder:
)
# simplified processing: ignore union process mesh and output reshard
dist_op
=
self
.
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_tensor
or
not
dist_op
:
return
reshard_op_cost
dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
tensor
.
name
)
...
...
python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
浏览文件 @
d6011cb6
...
...
@@ -16,17 +16,31 @@ import copy
import
logging
import
math
import
os
import
pickle
import
sys
import
time
from
abc
import
abstractmethod
from
collections
import
OrderedDict
from
functools
import
reduce
import
numpy
as
np
import
paddle
from
paddle.distributed.auto_parallel.cluster_v2
import
DeviceMesh
from
paddle.distributed.auto_parallel.completion
import
Completer
from
paddle.distributed.auto_parallel.cost
import
CostEstimator
from
paddle.distributed.auto_parallel.dist_attribute
import
(
OperatorDistAttr
,
TensorDistAttr
,
)
from
paddle.distributed.auto_parallel.dist_context
import
DistributedContext
from
paddle.distributed.auto_parallel.dist_tensor
import
DistributedTensor
from
paddle.distributed.auto_parallel.process_mesh
import
ProcessMesh
from
paddle.distributed.auto_parallel.utils
import
(
is_gradient_clip_op
,
print_program_with_dist_attr
,
)
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
from
paddle.fluid
import
program_guard
from
paddle.fluid.backward
import
append_backward
from
paddle.fluid.framework
import
Parameter
,
unique_name
...
...
@@ -1610,3 +1624,603 @@ class RuleBasedTuner:
idx
][
parallelism
][
key
]
self
.
_complete_sub_bwd_program
(
sub_program_dist_context
)
def
_complete_sub_update_program
(
self
,
sub_program_dist_context
):
"""
Complete the opt OP according to the tensor.
Most of the logic is the same as the update completion in the completer.
"""
world_ranks
=
ProcessMesh
(
[
i
for
i
in
range
(
self
.
_cluster
.
get_num_machines
()
*
self
.
_cluster
.
_num_devices_per_machine
)
]
)
dist_tensors
=
sub_program_dist_context
.
_dist_tensors_for_program
vars
=
self
.
full_main_program
.
global_block
().
vars
ops
=
self
.
full_main_program
.
global_block
().
ops
learning_rate_completed
=
False
for
idx
in
range
(
len
(
ops
)):
op
=
ops
[
idx
]
if
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Optimize
):
if
is_gradient_clip_op
(
op
):
if
op
.
type
in
[
"sum"
,
"sqrt"
,
"fill_constant"
,
"elementwise_max"
,
"elementwise_div"
,
]:
op_dist_attr
=
OperatorDistAttr
()
op_dist_attr
.
process_mesh
=
world_ranks
for
in_name
in
op
.
input_arg_names
:
in_var
=
vars
[
in_name
]
if
in_var
.
desc
.
original_id
()
in
dist_tensors
:
in_dist_attr
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
in_var
)
op_dist_attr
.
set_input_dist_attr
(
in_name
,
in_dist_attr
)
else
:
in_dist_attr
=
TensorDistAttr
()
in_dist_attr
.
process_mesh
=
world_ranks
in_dist_attr
.
dims_mapping
=
[
-
1
for
_
in
range
(
len
(
in_var
.
shape
))
]
op_dist_attr
.
set_input_dist_attr
(
in_name
,
in_dist_attr
)
sub_program_dist_context
.
set_tensor_dist_attr_for_program
(
in_var
,
in_dist_attr
)
for
out_name
in
op
.
output_arg_names
:
out_var
=
vars
[
out_name
]
if
out_var
.
desc
.
original_id
()
in
dist_tensors
:
out_dist_attr
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
out_var
)
op_dist_attr
.
set_output_dist_attr
(
out_name
,
out_dist_attr
)
else
:
out_dist_attr
=
TensorDistAttr
()
out_dist_attr
.
process_mesh
=
world_ranks
out_dist_attr
.
dims_mapping
=
[
-
1
for
_
in
range
(
len
(
out_var
.
shape
))
]
sub_program_dist_context
.
set_tensor_dist_attr_for_program
(
out_var
,
out_dist_attr
)
op_dist_attr
.
set_output_dist_attr
(
out_name
,
out_dist_attr
)
sub_program_dist_context
.
set_op_dist_attr_for_program
(
op
,
op_dist_attr
)
else
:
in_var
=
vars
[
op
.
input
(
"X"
)[
0
]]
if
in_var
.
desc
.
original_id
()
in
dist_tensors
:
in_dist_attr
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
in_var
)
assert
in_dist_attr
is
not
None
ref_process_mesh
=
in_dist_attr
.
process_mesh
ref_dims_mapping
=
in_dist_attr
.
dims_mapping
if
(
op
.
type
==
"cast"
and
ops
[
idx
+
1
].
type
==
"elementwise_mul"
):
ref_var
=
vars
[
ops
[
idx
+
1
].
input
(
"X"
)[
0
]]
ref_dist_attr
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
ref_var
)
assert
ref_dist_attr
is
not
None
ref_process_mesh
=
ref_dist_attr
.
process_mesh
out_var
=
vars
[
op
.
output
(
"Out"
)[
0
]]
out_dist_attr
=
TensorDistAttr
()
out_dist_attr
.
process_mesh
=
ref_process_mesh
if
out_var
.
shape
==
in_var
.
shape
:
out_dist_attr
.
dims_mapping
=
ref_dims_mapping
else
:
assert
(
len
(
out_var
.
shape
)
==
1
and
out_var
.
shape
[
0
]
==
1
)
out_dist_attr
.
dims_mapping
=
[
-
1
]
sub_program_dist_context
.
set_tensor_dist_attr_for_program
(
out_var
,
out_dist_attr
)
op_dist_attr
=
OperatorDistAttr
()
op_dist_attr
.
process_mesh
=
ref_process_mesh
for
in_name
in
op
.
input_arg_names
:
in_var
=
vars
[
in_name
]
in_dist_attr
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
in_var
)
op_dist_attr
.
set_input_dims_mapping
(
in_name
,
in_dist_attr
.
dims_mapping
)
for
out_name
in
op
.
output_arg_names
:
out_var
=
vars
[
out_name
]
out_dist_attr
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
out_var
)
op_dist_attr
.
set_output_dims_mapping
(
out_name
,
out_dist_attr
.
dims_mapping
)
op_dist_attr
.
set_input_dist_attr
(
in_var
.
name
,
in_dist_attr
)
op_dist_attr
.
set_output_dist_attr
(
out_var
.
name
,
out_dist_attr
)
sub_program_dist_context
.
set_op_dist_attr_for_program
(
op
,
op_dist_attr
)
else
:
continue
if
"Grad"
in
op
.
input_names
and
"Param"
in
ops
[
idx
].
input_names
:
assert
(
len
(
op
.
input
(
"Param"
))
==
1
),
"Only support one-to-one now."
assert
(
len
(
op
.
input
(
"Grad"
))
==
1
),
"Only support one-to-one now."
param
=
vars
[
op
.
input
(
"Param"
)[
0
]]
grad_var
=
vars
[
op
.
input
(
"Grad"
)[
0
]]
if
param
.
desc
.
original_id
()
in
dist_tensors
:
param_dist_attr
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
param
)
assert
param_dist_attr
is
not
None
ref_process_mesh
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
param
).
process_mesh
assert
ref_process_mesh
is
not
None
ref_dims_mapping
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
param
).
dims_mapping
assert
ref_dims_mapping
is
not
None
op_dist_attr
=
OperatorDistAttr
()
op_dist_attr
.
process_mesh
=
ref_process_mesh
op_dist_attr
.
set_input_dims_mapping
(
grad_var
.
name
,
ref_dims_mapping
)
op_dist_attr
.
set_input_dims_mapping
(
param
.
name
,
ref_dims_mapping
)
op_dist_attr
.
set_output_dims_mapping
(
param
.
name
,
ref_dims_mapping
)
learning_var
=
vars
[
op
.
input
(
"LearningRate"
)[
0
]]
op_dist_attr
.
set_input_dims_mapping
(
learning_var
.
name
,
[
-
1
]
)
op_dist_attr
.
set_output_dims_mapping
(
learning_var
.
name
,
[
-
1
]
)
if
not
learning_rate_completed
:
learning_rate_completed
=
True
var_dist_attr
=
TensorDistAttr
()
var_dist_attr
.
process_mesh
=
world_ranks
var_dist_attr
.
dims_mapping
=
[
-
1
]
sub_program_dist_context
.
set_tensor_dist_attr_for_program
(
learning_var
,
var_dist_attr
)
for
input_name
in
op
.
desc
.
input_names
():
if
input_name
in
[
'Param'
,
'Grad'
,
'LearningRate'
,
"SkipUpdate"
,
"Beta1Tensor"
,
"Beta2Tensor"
,
"EpsilonTensor"
,
]:
continue
if
len
(
op
.
desc
.
input
(
input_name
))
==
0
:
continue
assert
len
(
op
.
desc
.
input
(
input_name
))
==
1
input_var
=
vars
[
op
.
desc
.
input
(
input_name
)[
0
]]
input_var_attr
=
TensorDistAttr
()
if
(
"Beta1Pow"
in
input_name
or
"Beta2Pow"
in
input_name
):
input_var_attr
.
dims_mapping
=
[
-
1
]
op_dist_attr
.
set_input_dims_mapping
(
input_var
.
name
,
[
-
1
]
)
op_dist_attr
.
set_output_dims_mapping
(
input_var
.
name
,
[
-
1
]
)
else
:
input_var_attr
.
dims_mapping
=
ref_dims_mapping
op_dist_attr
.
set_input_dims_mapping
(
input_var
.
name
,
ref_dims_mapping
)
op_dist_attr
.
set_output_dims_mapping
(
input_var
.
name
,
ref_dims_mapping
)
input_var_attr
.
process_mesh
=
ref_process_mesh
sub_program_dist_context
.
set_tensor_dist_attr_for_program
(
input_var
,
input_var_attr
)
sub_program_dist_context
.
set_op_dist_attr_for_program
(
op
,
op_dist_attr
)
continue
else
:
continue
def
complete_sub_update_programs
(
self
):
for
idx
in
self
.
sub_programs_dist_context
:
for
parallelism
in
self
.
sub_programs_dist_context
[
idx
]:
for
key
in
self
.
sub_programs_dist_context
[
idx
][
parallelism
]:
sub_program_dist_context
=
self
.
sub_programs_dist_context
[
idx
][
parallelism
][
key
]
self
.
_complete_sub_update_program
(
sub_program_dist_context
)
def
convert_device_mesh_to_key
(
self
,
device_mesh
):
"""Convert device mesh object to str."""
processes
=
","
.
join
([
str
(
x
)
for
x
in
device_mesh
.
device_ids
])
topology
=
","
.
join
([
str
(
x
)
for
x
in
device_mesh
.
shape
])
key
=
processes
+
";"
+
topology
return
key
def
_get_sub_program_cost
(
self
,
dist_context
):
"""Estimate the cost of dist context."""
cost_estimator
=
CostEstimator
(
self
.
full_main_program
,
self
.
_cluster
)
global_cost
=
cost_estimator
.
estimate
(
dist_context
)
max_memory
=
cost_estimator
.
_estimate_max_memory_by_dist_op
(
dist_context
)
return
global_cost
.
time
,
max_memory
def
combine_dist_contexts
(
self
,
dist_contexts
):
"""Combine the dist attr in dist contexts to one dist context."""
combined_dist_context
=
DistributedContext
()
# set dist tensor, pay attention to shared param or var as input for multi op
for
dist_context
in
dist_contexts
:
for
tensor_id
in
dist_context
.
_dist_tensors_for_program
:
dist_tensor
=
dist_context
.
_dist_tensors_for_program
[
tensor_id
]
if
(
tensor_id
not
in
combined_dist_context
.
_dist_tensors_for_program
):
combined_dist_context
.
add_dist_tensor_for_program
(
dist_tensor
)
# set dist op
for
op_id
in
dist_context
.
_dist_ops_for_program
:
dist_op
=
dist_context
.
_dist_ops_for_program
[
op_id
]
combined_dist_context
.
add_dist_op_for_program
(
dist_op
)
for
process_mesh
in
dist_context
.
process_meshes
:
combined_dist_context
.
add_process_mesh
(
process_mesh
)
return
combined_dist_context
def
prepare
(
self
):
"""Prepare the sub program, tensor dist attr setting, device meshes and so on that tuner need."""
# step1: cluster operators to layers
begin
=
time
.
time
()
self
.
layers
=
self
.
cluster_operators
()
end
=
time
.
time
()
self
.
_logger
.
info
(
"Cluster operators to {} layers in {}s."
.
format
(
len
(
self
.
layers
),
end
-
begin
)
)
# step2: generate sub program of each layer
begin
=
time
.
time
()
self
.
gen_fwd_sub_programs_by_clone
()
end
=
time
.
time
()
self
.
_logger
.
info
(
"Generate programs of every layer in {}s."
.
format
(
end
-
begin
)
)
# step3: partition devices to device meshes
begin
=
time
.
time
()
n
,
m
=
(
self
.
_cluster
.
get_num_machines
(),
self
.
_cluster
.
_num_devices_per_machine
,
)
device_meshes_list
=
ClusterPartitionUtil
.
partition_cluster
(
n
,
m
)
end
=
time
.
time
()
self
.
_logger
.
info
(
"Partition cluster in {}s."
.
format
(
end
-
begin
))
# step4: transform device mesh to process meshes
dm_idx
=
0
for
device_meshes
in
device_meshes_list
:
has_used_devices
=
0
self
.
device_meshes_list
.
append
([])
for
device_mesh
in
device_meshes
:
devices
=
reduce
(
lambda
x
,
y
:
x
*
y
,
device_mesh
)
processes
=
[
i
for
i
in
range
(
has_used_devices
,
has_used_devices
+
devices
)
]
device_mesh_shape
=
(
device_mesh
if
device_mesh
[
0
]
!=
1
else
[
device_mesh
[
i
]
for
i
in
range
(
1
,
len
(
device_mesh
))]
)
self
.
device_meshes_list
[
-
1
].
append
(
DeviceMesh
(
mesh
=
np
.
array
(
processes
)
.
reshape
(
device_mesh_shape
)
.
tolist
(),
name
=
"device_mesh_"
+
str
(
dm_idx
),
)
)
dm_idx
+=
1
has_used_devices
+=
devices
process_mesh_shapes
=
convert_to_process_meshes
(
device_mesh
)
for
process_mesh_shape
in
process_mesh_shapes
:
process_mesh
=
ProcessMesh
(
np
.
array
(
processes
).
reshape
(
process_mesh_shape
).
tolist
()
)
if
process_mesh
not
in
self
.
process_meshes
:
self
.
process_meshes
.
append
(
process_mesh
)
# step5: generate full program
begin
=
time
.
time
()
self
.
gen_full_program
()
end
=
time
.
time
()
self
.
_logger
.
info
(
"Generate full program in {}s."
.
format
(
end
-
begin
))
# step6: complete forward sub programs
begin
=
time
.
time
()
for
process_mesh
in
self
.
process_meshes
:
self
.
complete_sub_fwd_programs
(
process_mesh
)
end
=
time
.
time
()
self
.
_logger
.
info
(
"Complete all sub forward programs in {}s."
.
format
(
end
-
begin
)
)
if
self
.
mode
==
"train"
:
# step7: complete backward sub programs
begin
=
time
.
time
()
self
.
complete_sub_bwd_programs
()
end
=
time
.
time
()
self
.
_logger
.
info
(
"Complete all sub backward programs in {}s."
.
format
(
end
-
begin
)
)
# step8: complete update sub programs
begin
=
time
.
time
()
self
.
complete_sub_update_programs
()
end
=
time
.
time
()
self
.
_logger
.
info
(
"Complete all sub update programs in {}s."
.
format
(
end
-
begin
)
)
def
tune_o1
(
self
):
"""The o1 level tuning."""
best_cost
=
sys
.
maxsize
best_dist_context
=
None
for
device_meshes
in
self
.
device_meshes_list
:
pp_stages
=
len
(
device_meshes
)
average_layers
=
len
(
self
.
layers
)
//
pp_stages
device_mesh_shape
=
device_meshes
[
0
].
shape
if
len
(
device_mesh_shape
)
==
1
:
device_mesh_shape
.
insert
(
0
,
1
)
process_mesh_shapes
=
convert_to_process_meshes
(
device_mesh_shape
)
# For example, device_mesh is [1, 8] and process_mesh is [8].
# The selective parallelism is dp or mp
# Get dp8 or mp8 cost and compare them to get best sreategy.
for
parallelism
in
[
"dp"
,
"mp"
,
"dp_mp"
,
"mp_dp"
]:
for
process_mesh_shape
in
process_mesh_shapes
:
dist_context_of_device_meshes
=
None
for
idx
,
device_mesh
in
enumerate
(
device_meshes
):
device_mesh_shape
=
device_mesh
.
shape
process_mesh
=
ProcessMesh
(
np
.
array
(
device_mesh
.
device_ids
)
.
reshape
(
process_mesh_shape
)
.
tolist
()
)
selective_parallelisms
=
(
[
"dp"
,
"mp"
]
if
len
(
process_mesh
.
shape
)
==
1
else
[
"dp_mp"
,
"mp_dp"
]
)
if
parallelism
not
in
selective_parallelisms
:
total_cost_of_device_meshes
=
sys
.
maxsize
continue
key
=
self
.
convert_process_mesh_to_key
(
process_mesh
)
if
idx
==
len
(
device_meshes
)
-
1
:
start
=
idx
*
average_layers
end
=
len
(
self
.
layers
)
else
:
start
=
idx
*
average_layers
end
=
(
idx
+
1
)
*
average_layers
dist_context
=
self
.
combine_dist_contexts
(
[
self
.
sub_programs_dist_context
[
j
][
parallelism
][
key
]
for
j
in
range
(
start
,
end
)
]
)
dist_context_of_device_meshes
=
(
dist_context
if
dist_context_of_device_meshes
is
None
else
self
.
combine_dist_contexts
(
[
dist_context_of_device_meshes
,
dist_context
]
)
)
if
dist_context_of_device_meshes
is
not
None
:
cost
,
memory
=
self
.
_get_sub_program_cost
(
dist_context_of_device_meshes
)
self
.
_logger
.
info
(
"Cost Model: The max memory is {}GB and cost is {} when {} parallelism under process mesh shape {} on {} stages."
.
format
(
memory
/
(
1024
**
3
),
cost
,
parallelism
,
process_mesh_shape
,
len
(
device_meshes
),
)
)
# 15% buffer is reserved for memory cost
if
memory
>
0.85
*
self
.
cluster
.
machines
[
0
].
devices
[
0
].
memory
*
(
1024
**
3
):
cost
=
sys
.
maxsize
if
cost
<
best_cost
:
best_cost
=
cost
best_dist_context
=
dist_context_of_device_meshes
self
.
_logger
.
info
(
"O1 level: a better strategy has be found that parallelism is {} under process mesh shape {} on {} stages with max memory {}GB."
.
format
(
parallelism
,
process_mesh_shape
,
len
(
device_meshes
),
memory
/
(
1024
**
3
),
)
)
return
best_dist_context
def
tune_o2
(
self
):
return
None
def
save_strategy
(
self
,
best_dist_context
,
path
):
dist_attrs
=
{
"tensor"
:
{},
"op"
:
{},
"process_meshes"
:
[]}
for
key
in
best_dist_context
.
_dist_tensors_for_program
:
if
key
in
self
.
_dist_context
.
_dist_tensors_for_program
:
dist_tensor
=
best_dist_context
.
_dist_tensors_for_program
[
key
]
dist_attrs
[
"tensor"
][
key
]
=
dist_tensor
.
dist_attr
.
serialize_to_string
()
assert
dist_attrs
[
"tensor"
],
"Tensor dist attrs must not be None."
for
key
in
best_dist_context
.
_dist_ops_for_program
:
if
key
in
self
.
_dist_context
.
_dist_ops_for_program
:
dist_op
=
best_dist_context
.
_dist_ops_for_program
[
key
]
dist_attrs
[
"op"
][
key
]
=
dist_op
.
dist_attr
.
serialize_to_string
()
assert
dist_attrs
[
"op"
],
"Op dist attrs must not be None."
for
process_mesh
in
best_dist_context
.
_process_meshes
:
process_ids
=
process_mesh
.
process_ids
process_shape
=
process_mesh
.
shape
dist_attrs
[
"process_meshes"
].
append
([
process_ids
,
process_shape
])
dist_attrs
[
"cluster"
]
=
self
.
_cluster
with
open
(
path
,
'wb'
)
as
f
:
pickle
.
dump
(
dist_attrs
,
f
)
self
.
_logger
.
info
(
"The strategy has been saved at {}"
.
format
(
path
))
def
run_or_quit
(
self
):
# Quit if just tune
if
not
self
.
_is_run
:
self
.
_logger
.
info
(
"The process will be quitted when just tune not run."
)
quit
()
def
tune
(
self
):
begin
=
time
.
time
()
self
.
match_program
(
self
.
_dist_context
.
serial_main_program
)
end
=
time
.
time
()
self
.
_logger
.
info
(
"Pattern match in {}s."
.
format
(
end
-
begin
))
if
self
.
_use_dp
:
completer
=
Completer
(
self
.
_dist_context
)
completer
.
complete_forward_annotation
()
print_program_with_dist_attr
(
self
.
_dist_context
.
serial_main_program
,
self
.
_dist_context
)
# Save strategy if need
path
=
self
.
_strategy_path
if
path
:
self
.
save_strategy
(
self
.
_dist_context
,
path
)
self
.
run_or_quit
()
return
# prepare
self
.
prepare
()
best_dist_context
=
None
if
self
.
level
==
"o2"
:
best_dist_context
=
self
.
tune_o2
()
elif
self
.
level
==
"o1"
:
# If level is o1, it means all layers within same parallelism.
# When in pipeline parallism, it means that place layers evenly.
use_o2_level
=
False
for
device_meshes
in
self
.
device_meshes_list
:
if
len
(
device_meshes
)
>
1
:
shape
=
None
for
device_mesh
in
device_meshes
:
if
shape
is
None
:
shape
=
device_mesh
.
shape
continue
else
:
if
shape
!=
device_mesh
.
shape
:
self
.
_logger
.
info
(
"Warning: The o1 level is not be supported when the number of machines is prime numer which greaters than 1. We will use o2 level to tune."
)
use_o2_level
=
True
break
if
use_o2_level
:
best_dist_context
=
self
.
tune_o2
()
else
:
best_dist_context
=
self
.
tune_o1
()
assert
(
best_dist_context
is
not
None
),
"can not find a parallel strategy to run, please use passes such as recompute, amp or sharding."
for
key
in
best_dist_context
.
_dist_tensors_for_program
:
if
key
in
self
.
_dist_context
.
_dist_tensors_for_program
:
self
.
_dist_context
.
_dist_tensors_for_program
[
key
]
=
best_dist_context
.
_dist_tensors_for_program
[
key
]
for
key
in
best_dist_context
.
_dist_ops_for_program
:
if
key
in
self
.
_dist_context
.
_dist_ops_for_program
:
self
.
_dist_context
.
_dist_ops_for_program
[
key
]
=
best_dist_context
.
_dist_ops_for_program
[
key
]
self
.
_dist_context
.
_process_meshes
=
best_dist_context
.
_process_meshes
end
=
time
.
time
()
self
.
_logger
.
info
(
"Rule-based tuner end in {}s."
.
format
(
end
-
begin
))
self
.
_logger
.
info
(
"The best strategy found is as follows: "
)
print_program_with_dist_attr
(
self
.
full_main_program
,
best_dist_context
)
# Save strategy if need
path
=
self
.
_strategy_path
if
path
:
self
.
save_strategy
(
best_dist_context
,
path
)
self
.
run_or_quit
()
python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py
浏览文件 @
d6011cb6
...
...
@@ -100,10 +100,10 @@ class TestRuleBasedTuner(unittest.TestCase):
modeling
.
init_global
()
train_program
=
static
.
Program
()
start_program
=
static
.
Program
()
place
=
paddle
.
set_device
(
"gpu"
)
batch_size
=
8
sequence_len
=
512
vocab_size
=
1000
place
=
None
train_program
,
start_program
,
loss
,
gen_data
=
get_gpt_model
(
train_program
,
start_program
,
...
...
@@ -112,31 +112,29 @@ class TestRuleBasedTuner(unittest.TestCase):
sequence_len
,
vocab_size
,
)
from
paddle.distributed.auto_parallel.cluster
import
Cluster
from
paddle.distributed.auto_parallel.dist_context
import
(
DistributedContext
,
)
from
paddle.distributed.auto_parallel.process_mesh
import
ProcessMesh
from
paddle.distributed.auto_parallel.tuner.rule_based_tuner
import
(
RuleBasedTuner
,
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
0.2
)
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
cluster
=
Cluster
()
cluster
.
gen_default_config_cluster
(
node_count
=
1
,
device_count
=
8
)
dist_context
=
DistributedContext
(
serial_main_prog
=
train_program
,
serial_startup_prog
=
start_program
,
serial_optimizer
=
opt
,
serial_loss
=
loss
,
cluster
=
cluster
,
)
dist_context
.
initialize
()
tuner
=
RuleBasedTuner
(
dist_context
)
tuner
.
cluster_operators
()
tuner
.
gen_full_program
()
tuner
.
match_program
(
tuner
.
_dist_context
.
serial_main_program
)
process_mesh
=
ProcessMesh
([
0
,
1
])
tuner
.
gen_fwd_sub_programs_by_clone
()
tuner
.
complete_sub_fwd_programs
(
process_mesh
)
tuner
.
complete_sub_bwd_programs
()
tuner
.
tune
()
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录