Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d6011cb6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d6011cb6
编写于
3月 28, 2023
作者:
C
caozhou
提交者:
GitHub
3月 28, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Add o1 level tune (#52041)
* add tune o1 level * add unittest
上级
418b983c
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
820 addition
and
721 deletion
+820
-721
python/paddle/distributed/auto_parallel/cluster.py
python/paddle/distributed/auto_parallel/cluster.py
+76
-14
python/paddle/distributed/auto_parallel/cost/base_cost.py
python/paddle/distributed/auto_parallel/cost/base_cost.py
+38
-25
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
+1
-1
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
+0
-640
python/paddle/distributed/auto_parallel/cost/estimate_cost.py
...on/paddle/distributed/auto_parallel/cost/estimate_cost.py
+82
-32
python/paddle/distributed/auto_parallel/reshard.py
python/paddle/distributed/auto_parallel/reshard.py
+2
-0
python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
...addle/distributed/auto_parallel/tuner/rule_based_tuner.py
+614
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py
...id/tests/unittests/auto_parallel/test_rule_based_tuner.py
+7
-9
未找到文件。
python/paddle/distributed/auto_parallel/cluster.py
浏览文件 @
d6011cb6
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
import
json
import
json
import
os
import
os
import
re
from
enum
import
IntEnum
,
unique
from
enum
import
IntEnum
,
unique
import
paddle
import
paddle
...
@@ -449,7 +450,6 @@ class Cluster:
...
@@ -449,7 +450,6 @@ class Cluster:
npu_models
=
[
"NPU"
]
npu_models
=
[
"NPU"
]
dcu_models
=
[
"DCU"
]
dcu_models
=
[
"DCU"
]
all_gpu_models
=
gpu_models
+
xpu_models
+
npu_models
+
dcu_models
all_gpu_models
=
gpu_models
+
xpu_models
+
npu_models
+
dcu_models
assert
gpu_model
in
all_gpu_models
self
.
_num_devices_per_machine
=
device_count
self
.
_num_devices_per_machine
=
device_count
def
_convert_to_type
(
gpu_model
):
def
_convert_to_type
(
gpu_model
):
...
@@ -462,6 +462,8 @@ class Cluster:
...
@@ -462,6 +462,8 @@ class Cluster:
type
=
"NPU"
type
=
"NPU"
elif
gpu_model
in
dcu_models
:
elif
gpu_model
in
dcu_models
:
type
=
"DCU"
type
=
"DCU"
else
:
type
=
"GPU"
assert
type
is
not
None
assert
type
is
not
None
return
type
return
type
...
@@ -470,6 +472,12 @@ class Cluster:
...
@@ -470,6 +472,12 @@ class Cluster:
model
=
None
model
=
None
if
gpu_model
==
"V100"
:
if
gpu_model
==
"V100"
:
model
=
"Tesla V100-SXM2-"
+
str
(
gpu_memory
)
+
"GB"
model
=
"Tesla V100-SXM2-"
+
str
(
gpu_memory
)
+
"GB"
elif
gpu_model
==
"A100"
:
model
=
"Tesla A100-SXM-"
+
str
(
gpu_memory
)
+
"GB"
elif
gpu_model
==
"A30"
:
model
=
"Tesla A30-SXM-"
+
str
(
gpu_memory
)
+
"GB"
else
:
model
=
gpu_model
+
str
(
gpu_memory
)
+
"GB"
assert
model
is
not
None
assert
model
is
not
None
return
model
return
model
...
@@ -527,6 +535,8 @@ class Cluster:
...
@@ -527,6 +535,8 @@ class Cluster:
device
[
"memory"
]
=
memory
device
[
"memory"
]
=
memory
device
[
"sp_gflops"
]
=
sp_gflops
device
[
"sp_gflops"
]
=
sp_gflops
device
[
"dp_gflops"
]
=
dp_gflops
device
[
"dp_gflops"
]
=
dp_gflops
# hard code
device
[
"type"
]
=
"GPU"
global_id_to_device_type
[
global_id
]
=
type
global_id_to_device_type
[
global_id
]
=
type
global_id_to_node
[
global_id
]
=
i
global_id_to_node
[
global_id
]
=
i
devices
.
append
(
device
)
devices
.
append
(
device
)
...
@@ -820,30 +830,82 @@ class Cluster:
...
@@ -820,30 +830,82 @@ class Cluster:
return
self
.
__str__
()
return
self
.
__str__
()
def
get_default_cluster
():
def
get_default_cluster
(
json_config
=
None
):
def
is_by_json_config
(
json_config
):
if
not
json_config
:
return
False
if
"cluster"
not
in
json_config
:
return
False
else
:
if
"path"
not
in
json_config
[
"cluster"
]:
if
"num_nodes"
not
in
json_config
[
"cluster"
]:
return
False
if
"num_gpus"
not
in
json_config
[
"cluster"
]:
return
False
if
"gpu_model"
not
in
json_config
[
"cluster"
]:
return
False
if
"gpu_memory"
not
in
json_config
[
"cluster"
]:
return
False
return
True
else
:
return
True
cluster
=
Cluster
()
cluster
=
Cluster
()
local_device_count
=
os
.
getenv
(
"PADDLE_LOCAL_SIZE"
)
if
json_config
and
is_by_json_config
(
json_config
):
if
local_device_count
is
None
:
# Get GPU info by json config
local_device_count
=
1
if
"path"
in
json_config
[
"cluster"
]:
else
:
cluster
.
build_from_file
(
json_config
[
"cluster"
][
"path"
])
local_device_count
=
int
(
local_device_count
)
return
cluster
global_device_count
=
os
.
getenv
(
"PADDLE_GLOBAL_SIZE"
)
else
:
if
global_device_count
is
None
:
node_count
=
json_config
[
"cluster"
][
"num_nodes"
]
node_count
=
1
local_device_count
=
json_config
[
"cluster"
][
"num_gpus"
]
gpu_model
=
json_config
[
"cluster"
][
"gpu_model"
]
memory
=
json_config
[
"cluster"
][
"gpu_memory"
]
else
:
else
:
global_device_count
=
int
(
global_device_count
)
# Get GPU info by get_device_properties
assert
global_device_count
%
local_device_count
==
0
local_device_count
=
os
.
getenv
(
"PADDLE_LOCAL_SIZE"
)
node_count
=
int
(
global_device_count
)
//
local_device_count
if
local_device_count
is
None
:
local_device_count
=
1
else
:
local_device_count
=
int
(
local_device_count
)
global_device_count
=
os
.
getenv
(
"PADDLE_GLOBAL_SIZE"
)
if
global_device_count
is
None
:
node_count
=
1
else
:
global_device_count
=
int
(
global_device_count
)
assert
global_device_count
%
local_device_count
==
0
node_count
=
int
(
global_device_count
)
//
local_device_count
gpu_info
=
paddle
.
device
.
cuda
.
get_device_properties
()
assert
gpu_info
,
"Auto parallel just runs on gpu now."
gpu_name
=
gpu_info
.
name
try
:
re_result
=
re
.
split
(
r
'[ , -]'
,
gpu_name
)
gpu_model
=
re_result
[
1
]
memory
=
int
(
re_result
[
-
1
][:
-
2
])
except
:
memory
=
int
(
gpu_info
.
total_memory
)
//
(
1000
**
3
)
gpu_model
=
gpu_name
print
(
print
(
"Node Count: "
,
"Node Count: "
,
node_count
,
node_count
,
"Local Device Size: "
,
"Local Device Size: "
,
local_device_count
,
local_device_count
,
"GPU Model: "
,
gpu_model
,
"GPU Memory: "
,
memory
,
"World size: "
,
"World size: "
,
paddle
.
distributed
.
get_world_size
(),
paddle
.
distributed
.
get_world_size
(),
flush
=
True
,
flush
=
True
,
)
)
cluster
.
gen_default_config_cluster
(
cluster
.
gen_default_config_cluster
(
node_count
=
node_count
,
device_count
=
local_device_count
node_count
=
node_count
,
device_count
=
local_device_count
,
gpu_model
=
gpu_model
,
gpu_memory
=
memory
,
)
)
return
cluster
return
cluster
python/paddle/distributed/auto_parallel/cost/base_cost.py
浏览文件 @
d6011cb6
...
@@ -16,6 +16,7 @@ from collections import OrderedDict
...
@@ -16,6 +16,7 @@ from collections import OrderedDict
from
functools
import
reduce
from
functools
import
reduce
import
paddle
import
paddle
from
paddle.utils.flops
import
flops
from
..cluster
import
LinkType
from
..cluster
import
LinkType
from
..dist_tensor
import
DistributedTensor
from
..dist_tensor
import
DistributedTensor
...
@@ -91,9 +92,10 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
...
@@ -91,9 +92,10 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
output_desc
=
OrderedDict
()
output_desc
=
OrderedDict
()
# Get partitioned shape of input
# Get partitioned shape of input
input_var_desc
=
{}
for
input_name
in
op
.
input_names
:
for
input_name
in
op
.
input_names
:
var_name_list
=
op
.
input
(
input_name
)
var_name_list
=
op
.
input
(
input_name
)
var_desc
=
[]
input_var_desc
[
input_name
]
=
[]
for
var_name
in
var_name_list
:
for
var_name
in
var_name_list
:
var
=
get_var_with_recursion
(
var
=
get_var_with_recursion
(
var_name
,
op
.
block
,
op
.
block
.
program
var_name
,
op
.
block
,
op
.
block
.
program
...
@@ -112,7 +114,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
...
@@ -112,7 +114,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
process
,
process
,
shard_sizes
,
shard_sizes
,
)
)
var_desc
.
append
((
var
.
dtype
,
shape
)
)
input_var_desc
[
input_name
].
append
(
shape
)
# For special op such as embedding and its grad op
# For special op such as embedding and its grad op
if
(
if
(
...
@@ -137,8 +139,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
...
@@ -137,8 +139,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
relative_idx
=
relative_idx
*
per_part_size
relative_idx
=
relative_idx
*
per_part_size
desc
[
"attrs"
][
"start_index"
]
=
relative_idx
desc
[
"attrs"
][
"start_index"
]
=
relative_idx
input_desc
[
input_name
]
=
var_desc
desc
[
"inputs"
]
=
input_var_desc
desc
[
"inputs"
]
=
input_desc
for
out_name
in
op
.
output_names
:
for
out_name
in
op
.
output_names
:
var_name_list
=
op
.
output
(
out_name
)
var_name_list
=
op
.
output
(
out_name
)
...
@@ -350,7 +351,9 @@ def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None):
...
@@ -350,7 +351,9 @@ def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None):
return
desc
return
desc
def
build_comm_costs_from_descs
(
op_cost_class
,
ctx
,
processes
,
descs
,
cluster
):
def
build_comm_costs_from_descs
(
op_cost_class
,
ctx
,
processes
,
descs
,
cluster
,
is_dp
=
False
):
"""Build comm costs by descriptions"""
"""Build comm costs by descriptions"""
comm_context
=
CommContext
(
cluster
)
comm_context
=
CommContext
(
cluster
)
group_ranks_list
=
[]
group_ranks_list
=
[]
...
@@ -363,6 +366,8 @@ def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster):
...
@@ -363,6 +366,8 @@ def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster):
comm_op_cost
=
op_cost_class
(
comm_op_cost
=
op_cost_class
(
op_desc
=
desc
,
comm_context
=
comm_context
op_desc
=
desc
,
comm_context
=
comm_context
)
)
if
is_dp
:
comm_op_cost
.
cost
.
time
*=
0.9
comm_op_cost_list
.
append
(
comm_op_cost
)
comm_op_cost_list
.
append
(
comm_op_cost
)
return
comm_op_cost_list
return
comm_op_cost_list
...
@@ -389,6 +394,7 @@ def build_dp_costs(
...
@@ -389,6 +394,7 @@ def build_dp_costs(
vars
=
dist_op
.
serial_op
.
block
.
vars
vars
=
dist_op
.
serial_op
.
block
.
vars
var_name
=
var_names
[
0
]
var_name
=
var_names
[
0
]
has_found
=
False
has_found
=
False
is_input
=
True
for
name
in
dist_op
.
serial_op
.
input_arg_names
:
for
name
in
dist_op
.
serial_op
.
input_arg_names
:
if
var_name
in
name
:
if
var_name
in
name
:
var_name
=
name
var_name
=
name
...
@@ -400,6 +406,7 @@ def build_dp_costs(
...
@@ -400,6 +406,7 @@ def build_dp_costs(
if
var_name
in
name
:
if
var_name
in
name
:
var_name
=
name
var_name
=
name
has_found
=
True
has_found
=
True
is_input
=
False
break
break
if
not
has_found
:
if
not
has_found
:
return
return
...
@@ -418,6 +425,7 @@ def build_dp_costs(
...
@@ -418,6 +425,7 @@ def build_dp_costs(
processes
,
processes
,
c_allreduce_sum_descs
,
c_allreduce_sum_descs
,
cluster
,
cluster
,
is_dp
=
True
,
)
)
result
.
append
(
comm_cost_list
)
result
.
append
(
comm_cost_list
)
...
@@ -431,22 +439,11 @@ def build_dp_costs(
...
@@ -431,22 +439,11 @@ def build_dp_costs(
desc
=
{}
desc
=
{}
desc
[
"op"
]
=
op_type
desc
[
"op"
]
=
op_type
desc
[
"inputs"
]
=
{}
desc
[
"inputs"
]
=
{}
if
var_name
in
dist_attr
.
inputs_dist_attrs
:
dims_mapping
=
(
dims_mapping
=
dist_attr
.
get_input_dims_mapping
(
var_name
)
dist_attr
.
get_input_dims_mapping
(
var_name
)
elif
var_name
in
dist_attr
.
outputs_dist_attrs
:
if
is_input
dims_mapping
=
dist_attr
.
get_output_dims_mapping
(
var_name
)
else
dist_attr
.
get_output_dims_mapping
(
var_name
)
else
:
)
raise
AssertionError
(
"cannot find dims_mapping for {} in {}"
.
format
(
var_name
,
dist_attr
)
)
# dims_mapping = (
# dist_attr.get_input_dims_mapping(var_name)
# if dist_attr.get_input_dims_mapping(var_name) is not None
# else dist_attr.get_output_dims_mapping(var_name)
# )
var
=
get_var_with_recursion
(
var
=
get_var_with_recursion
(
var_name
,
var_name
,
dist_op
.
serial_op
.
block
,
dist_op
.
serial_op
.
block
,
...
@@ -493,8 +490,6 @@ class CommContext:
...
@@ -493,8 +490,6 @@ class CommContext:
# if cluster has no info about those vars, it will be set by default
# if cluster has no info about those vars, it will be set by default
self
.
base_ring
=
None
self
.
base_ring
=
None
self
.
base_tree
=
None
self
.
base_tree
=
None
# self.base_inter_ring = None
# self.base_inter_tree = None
self
.
intra_ring
=
None
self
.
intra_ring
=
None
self
.
intra_tree
=
None
self
.
intra_tree
=
None
self
.
inter_ring
=
None
self
.
inter_ring
=
None
...
@@ -508,8 +503,6 @@ class CommContext:
...
@@ -508,8 +503,6 @@ class CommContext:
# set default
# set default
self
.
base_ring
=
8.4
self
.
base_ring
=
8.4
self
.
base_tree
=
0.0
self
.
base_tree
=
0.0
# self.base_inter_ring = 9.6
# self.base_inter_tree = 28
# NVL in default
# NVL in default
self
.
intra_ring
=
3.4
self
.
intra_ring
=
3.4
self
.
intra_tree
=
28
self
.
intra_tree
=
28
...
@@ -681,6 +674,8 @@ class Cost:
...
@@ -681,6 +674,8 @@ class Cost:
class
OpCost
:
class
OpCost
:
OP_TYPE
=
"op"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
):
self
.
_op
=
op
self
.
_op
=
op
self
.
_op_desc
=
op_desc
self
.
_op_desc
=
op_desc
...
@@ -883,6 +878,24 @@ class CompOpCost(OpCost):
...
@@ -883,6 +878,24 @@ class CompOpCost(OpCost):
)
)
)
)
def
calc_flops
(
self
):
if
not
self
.
op_desc
:
return
0
if
"_grad"
in
self
.
__class__
.
OP_TYPE
:
op_type
=
self
.
__class__
.
OP_TYPE
[:
len
(
self
.
__class__
.
OP_TYPE
)
-
5
]
return
2
*
flops
(
op_type
,
self
.
op_desc
[
"inputs"
],
self
.
op_desc
[
"attrs"
]
)
return
flops
(
self
.
__class__
.
OP_TYPE
,
self
.
op_desc
[
"inputs"
],
self
.
op_desc
[
"attrs"
],
)
def
calc_time
(
self
):
flops_count
=
self
.
calc_flops
()
return
flops_count
*
2.9e-7
def
register_op_cost
(
cls
):
def
register_op_cost
(
cls
):
op_type
=
cls
.
OP_TYPE
op_type
=
cls
.
OP_TYPE
...
...
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
浏览文件 @
d6011cb6
...
@@ -140,7 +140,7 @@ class IdentityOpCost(CommOpCost):
...
@@ -140,7 +140,7 @@ class IdentityOpCost(CommOpCost):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
def
calc_time
(
self
):
def
calc_time
(
self
):
return
0
return
self
.
comm_count
*
1
/
(
144
*
1e3
)
@
register_op_cost
@
register_op_cost
...
...
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
浏览文件 @
d6011cb6
...
@@ -22,15 +22,6 @@ class AdamOpCost(CompOpCost):
...
@@ -22,15 +22,6 @@ class AdamOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
ArgsortOpCost
(
CompOpCost
):
class
ArgsortOpCost
(
CompOpCost
):
...
@@ -39,15 +30,6 @@ class ArgsortOpCost(CompOpCost):
...
@@ -39,15 +30,6 @@ class ArgsortOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
AssignOpCost
(
CompOpCost
):
class
AssignOpCost
(
CompOpCost
):
...
@@ -56,15 +38,6 @@ class AssignOpCost(CompOpCost):
...
@@ -56,15 +38,6 @@ class AssignOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
AssignValueOpCost
(
CompOpCost
):
class
AssignValueOpCost
(
CompOpCost
):
...
@@ -73,15 +46,6 @@ class AssignValueOpCost(CompOpCost):
...
@@ -73,15 +46,6 @@ class AssignValueOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
BeamSearchOpCost
(
CompOpCost
):
class
BeamSearchOpCost
(
CompOpCost
):
...
@@ -90,15 +54,6 @@ class BeamSearchOpCost(CompOpCost):
...
@@ -90,15 +54,6 @@ class BeamSearchOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
BeamSearchDecodeOpCost
(
CompOpCost
):
class
BeamSearchDecodeOpCost
(
CompOpCost
):
...
@@ -107,15 +62,6 @@ class BeamSearchDecodeOpCost(CompOpCost):
...
@@ -107,15 +62,6 @@ class BeamSearchDecodeOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
CastOpCost
(
CompOpCost
):
class
CastOpCost
(
CompOpCost
):
...
@@ -124,15 +70,6 @@ class CastOpCost(CompOpCost):
...
@@ -124,15 +70,6 @@ class CastOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
ConcatOpCost
(
CompOpCost
):
class
ConcatOpCost
(
CompOpCost
):
...
@@ -141,15 +78,6 @@ class ConcatOpCost(CompOpCost):
...
@@ -141,15 +78,6 @@ class ConcatOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
DropoutOpCost
(
CompOpCost
):
class
DropoutOpCost
(
CompOpCost
):
...
@@ -158,15 +86,6 @@ class DropoutOpCost(CompOpCost):
...
@@ -158,15 +86,6 @@ class DropoutOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
DropoutGradOpCost
(
CompOpCost
):
class
DropoutGradOpCost
(
CompOpCost
):
...
@@ -175,15 +94,6 @@ class DropoutGradOpCost(CompOpCost):
...
@@ -175,15 +94,6 @@ class DropoutGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
ElementwiseAddOpCost
(
CompOpCost
):
class
ElementwiseAddOpCost
(
CompOpCost
):
...
@@ -192,15 +102,6 @@ class ElementwiseAddOpCost(CompOpCost):
...
@@ -192,15 +102,6 @@ class ElementwiseAddOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
ElementwiseAddGradOpCost
(
CompOpCost
):
class
ElementwiseAddGradOpCost
(
CompOpCost
):
...
@@ -209,15 +110,6 @@ class ElementwiseAddGradOpCost(CompOpCost):
...
@@ -209,15 +110,6 @@ class ElementwiseAddGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
ElementwiseDivOpCost
(
CompOpCost
):
class
ElementwiseDivOpCost
(
CompOpCost
):
...
@@ -226,15 +118,6 @@ class ElementwiseDivOpCost(CompOpCost):
...
@@ -226,15 +118,6 @@ class ElementwiseDivOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
ElementwiseDivGradOpCost
(
CompOpCost
):
class
ElementwiseDivGradOpCost
(
CompOpCost
):
...
@@ -243,15 +126,6 @@ class ElementwiseDivGradOpCost(CompOpCost):
...
@@ -243,15 +126,6 @@ class ElementwiseDivGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
ElementwiseMulOpCost
(
CompOpCost
):
class
ElementwiseMulOpCost
(
CompOpCost
):
...
@@ -260,15 +134,6 @@ class ElementwiseMulOpCost(CompOpCost):
...
@@ -260,15 +134,6 @@ class ElementwiseMulOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
ElementwiseMulGradOpCost
(
CompOpCost
):
class
ElementwiseMulGradOpCost
(
CompOpCost
):
...
@@ -277,15 +142,6 @@ class ElementwiseMulGradOpCost(CompOpCost):
...
@@ -277,15 +142,6 @@ class ElementwiseMulGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
ElementwiseSubOpCost
(
CompOpCost
):
class
ElementwiseSubOpCost
(
CompOpCost
):
...
@@ -294,15 +150,6 @@ class ElementwiseSubOpCost(CompOpCost):
...
@@ -294,15 +150,6 @@ class ElementwiseSubOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
ElementwiseSubGradOpCost
(
CompOpCost
):
class
ElementwiseSubGradOpCost
(
CompOpCost
):
...
@@ -311,15 +158,6 @@ class ElementwiseSubGradOpCost(CompOpCost):
...
@@ -311,15 +158,6 @@ class ElementwiseSubGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
EqualOpCost
(
CompOpCost
):
class
EqualOpCost
(
CompOpCost
):
...
@@ -328,15 +166,6 @@ class EqualOpCost(CompOpCost):
...
@@ -328,15 +166,6 @@ class EqualOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
EmbeddingOpCost
(
CompOpCost
):
class
EmbeddingOpCost
(
CompOpCost
):
...
@@ -345,15 +174,6 @@ class EmbeddingOpCost(CompOpCost):
...
@@ -345,15 +174,6 @@ class EmbeddingOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
EmbeddingGradOpCost
(
CompOpCost
):
class
EmbeddingGradOpCost
(
CompOpCost
):
...
@@ -362,15 +182,6 @@ class EmbeddingGradOpCost(CompOpCost):
...
@@ -362,15 +182,6 @@ class EmbeddingGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
FillConstantOpCost
(
CompOpCost
):
class
FillConstantOpCost
(
CompOpCost
):
...
@@ -379,15 +190,6 @@ class FillConstantOpCost(CompOpCost):
...
@@ -379,15 +190,6 @@ class FillConstantOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
FillConstantBatchSizeLikeOpCost
(
CompOpCost
):
class
FillConstantBatchSizeLikeOpCost
(
CompOpCost
):
...
@@ -396,15 +198,6 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost):
...
@@ -396,15 +198,6 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
FusedSoftmaxMaskUpperTriangleOpCost
(
CompOpCost
):
class
FusedSoftmaxMaskUpperTriangleOpCost
(
CompOpCost
):
...
@@ -413,15 +206,6 @@ class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost):
...
@@ -413,15 +206,6 @@ class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
FusedSoftmaxMaskUpperTriangleGradOpCost
(
CompOpCost
):
class
FusedSoftmaxMaskUpperTriangleGradOpCost
(
CompOpCost
):
...
@@ -430,15 +214,6 @@ class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost):
...
@@ -430,15 +214,6 @@ class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
GatherOpCost
(
CompOpCost
):
class
GatherOpCost
(
CompOpCost
):
...
@@ -447,15 +222,6 @@ class GatherOpCost(CompOpCost):
...
@@ -447,15 +222,6 @@ class GatherOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
GeluOpCost
(
CompOpCost
):
class
GeluOpCost
(
CompOpCost
):
...
@@ -464,15 +230,6 @@ class GeluOpCost(CompOpCost):
...
@@ -464,15 +230,6 @@ class GeluOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
GeluGradOpCost
(
CompOpCost
):
class
GeluGradOpCost
(
CompOpCost
):
...
@@ -481,15 +238,6 @@ class GeluGradOpCost(CompOpCost):
...
@@ -481,15 +238,6 @@ class GeluGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
GreaterEqualOpCost
(
CompOpCost
):
class
GreaterEqualOpCost
(
CompOpCost
):
...
@@ -498,15 +246,6 @@ class GreaterEqualOpCost(CompOpCost):
...
@@ -498,15 +246,6 @@ class GreaterEqualOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
IncrementOpCost
(
CompOpCost
):
class
IncrementOpCost
(
CompOpCost
):
...
@@ -515,11 +254,6 @@ class IncrementOpCost(CompOpCost):
...
@@ -515,11 +254,6 @@ class IncrementOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
IsEmptyOpCost
(
CompOpCost
):
class
IsEmptyOpCost
(
CompOpCost
):
...
@@ -528,11 +262,6 @@ class IsEmptyOpCost(CompOpCost):
...
@@ -528,11 +262,6 @@ class IsEmptyOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
LayerNormOpCost
(
CompOpCost
):
class
LayerNormOpCost
(
CompOpCost
):
...
@@ -541,15 +270,6 @@ class LayerNormOpCost(CompOpCost):
...
@@ -541,15 +270,6 @@ class LayerNormOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
LayerNormGradOpCost
(
CompOpCost
):
class
LayerNormGradOpCost
(
CompOpCost
):
...
@@ -558,15 +278,6 @@ class LayerNormGradOpCost(CompOpCost):
...
@@ -558,15 +278,6 @@ class LayerNormGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
LessThanOpCost
(
CompOpCost
):
class
LessThanOpCost
(
CompOpCost
):
...
@@ -575,15 +286,6 @@ class LessThanOpCost(CompOpCost):
...
@@ -575,15 +286,6 @@ class LessThanOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
LogicalNotOpCost
(
CompOpCost
):
class
LogicalNotOpCost
(
CompOpCost
):
...
@@ -592,15 +294,6 @@ class LogicalNotOpCost(CompOpCost):
...
@@ -592,15 +294,6 @@ class LogicalNotOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
LogicalAndOpCost
(
CompOpCost
):
class
LogicalAndOpCost
(
CompOpCost
):
...
@@ -609,15 +302,6 @@ class LogicalAndOpCost(CompOpCost):
...
@@ -609,15 +302,6 @@ class LogicalAndOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
LodResetOpCost
(
CompOpCost
):
class
LodResetOpCost
(
CompOpCost
):
...
@@ -626,15 +310,6 @@ class LodResetOpCost(CompOpCost):
...
@@ -626,15 +310,6 @@ class LodResetOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
LogOpCost
(
CompOpCost
):
class
LogOpCost
(
CompOpCost
):
...
@@ -643,15 +318,6 @@ class LogOpCost(CompOpCost):
...
@@ -643,15 +318,6 @@ class LogOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
LookupTableV2OpCost
(
CompOpCost
):
class
LookupTableV2OpCost
(
CompOpCost
):
...
@@ -660,15 +326,6 @@ class LookupTableV2OpCost(CompOpCost):
...
@@ -660,15 +326,6 @@ class LookupTableV2OpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
LookupTableV2GradOpCost
(
CompOpCost
):
class
LookupTableV2GradOpCost
(
CompOpCost
):
...
@@ -677,15 +334,6 @@ class LookupTableV2GradOpCost(CompOpCost):
...
@@ -677,15 +334,6 @@ class LookupTableV2GradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
MatmulOpCost
(
CompOpCost
):
class
MatmulOpCost
(
CompOpCost
):
...
@@ -694,15 +342,6 @@ class MatmulOpCost(CompOpCost):
...
@@ -694,15 +342,6 @@ class MatmulOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
MatmulGradOpCost
(
CompOpCost
):
class
MatmulGradOpCost
(
CompOpCost
):
...
@@ -711,15 +350,6 @@ class MatmulGradOpCost(CompOpCost):
...
@@ -711,15 +350,6 @@ class MatmulGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
MatmulV2OpCost
(
CompOpCost
):
class
MatmulV2OpCost
(
CompOpCost
):
...
@@ -728,15 +358,6 @@ class MatmulV2OpCost(CompOpCost):
...
@@ -728,15 +358,6 @@ class MatmulV2OpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
MatmulV2GradOpCost
(
CompOpCost
):
class
MatmulV2GradOpCost
(
CompOpCost
):
...
@@ -745,15 +366,6 @@ class MatmulV2GradOpCost(CompOpCost):
...
@@ -745,15 +366,6 @@ class MatmulV2GradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
MemcpyOpCost
(
CompOpCost
):
class
MemcpyOpCost
(
CompOpCost
):
...
@@ -762,15 +374,6 @@ class MemcpyOpCost(CompOpCost):
...
@@ -762,15 +374,6 @@ class MemcpyOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
MulOpCost
(
CompOpCost
):
class
MulOpCost
(
CompOpCost
):
...
@@ -779,15 +382,6 @@ class MulOpCost(CompOpCost):
...
@@ -779,15 +382,6 @@ class MulOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
MulGradOpCost
(
CompOpCost
):
class
MulGradOpCost
(
CompOpCost
):
...
@@ -796,15 +390,6 @@ class MulGradOpCost(CompOpCost):
...
@@ -796,15 +390,6 @@ class MulGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
OneHotOpCost
(
CompOpCost
):
class
OneHotOpCost
(
CompOpCost
):
...
@@ -813,15 +398,6 @@ class OneHotOpCost(CompOpCost):
...
@@ -813,15 +398,6 @@ class OneHotOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
ReadFromArrayOpCost
(
CompOpCost
):
class
ReadFromArrayOpCost
(
CompOpCost
):
...
@@ -830,15 +406,6 @@ class ReadFromArrayOpCost(CompOpCost):
...
@@ -830,15 +406,6 @@ class ReadFromArrayOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
ReduceSumOpCost
(
CompOpCost
):
class
ReduceSumOpCost
(
CompOpCost
):
...
@@ -847,15 +414,6 @@ class ReduceSumOpCost(CompOpCost):
...
@@ -847,15 +414,6 @@ class ReduceSumOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
ReduceSumGradOpCost
(
CompOpCost
):
class
ReduceSumGradOpCost
(
CompOpCost
):
...
@@ -864,15 +422,6 @@ class ReduceSumGradOpCost(CompOpCost):
...
@@ -864,15 +422,6 @@ class ReduceSumGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
Reshape2OpCost
(
CompOpCost
):
class
Reshape2OpCost
(
CompOpCost
):
...
@@ -881,15 +430,6 @@ class Reshape2OpCost(CompOpCost):
...
@@ -881,15 +430,6 @@ class Reshape2OpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
Reshape2GradOpCost
(
CompOpCost
):
class
Reshape2GradOpCost
(
CompOpCost
):
...
@@ -898,15 +438,6 @@ class Reshape2GradOpCost(CompOpCost):
...
@@ -898,15 +438,6 @@ class Reshape2GradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
ReduceMeanOpCost
(
CompOpCost
):
class
ReduceMeanOpCost
(
CompOpCost
):
...
@@ -915,15 +446,6 @@ class ReduceMeanOpCost(CompOpCost):
...
@@ -915,15 +446,6 @@ class ReduceMeanOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
ReduceMeanGradOpCost
(
CompOpCost
):
class
ReduceMeanGradOpCost
(
CompOpCost
):
...
@@ -932,15 +454,6 @@ class ReduceMeanGradOpCost(CompOpCost):
...
@@ -932,15 +454,6 @@ class ReduceMeanGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
SamplingIdOpCost
(
CompOpCost
):
class
SamplingIdOpCost
(
CompOpCost
):
...
@@ -949,15 +462,6 @@ class SamplingIdOpCost(CompOpCost):
...
@@ -949,15 +462,6 @@ class SamplingIdOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
ScaleOpCost
(
CompOpCost
):
class
ScaleOpCost
(
CompOpCost
):
...
@@ -966,15 +470,6 @@ class ScaleOpCost(CompOpCost):
...
@@ -966,15 +470,6 @@ class ScaleOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
SliceOpCost
(
CompOpCost
):
class
SliceOpCost
(
CompOpCost
):
...
@@ -983,15 +478,6 @@ class SliceOpCost(CompOpCost):
...
@@ -983,15 +478,6 @@ class SliceOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
SoftmaxOpCost
(
CompOpCost
):
class
SoftmaxOpCost
(
CompOpCost
):
...
@@ -1000,15 +486,6 @@ class SoftmaxOpCost(CompOpCost):
...
@@ -1000,15 +486,6 @@ class SoftmaxOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
SoftmaxGradOpCost
(
CompOpCost
):
class
SoftmaxGradOpCost
(
CompOpCost
):
...
@@ -1017,15 +494,6 @@ class SoftmaxGradOpCost(CompOpCost):
...
@@ -1017,15 +494,6 @@ class SoftmaxGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
SoftmaxWithCrossEntropyOpCost
(
CompOpCost
):
class
SoftmaxWithCrossEntropyOpCost
(
CompOpCost
):
...
@@ -1034,15 +502,6 @@ class SoftmaxWithCrossEntropyOpCost(CompOpCost):
...
@@ -1034,15 +502,6 @@ class SoftmaxWithCrossEntropyOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
SoftmaxWithCrossEntropyGradOpCost
(
CompOpCost
):
class
SoftmaxWithCrossEntropyGradOpCost
(
CompOpCost
):
...
@@ -1051,15 +510,6 @@ class SoftmaxWithCrossEntropyGradOpCost(CompOpCost):
...
@@ -1051,15 +510,6 @@ class SoftmaxWithCrossEntropyGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
SplitOpCost
(
CompOpCost
):
class
SplitOpCost
(
CompOpCost
):
...
@@ -1068,15 +518,6 @@ class SplitOpCost(CompOpCost):
...
@@ -1068,15 +518,6 @@ class SplitOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
Squeeze2OpCost
(
CompOpCost
):
class
Squeeze2OpCost
(
CompOpCost
):
...
@@ -1085,15 +526,6 @@ class Squeeze2OpCost(CompOpCost):
...
@@ -1085,15 +526,6 @@ class Squeeze2OpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
SquareOpCost
(
CompOpCost
):
class
SquareOpCost
(
CompOpCost
):
...
@@ -1102,15 +534,6 @@ class SquareOpCost(CompOpCost):
...
@@ -1102,15 +534,6 @@ class SquareOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
SquareGradOpCost
(
CompOpCost
):
class
SquareGradOpCost
(
CompOpCost
):
...
@@ -1119,15 +542,6 @@ class SquareGradOpCost(CompOpCost):
...
@@ -1119,15 +542,6 @@ class SquareGradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
SumOpCost
(
CompOpCost
):
class
SumOpCost
(
CompOpCost
):
...
@@ -1136,15 +550,6 @@ class SumOpCost(CompOpCost):
...
@@ -1136,15 +550,6 @@ class SumOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
TopKOpCost
(
CompOpCost
):
class
TopKOpCost
(
CompOpCost
):
...
@@ -1153,15 +558,6 @@ class TopKOpCost(CompOpCost):
...
@@ -1153,15 +558,6 @@ class TopKOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
Transpose2OpCost
(
CompOpCost
):
class
Transpose2OpCost
(
CompOpCost
):
...
@@ -1170,15 +566,6 @@ class Transpose2OpCost(CompOpCost):
...
@@ -1170,15 +566,6 @@ class Transpose2OpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
Transpose2GradOpCost
(
CompOpCost
):
class
Transpose2GradOpCost
(
CompOpCost
):
...
@@ -1187,15 +574,6 @@ class Transpose2GradOpCost(CompOpCost):
...
@@ -1187,15 +574,6 @@ class Transpose2GradOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
Unsqueeze2OpCost
(
CompOpCost
):
class
Unsqueeze2OpCost
(
CompOpCost
):
...
@@ -1204,15 +582,6 @@ class Unsqueeze2OpCost(CompOpCost):
...
@@ -1204,15 +582,6 @@ class Unsqueeze2OpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
@
register_op_cost
@
register_op_cost
class
WriteToArrayOpCost
(
CompOpCost
):
class
WriteToArrayOpCost
(
CompOpCost
):
...
@@ -1220,12 +589,3 @@ class WriteToArrayOpCost(CompOpCost):
...
@@ -1220,12 +589,3 @@ class WriteToArrayOpCost(CompOpCost):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
,
cluster
=
None
):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
cluster
=
cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def
calc_flops
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
def
calc_time
(
self
):
# NOTE: The actual formula will be filled in the future
return
0
python/paddle/distributed/auto_parallel/cost/estimate_cost.py
浏览文件 @
d6011cb6
...
@@ -189,6 +189,9 @@ class CostEstimator:
...
@@ -189,6 +189,9 @@ class CostEstimator:
# Calc dist op cost
# Calc dist op cost
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_op
:
continue
op_dist_attr
=
dist_op
.
dist_attr
op_dist_attr
=
dist_op
.
dist_attr
processes
=
op_dist_attr
.
process_mesh
.
process_ids
processes
=
op_dist_attr
.
process_mesh
.
process_ids
...
@@ -225,6 +228,8 @@ class CostEstimator:
...
@@ -225,6 +228,8 @@ class CostEstimator:
for
rank
in
group_ranks
:
for
rank
in
group_ranks
:
self
.
local_cost
(
rank
).
time
=
(
self
.
local_cost
(
rank
).
time
=
(
max_time
+
comm_op_cost
.
time
max_time
+
comm_op_cost
.
time
if
op
.
attr
(
'op_role'
)
!=
OpRole
.
Backward
else
max_time
+
0.9
*
comm_op_cost
.
time
)
)
if
rank
not
in
self
.
_bubble_time_mapping
:
if
rank
not
in
self
.
_bubble_time_mapping
:
self
.
_bubble_time_mapping
[
rank
]
=
0
self
.
_bubble_time_mapping
[
rank
]
=
0
...
@@ -290,6 +295,7 @@ class CostEstimator:
...
@@ -290,6 +295,7 @@ class CostEstimator:
self
.
_ordered_ops
.
append
([
op
.
desc
.
id
(),
op
])
self
.
_ordered_ops
.
append
([
op
.
desc
.
id
(),
op
])
self
.
_ordered_ops
.
sort
(
key
=
lambda
x
:
x
[
0
])
self
.
_ordered_ops
.
sort
(
key
=
lambda
x
:
x
[
0
])
parameters
=
set
()
for
op_id
,
op
in
self
.
_ordered_ops
:
for
op_id
,
op
in
self
.
_ordered_ops
:
if
op
.
type
in
[
if
op
.
type
in
[
"create_py_reader"
,
"create_py_reader"
,
...
@@ -298,11 +304,14 @@ class CostEstimator:
...
@@ -298,11 +304,14 @@ class CostEstimator:
]:
]:
continue
continue
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_op
:
continue
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
for
var_name
in
op
.
input_arg_names
:
for
var_name
in
op
.
input_arg_names
:
input_dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
input_dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
var_name
var_name
)
)
if
var_name
not
in
var_info
:
if
var_name
not
in
var_info
:
var_info
[
var_name
]
=
{}
var_info
[
var_name
]
=
{}
key
=
_convert_pm_and_dm_to_str
(
key
=
_convert_pm_and_dm_to_str
(
...
@@ -311,6 +320,10 @@ class CostEstimator:
...
@@ -311,6 +320,10 @@ class CostEstimator:
if
key
not
in
var_info
[
var_name
]:
if
key
not
in
var_info
[
var_name
]:
var_info
[
var_name
][
key
]
=
{}
var_info
[
var_name
][
key
]
=
{}
# It is even partition now
# It is even partition now
if
"position"
not
in
var_info
[
var_name
][
key
]:
var_info
[
var_name
][
key
][
"position"
]
=
[]
var_info
[
var_name
][
key
][
"position"
].
append
(
op_id
)
if
"memory"
not
in
var_info
[
var_name
][
key
]:
if
"memory"
not
in
var_info
[
var_name
][
key
]:
var
=
dist_op
.
get_serial_input
(
var_name
)
var
=
dist_op
.
get_serial_input
(
var_name
)
global_sizes
=
var
.
shape
global_sizes
=
var
.
shape
...
@@ -324,9 +337,16 @@ class CostEstimator:
...
@@ -324,9 +337,16 @@ class CostEstimator:
var_info
[
var_name
][
key
][
"memory"
]
=
self
.
_calculate_bytes
(
var_info
[
var_name
][
key
][
"memory"
]
=
self
.
_calculate_bytes
(
sizes
,
dtype
sizes
,
dtype
)
)
if
"position"
not
in
var_info
[
var_name
][
key
]:
if
var
.
persistable
:
var_info
[
var_name
][
key
][
"position"
]
=
[]
name
=
var_name
+
key
var_info
[
var_name
][
key
][
"position"
].
append
(
op_id
)
if
name
not
in
parameters
:
parameters
.
add
(
name
)
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
memories
:
memories
[
process
]
=
0
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
for
var_name
in
op
.
output_arg_names
:
for
var_name
in
op
.
output_arg_names
:
output_dims_mapping
=
dist_op
.
dist_attr
.
get_output_dims_mapping
(
output_dims_mapping
=
dist_op
.
dist_attr
.
get_output_dims_mapping
(
...
@@ -339,6 +359,10 @@ class CostEstimator:
...
@@ -339,6 +359,10 @@ class CostEstimator:
)
)
if
key
not
in
var_info
[
var_name
]:
if
key
not
in
var_info
[
var_name
]:
var_info
[
var_name
][
key
]
=
{}
var_info
[
var_name
][
key
]
=
{}
if
"position"
not
in
var_info
[
var_name
][
key
]:
var_info
[
var_name
][
key
][
"position"
]
=
[]
var_info
[
var_name
][
key
][
"position"
].
append
(
op_id
)
if
"memory"
not
in
var_info
[
var_name
][
key
]:
if
"memory"
not
in
var_info
[
var_name
][
key
]:
var
=
dist_op
.
get_serial_output
(
var_name
)
var
=
dist_op
.
get_serial_output
(
var_name
)
global_sizes
=
var
.
shape
global_sizes
=
var
.
shape
...
@@ -352,11 +376,19 @@ class CostEstimator:
...
@@ -352,11 +376,19 @@ class CostEstimator:
var_info
[
var_name
][
key
][
"memory"
]
=
self
.
_calculate_bytes
(
var_info
[
var_name
][
key
][
"memory"
]
=
self
.
_calculate_bytes
(
sizes
,
dtype
sizes
,
dtype
)
)
if
"position"
not
in
var_info
[
var_name
][
key
]:
if
var
.
persistable
:
var_info
[
var_name
][
key
][
"position"
]
=
[]
name
=
var_name
+
key
var_info
[
var_name
][
key
][
"position"
].
append
(
op_id
)
if
name
not
in
parameters
:
parameters
.
add
(
name
)
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
memories
:
memories
[
process
]
=
0
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
has_used_vars
=
set
()
has_used_vars
=
set
()
not_calc_vars
=
set
()
for
op_id
,
op
in
self
.
_ordered_ops
:
for
op_id
,
op
in
self
.
_ordered_ops
:
if
op
.
type
in
[
if
op
.
type
in
[
"create_py_reader"
,
"create_py_reader"
,
...
@@ -367,6 +399,8 @@ class CostEstimator:
...
@@ -367,6 +399,8 @@ class CostEstimator:
can_free_memories
=
{}
can_free_memories
=
{}
can_free_vars
=
set
()
can_free_vars
=
set
()
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_op
:
continue
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
for
var_name
in
op
.
input_arg_names
:
for
var_name
in
op
.
input_arg_names
:
input_dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
input_dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
...
@@ -378,24 +412,30 @@ class CostEstimator:
...
@@ -378,24 +412,30 @@ class CostEstimator:
has_used_var
=
var_name
+
key
has_used_var
=
var_name
+
key
var
=
dist_op
.
get_serial_input
(
var_name
)
var
=
dist_op
.
get_serial_input
(
var_name
)
# Not used
# Not used
if
var_name
+
key
not
in
has_used_vars
:
if
(
has_used_var
not
in
has_used_vars
and
has_used_var
not
in
parameters
):
if
has_used_var
in
not_calc_vars
:
continue
has_used_vars
.
add
(
has_used_var
)
has_used_vars
.
add
(
has_used_var
)
for
process
in
process_mesh
.
process_ids
:
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
memories
:
if
process
not
in
memories
:
memories
[
process
]
=
0
memories
[
process
]
=
0
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
# Used
# Used
else
:
if
op_id
==
var_info
[
var_name
][
key
][
"position"
][
-
1
]:
if
op_id
==
var_info
[
var_name
][
key
][
"position"
][
-
1
]:
if
(
if
has_used_var
not
in
can_free_vars
:
has_used_var
not
in
can_free_vars
can_free_vars
.
add
(
has_used_var
)
and
not
var
.
persistable
if
not
var
.
persistable
:
):
for
process
in
process_mesh
.
process_ids
:
can_free_vars
.
add
(
has_used_var
)
if
process
not
in
can_free_memories
:
for
process
in
process_mesh
.
process_ids
:
can_free_memories
[
process
]
=
0
if
process
not
in
can_free_memories
:
can_free_memories
[
process
]
+=
var_info
[
can_free_memories
[
process
]
=
0
var_name
can_free_memories
[
process
]
+=
var_info
[
var_name
][
][
key
][
"memory"
]
key
][
"memory"
]
for
var_name
in
op
.
output_arg_names
:
for
var_name
in
op
.
output_arg_names
:
output_dims_mapping
=
dist_op
.
dist_attr
.
get_output_dims_mapping
(
output_dims_mapping
=
dist_op
.
dist_attr
.
get_output_dims_mapping
(
...
@@ -406,25 +446,36 @@ class CostEstimator:
...
@@ -406,25 +446,36 @@ class CostEstimator:
)
)
has_used_var
=
var_name
+
key
has_used_var
=
var_name
+
key
var
=
dist_op
.
get_serial_output
(
var_name
)
var
=
dist_op
.
get_serial_output
(
var_name
)
if
(
op
.
type
==
"reshape2"
or
op
.
type
==
"transpose2"
or
op
.
type
==
"elementwise_add"
):
not_calc_vars
.
add
(
has_used_var
)
continue
# Not used
# Not used
if
var_name
+
key
not
in
has_used_vars
:
if
(
has_used_var
not
in
has_used_vars
and
has_used_var
not
in
parameters
):
has_used_vars
.
add
(
has_used_var
)
has_used_vars
.
add
(
has_used_var
)
for
process
in
process_mesh
.
process_ids
:
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
memories
:
if
process
not
in
memories
:
memories
[
process
]
=
0
memories
[
process
]
=
0
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
# Used
# Used
else
:
if
op_id
==
var_info
[
var_name
][
key
][
"position"
][
-
1
]:
if
op_id
==
var_info
[
var_name
][
key
][
"position"
][
-
1
]:
if
(
if
has_used_var
not
in
can_free_vars
:
has_used_var
not
in
can_free_vars
can_free_vars
.
add
(
has_used_var
)
and
not
var
.
persistable
if
not
var
.
persistable
:
):
for
process
in
process_mesh
.
process_ids
:
can_free_vars
.
add
(
has_used_var
)
if
process
not
in
can_free_memories
:
for
process
in
process_mesh
.
process_ids
:
can_free_memories
[
process
]
=
0
if
process
not
in
can_free_memories
:
can_free_memories
[
process
]
+=
var_info
[
can_free_memories
[
process
]
=
0
var_name
can_free_memories
[
process
]
+=
var_info
[
var_name
][
][
key
][
"memory"
]
key
][
"memory"
]
# Calc peak memory
# Calc peak memory
for
process
in
memories
:
for
process
in
memories
:
...
@@ -433,7 +484,6 @@ class CostEstimator:
...
@@ -433,7 +484,6 @@ class CostEstimator:
else
:
else
:
if
memories
[
process
]
>
self
.
max_memories
[
process
]:
if
memories
[
process
]
>
self
.
max_memories
[
process
]:
self
.
max_memories
[
process
]
=
memories
[
process
]
self
.
max_memories
[
process
]
=
memories
[
process
]
# Free memory
# Free memory
for
process
in
can_free_memories
:
for
process
in
can_free_memories
:
if
process
in
memories
:
if
process
in
memories
:
...
@@ -513,7 +563,7 @@ class CostEstimator:
...
@@ -513,7 +563,7 @@ class CostEstimator:
# Padding automatically
# Padding automatically
max_len
=
0
max_len
=
0
header
=
[
"Execution Time(
m
s)"
,
"Max Memory(MiB)"
]
header
=
[
"Execution Time(
u
s)"
,
"Max Memory(MiB)"
]
vals
=
[
round
(
self
.
global_cost
.
time
,
3
),
int
(
self
.
max_memory
//
1e6
)]
vals
=
[
round
(
self
.
global_cost
.
time
,
3
),
int
(
self
.
max_memory
//
1e6
)]
for
memory
in
vals
+
header
:
for
memory
in
vals
+
header
:
if
len
(
str
(
memory
))
>
max_len
:
if
len
(
str
(
memory
))
>
max_len
:
...
...
python/paddle/distributed/auto_parallel/reshard.py
浏览文件 @
d6011cb6
...
@@ -2716,6 +2716,8 @@ class Resharder:
...
@@ -2716,6 +2716,8 @@ class Resharder:
)
)
# simplified processing: ignore union process mesh and output reshard
# simplified processing: ignore union process mesh and output reshard
dist_op
=
self
.
dist_context
.
get_dist_op_for_program
(
op
)
dist_op
=
self
.
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_tensor
or
not
dist_op
:
return
reshard_op_cost
dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
tensor
.
name
tensor
.
name
)
)
...
...
python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
浏览文件 @
d6011cb6
...
@@ -16,17 +16,31 @@ import copy
...
@@ -16,17 +16,31 @@ import copy
import
logging
import
logging
import
math
import
math
import
os
import
os
import
pickle
import
sys
import
time
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
functools
import
reduce
import
numpy
as
np
import
paddle
import
paddle
from
paddle.distributed.auto_parallel.cluster_v2
import
DeviceMesh
from
paddle.distributed.auto_parallel.completion
import
Completer
from
paddle.distributed.auto_parallel.completion
import
Completer
from
paddle.distributed.auto_parallel.cost
import
CostEstimator
from
paddle.distributed.auto_parallel.dist_attribute
import
(
from
paddle.distributed.auto_parallel.dist_attribute
import
(
OperatorDistAttr
,
OperatorDistAttr
,
TensorDistAttr
,
TensorDistAttr
,
)
)
from
paddle.distributed.auto_parallel.dist_context
import
DistributedContext
from
paddle.distributed.auto_parallel.dist_context
import
DistributedContext
from
paddle.distributed.auto_parallel.dist_tensor
import
DistributedTensor
from
paddle.distributed.auto_parallel.dist_tensor
import
DistributedTensor
from
paddle.distributed.auto_parallel.process_mesh
import
ProcessMesh
from
paddle.distributed.auto_parallel.utils
import
(
is_gradient_clip_op
,
print_program_with_dist_attr
,
)
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
from
paddle.fluid
import
program_guard
from
paddle.fluid
import
program_guard
from
paddle.fluid.backward
import
append_backward
from
paddle.fluid.backward
import
append_backward
from
paddle.fluid.framework
import
Parameter
,
unique_name
from
paddle.fluid.framework
import
Parameter
,
unique_name
...
@@ -1610,3 +1624,603 @@ class RuleBasedTuner:
...
@@ -1610,3 +1624,603 @@ class RuleBasedTuner:
idx
idx
][
parallelism
][
key
]
][
parallelism
][
key
]
self
.
_complete_sub_bwd_program
(
sub_program_dist_context
)
self
.
_complete_sub_bwd_program
(
sub_program_dist_context
)
def
_complete_sub_update_program
(
self
,
sub_program_dist_context
):
"""
Complete the opt OP according to the tensor.
Most of the logic is the same as the update completion in the completer.
"""
world_ranks
=
ProcessMesh
(
[
i
for
i
in
range
(
self
.
_cluster
.
get_num_machines
()
*
self
.
_cluster
.
_num_devices_per_machine
)
]
)
dist_tensors
=
sub_program_dist_context
.
_dist_tensors_for_program
vars
=
self
.
full_main_program
.
global_block
().
vars
ops
=
self
.
full_main_program
.
global_block
().
ops
learning_rate_completed
=
False
for
idx
in
range
(
len
(
ops
)):
op
=
ops
[
idx
]
if
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Optimize
):
if
is_gradient_clip_op
(
op
):
if
op
.
type
in
[
"sum"
,
"sqrt"
,
"fill_constant"
,
"elementwise_max"
,
"elementwise_div"
,
]:
op_dist_attr
=
OperatorDistAttr
()
op_dist_attr
.
process_mesh
=
world_ranks
for
in_name
in
op
.
input_arg_names
:
in_var
=
vars
[
in_name
]
if
in_var
.
desc
.
original_id
()
in
dist_tensors
:
in_dist_attr
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
in_var
)
op_dist_attr
.
set_input_dist_attr
(
in_name
,
in_dist_attr
)
else
:
in_dist_attr
=
TensorDistAttr
()
in_dist_attr
.
process_mesh
=
world_ranks
in_dist_attr
.
dims_mapping
=
[
-
1
for
_
in
range
(
len
(
in_var
.
shape
))
]
op_dist_attr
.
set_input_dist_attr
(
in_name
,
in_dist_attr
)
sub_program_dist_context
.
set_tensor_dist_attr_for_program
(
in_var
,
in_dist_attr
)
for
out_name
in
op
.
output_arg_names
:
out_var
=
vars
[
out_name
]
if
out_var
.
desc
.
original_id
()
in
dist_tensors
:
out_dist_attr
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
out_var
)
op_dist_attr
.
set_output_dist_attr
(
out_name
,
out_dist_attr
)
else
:
out_dist_attr
=
TensorDistAttr
()
out_dist_attr
.
process_mesh
=
world_ranks
out_dist_attr
.
dims_mapping
=
[
-
1
for
_
in
range
(
len
(
out_var
.
shape
))
]
sub_program_dist_context
.
set_tensor_dist_attr_for_program
(
out_var
,
out_dist_attr
)
op_dist_attr
.
set_output_dist_attr
(
out_name
,
out_dist_attr
)
sub_program_dist_context
.
set_op_dist_attr_for_program
(
op
,
op_dist_attr
)
else
:
in_var
=
vars
[
op
.
input
(
"X"
)[
0
]]
if
in_var
.
desc
.
original_id
()
in
dist_tensors
:
in_dist_attr
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
in_var
)
assert
in_dist_attr
is
not
None
ref_process_mesh
=
in_dist_attr
.
process_mesh
ref_dims_mapping
=
in_dist_attr
.
dims_mapping
if
(
op
.
type
==
"cast"
and
ops
[
idx
+
1
].
type
==
"elementwise_mul"
):
ref_var
=
vars
[
ops
[
idx
+
1
].
input
(
"X"
)[
0
]]
ref_dist_attr
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
ref_var
)
assert
ref_dist_attr
is
not
None
ref_process_mesh
=
ref_dist_attr
.
process_mesh
out_var
=
vars
[
op
.
output
(
"Out"
)[
0
]]
out_dist_attr
=
TensorDistAttr
()
out_dist_attr
.
process_mesh
=
ref_process_mesh
if
out_var
.
shape
==
in_var
.
shape
:
out_dist_attr
.
dims_mapping
=
ref_dims_mapping
else
:
assert
(
len
(
out_var
.
shape
)
==
1
and
out_var
.
shape
[
0
]
==
1
)
out_dist_attr
.
dims_mapping
=
[
-
1
]
sub_program_dist_context
.
set_tensor_dist_attr_for_program
(
out_var
,
out_dist_attr
)
op_dist_attr
=
OperatorDistAttr
()
op_dist_attr
.
process_mesh
=
ref_process_mesh
for
in_name
in
op
.
input_arg_names
:
in_var
=
vars
[
in_name
]
in_dist_attr
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
in_var
)
op_dist_attr
.
set_input_dims_mapping
(
in_name
,
in_dist_attr
.
dims_mapping
)
for
out_name
in
op
.
output_arg_names
:
out_var
=
vars
[
out_name
]
out_dist_attr
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
out_var
)
op_dist_attr
.
set_output_dims_mapping
(
out_name
,
out_dist_attr
.
dims_mapping
)
op_dist_attr
.
set_input_dist_attr
(
in_var
.
name
,
in_dist_attr
)
op_dist_attr
.
set_output_dist_attr
(
out_var
.
name
,
out_dist_attr
)
sub_program_dist_context
.
set_op_dist_attr_for_program
(
op
,
op_dist_attr
)
else
:
continue
if
"Grad"
in
op
.
input_names
and
"Param"
in
ops
[
idx
].
input_names
:
assert
(
len
(
op
.
input
(
"Param"
))
==
1
),
"Only support one-to-one now."
assert
(
len
(
op
.
input
(
"Grad"
))
==
1
),
"Only support one-to-one now."
param
=
vars
[
op
.
input
(
"Param"
)[
0
]]
grad_var
=
vars
[
op
.
input
(
"Grad"
)[
0
]]
if
param
.
desc
.
original_id
()
in
dist_tensors
:
param_dist_attr
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
param
)
assert
param_dist_attr
is
not
None
ref_process_mesh
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
param
).
process_mesh
assert
ref_process_mesh
is
not
None
ref_dims_mapping
=
sub_program_dist_context
.
get_tensor_dist_attr_for_program
(
param
).
dims_mapping
assert
ref_dims_mapping
is
not
None
op_dist_attr
=
OperatorDistAttr
()
op_dist_attr
.
process_mesh
=
ref_process_mesh
op_dist_attr
.
set_input_dims_mapping
(
grad_var
.
name
,
ref_dims_mapping
)
op_dist_attr
.
set_input_dims_mapping
(
param
.
name
,
ref_dims_mapping
)
op_dist_attr
.
set_output_dims_mapping
(
param
.
name
,
ref_dims_mapping
)
learning_var
=
vars
[
op
.
input
(
"LearningRate"
)[
0
]]
op_dist_attr
.
set_input_dims_mapping
(
learning_var
.
name
,
[
-
1
]
)
op_dist_attr
.
set_output_dims_mapping
(
learning_var
.
name
,
[
-
1
]
)
if
not
learning_rate_completed
:
learning_rate_completed
=
True
var_dist_attr
=
TensorDistAttr
()
var_dist_attr
.
process_mesh
=
world_ranks
var_dist_attr
.
dims_mapping
=
[
-
1
]
sub_program_dist_context
.
set_tensor_dist_attr_for_program
(
learning_var
,
var_dist_attr
)
for
input_name
in
op
.
desc
.
input_names
():
if
input_name
in
[
'Param'
,
'Grad'
,
'LearningRate'
,
"SkipUpdate"
,
"Beta1Tensor"
,
"Beta2Tensor"
,
"EpsilonTensor"
,
]:
continue
if
len
(
op
.
desc
.
input
(
input_name
))
==
0
:
continue
assert
len
(
op
.
desc
.
input
(
input_name
))
==
1
input_var
=
vars
[
op
.
desc
.
input
(
input_name
)[
0
]]
input_var_attr
=
TensorDistAttr
()
if
(
"Beta1Pow"
in
input_name
or
"Beta2Pow"
in
input_name
):
input_var_attr
.
dims_mapping
=
[
-
1
]
op_dist_attr
.
set_input_dims_mapping
(
input_var
.
name
,
[
-
1
]
)
op_dist_attr
.
set_output_dims_mapping
(
input_var
.
name
,
[
-
1
]
)
else
:
input_var_attr
.
dims_mapping
=
ref_dims_mapping
op_dist_attr
.
set_input_dims_mapping
(
input_var
.
name
,
ref_dims_mapping
)
op_dist_attr
.
set_output_dims_mapping
(
input_var
.
name
,
ref_dims_mapping
)
input_var_attr
.
process_mesh
=
ref_process_mesh
sub_program_dist_context
.
set_tensor_dist_attr_for_program
(
input_var
,
input_var_attr
)
sub_program_dist_context
.
set_op_dist_attr_for_program
(
op
,
op_dist_attr
)
continue
else
:
continue
def
complete_sub_update_programs
(
self
):
for
idx
in
self
.
sub_programs_dist_context
:
for
parallelism
in
self
.
sub_programs_dist_context
[
idx
]:
for
key
in
self
.
sub_programs_dist_context
[
idx
][
parallelism
]:
sub_program_dist_context
=
self
.
sub_programs_dist_context
[
idx
][
parallelism
][
key
]
self
.
_complete_sub_update_program
(
sub_program_dist_context
)
def
convert_device_mesh_to_key
(
self
,
device_mesh
):
"""Convert device mesh object to str."""
processes
=
","
.
join
([
str
(
x
)
for
x
in
device_mesh
.
device_ids
])
topology
=
","
.
join
([
str
(
x
)
for
x
in
device_mesh
.
shape
])
key
=
processes
+
";"
+
topology
return
key
def
_get_sub_program_cost
(
self
,
dist_context
):
"""Estimate the cost of dist context."""
cost_estimator
=
CostEstimator
(
self
.
full_main_program
,
self
.
_cluster
)
global_cost
=
cost_estimator
.
estimate
(
dist_context
)
max_memory
=
cost_estimator
.
_estimate_max_memory_by_dist_op
(
dist_context
)
return
global_cost
.
time
,
max_memory
def
combine_dist_contexts
(
self
,
dist_contexts
):
"""Combine the dist attr in dist contexts to one dist context."""
combined_dist_context
=
DistributedContext
()
# set dist tensor, pay attention to shared param or var as input for multi op
for
dist_context
in
dist_contexts
:
for
tensor_id
in
dist_context
.
_dist_tensors_for_program
:
dist_tensor
=
dist_context
.
_dist_tensors_for_program
[
tensor_id
]
if
(
tensor_id
not
in
combined_dist_context
.
_dist_tensors_for_program
):
combined_dist_context
.
add_dist_tensor_for_program
(
dist_tensor
)
# set dist op
for
op_id
in
dist_context
.
_dist_ops_for_program
:
dist_op
=
dist_context
.
_dist_ops_for_program
[
op_id
]
combined_dist_context
.
add_dist_op_for_program
(
dist_op
)
for
process_mesh
in
dist_context
.
process_meshes
:
combined_dist_context
.
add_process_mesh
(
process_mesh
)
return
combined_dist_context
def
prepare
(
self
):
"""Prepare the sub program, tensor dist attr setting, device meshes and so on that tuner need."""
# step1: cluster operators to layers
begin
=
time
.
time
()
self
.
layers
=
self
.
cluster_operators
()
end
=
time
.
time
()
self
.
_logger
.
info
(
"Cluster operators to {} layers in {}s."
.
format
(
len
(
self
.
layers
),
end
-
begin
)
)
# step2: generate sub program of each layer
begin
=
time
.
time
()
self
.
gen_fwd_sub_programs_by_clone
()
end
=
time
.
time
()
self
.
_logger
.
info
(
"Generate programs of every layer in {}s."
.
format
(
end
-
begin
)
)
# step3: partition devices to device meshes
begin
=
time
.
time
()
n
,
m
=
(
self
.
_cluster
.
get_num_machines
(),
self
.
_cluster
.
_num_devices_per_machine
,
)
device_meshes_list
=
ClusterPartitionUtil
.
partition_cluster
(
n
,
m
)
end
=
time
.
time
()
self
.
_logger
.
info
(
"Partition cluster in {}s."
.
format
(
end
-
begin
))
# step4: transform device mesh to process meshes
dm_idx
=
0
for
device_meshes
in
device_meshes_list
:
has_used_devices
=
0
self
.
device_meshes_list
.
append
([])
for
device_mesh
in
device_meshes
:
devices
=
reduce
(
lambda
x
,
y
:
x
*
y
,
device_mesh
)
processes
=
[
i
for
i
in
range
(
has_used_devices
,
has_used_devices
+
devices
)
]
device_mesh_shape
=
(
device_mesh
if
device_mesh
[
0
]
!=
1
else
[
device_mesh
[
i
]
for
i
in
range
(
1
,
len
(
device_mesh
))]
)
self
.
device_meshes_list
[
-
1
].
append
(
DeviceMesh
(
mesh
=
np
.
array
(
processes
)
.
reshape
(
device_mesh_shape
)
.
tolist
(),
name
=
"device_mesh_"
+
str
(
dm_idx
),
)
)
dm_idx
+=
1
has_used_devices
+=
devices
process_mesh_shapes
=
convert_to_process_meshes
(
device_mesh
)
for
process_mesh_shape
in
process_mesh_shapes
:
process_mesh
=
ProcessMesh
(
np
.
array
(
processes
).
reshape
(
process_mesh_shape
).
tolist
()
)
if
process_mesh
not
in
self
.
process_meshes
:
self
.
process_meshes
.
append
(
process_mesh
)
# step5: generate full program
begin
=
time
.
time
()
self
.
gen_full_program
()
end
=
time
.
time
()
self
.
_logger
.
info
(
"Generate full program in {}s."
.
format
(
end
-
begin
))
# step6: complete forward sub programs
begin
=
time
.
time
()
for
process_mesh
in
self
.
process_meshes
:
self
.
complete_sub_fwd_programs
(
process_mesh
)
end
=
time
.
time
()
self
.
_logger
.
info
(
"Complete all sub forward programs in {}s."
.
format
(
end
-
begin
)
)
if
self
.
mode
==
"train"
:
# step7: complete backward sub programs
begin
=
time
.
time
()
self
.
complete_sub_bwd_programs
()
end
=
time
.
time
()
self
.
_logger
.
info
(
"Complete all sub backward programs in {}s."
.
format
(
end
-
begin
)
)
# step8: complete update sub programs
begin
=
time
.
time
()
self
.
complete_sub_update_programs
()
end
=
time
.
time
()
self
.
_logger
.
info
(
"Complete all sub update programs in {}s."
.
format
(
end
-
begin
)
)
def
tune_o1
(
self
):
"""The o1 level tuning."""
best_cost
=
sys
.
maxsize
best_dist_context
=
None
for
device_meshes
in
self
.
device_meshes_list
:
pp_stages
=
len
(
device_meshes
)
average_layers
=
len
(
self
.
layers
)
//
pp_stages
device_mesh_shape
=
device_meshes
[
0
].
shape
if
len
(
device_mesh_shape
)
==
1
:
device_mesh_shape
.
insert
(
0
,
1
)
process_mesh_shapes
=
convert_to_process_meshes
(
device_mesh_shape
)
# For example, device_mesh is [1, 8] and process_mesh is [8].
# The selective parallelism is dp or mp
# Get dp8 or mp8 cost and compare them to get best sreategy.
for
parallelism
in
[
"dp"
,
"mp"
,
"dp_mp"
,
"mp_dp"
]:
for
process_mesh_shape
in
process_mesh_shapes
:
dist_context_of_device_meshes
=
None
for
idx
,
device_mesh
in
enumerate
(
device_meshes
):
device_mesh_shape
=
device_mesh
.
shape
process_mesh
=
ProcessMesh
(
np
.
array
(
device_mesh
.
device_ids
)
.
reshape
(
process_mesh_shape
)
.
tolist
()
)
selective_parallelisms
=
(
[
"dp"
,
"mp"
]
if
len
(
process_mesh
.
shape
)
==
1
else
[
"dp_mp"
,
"mp_dp"
]
)
if
parallelism
not
in
selective_parallelisms
:
total_cost_of_device_meshes
=
sys
.
maxsize
continue
key
=
self
.
convert_process_mesh_to_key
(
process_mesh
)
if
idx
==
len
(
device_meshes
)
-
1
:
start
=
idx
*
average_layers
end
=
len
(
self
.
layers
)
else
:
start
=
idx
*
average_layers
end
=
(
idx
+
1
)
*
average_layers
dist_context
=
self
.
combine_dist_contexts
(
[
self
.
sub_programs_dist_context
[
j
][
parallelism
][
key
]
for
j
in
range
(
start
,
end
)
]
)
dist_context_of_device_meshes
=
(
dist_context
if
dist_context_of_device_meshes
is
None
else
self
.
combine_dist_contexts
(
[
dist_context_of_device_meshes
,
dist_context
]
)
)
if
dist_context_of_device_meshes
is
not
None
:
cost
,
memory
=
self
.
_get_sub_program_cost
(
dist_context_of_device_meshes
)
self
.
_logger
.
info
(
"Cost Model: The max memory is {}GB and cost is {} when {} parallelism under process mesh shape {} on {} stages."
.
format
(
memory
/
(
1024
**
3
),
cost
,
parallelism
,
process_mesh_shape
,
len
(
device_meshes
),
)
)
# 15% buffer is reserved for memory cost
if
memory
>
0.85
*
self
.
cluster
.
machines
[
0
].
devices
[
0
].
memory
*
(
1024
**
3
):
cost
=
sys
.
maxsize
if
cost
<
best_cost
:
best_cost
=
cost
best_dist_context
=
dist_context_of_device_meshes
self
.
_logger
.
info
(
"O1 level: a better strategy has be found that parallelism is {} under process mesh shape {} on {} stages with max memory {}GB."
.
format
(
parallelism
,
process_mesh_shape
,
len
(
device_meshes
),
memory
/
(
1024
**
3
),
)
)
return
best_dist_context
def
tune_o2
(
self
):
return
None
def
save_strategy
(
self
,
best_dist_context
,
path
):
dist_attrs
=
{
"tensor"
:
{},
"op"
:
{},
"process_meshes"
:
[]}
for
key
in
best_dist_context
.
_dist_tensors_for_program
:
if
key
in
self
.
_dist_context
.
_dist_tensors_for_program
:
dist_tensor
=
best_dist_context
.
_dist_tensors_for_program
[
key
]
dist_attrs
[
"tensor"
][
key
]
=
dist_tensor
.
dist_attr
.
serialize_to_string
()
assert
dist_attrs
[
"tensor"
],
"Tensor dist attrs must not be None."
for
key
in
best_dist_context
.
_dist_ops_for_program
:
if
key
in
self
.
_dist_context
.
_dist_ops_for_program
:
dist_op
=
best_dist_context
.
_dist_ops_for_program
[
key
]
dist_attrs
[
"op"
][
key
]
=
dist_op
.
dist_attr
.
serialize_to_string
()
assert
dist_attrs
[
"op"
],
"Op dist attrs must not be None."
for
process_mesh
in
best_dist_context
.
_process_meshes
:
process_ids
=
process_mesh
.
process_ids
process_shape
=
process_mesh
.
shape
dist_attrs
[
"process_meshes"
].
append
([
process_ids
,
process_shape
])
dist_attrs
[
"cluster"
]
=
self
.
_cluster
with
open
(
path
,
'wb'
)
as
f
:
pickle
.
dump
(
dist_attrs
,
f
)
self
.
_logger
.
info
(
"The strategy has been saved at {}"
.
format
(
path
))
def
run_or_quit
(
self
):
# Quit if just tune
if
not
self
.
_is_run
:
self
.
_logger
.
info
(
"The process will be quitted when just tune not run."
)
quit
()
def
tune
(
self
):
begin
=
time
.
time
()
self
.
match_program
(
self
.
_dist_context
.
serial_main_program
)
end
=
time
.
time
()
self
.
_logger
.
info
(
"Pattern match in {}s."
.
format
(
end
-
begin
))
if
self
.
_use_dp
:
completer
=
Completer
(
self
.
_dist_context
)
completer
.
complete_forward_annotation
()
print_program_with_dist_attr
(
self
.
_dist_context
.
serial_main_program
,
self
.
_dist_context
)
# Save strategy if need
path
=
self
.
_strategy_path
if
path
:
self
.
save_strategy
(
self
.
_dist_context
,
path
)
self
.
run_or_quit
()
return
# prepare
self
.
prepare
()
best_dist_context
=
None
if
self
.
level
==
"o2"
:
best_dist_context
=
self
.
tune_o2
()
elif
self
.
level
==
"o1"
:
# If level is o1, it means all layers within same parallelism.
# When in pipeline parallism, it means that place layers evenly.
use_o2_level
=
False
for
device_meshes
in
self
.
device_meshes_list
:
if
len
(
device_meshes
)
>
1
:
shape
=
None
for
device_mesh
in
device_meshes
:
if
shape
is
None
:
shape
=
device_mesh
.
shape
continue
else
:
if
shape
!=
device_mesh
.
shape
:
self
.
_logger
.
info
(
"Warning: The o1 level is not be supported when the number of machines is prime numer which greaters than 1. We will use o2 level to tune."
)
use_o2_level
=
True
break
if
use_o2_level
:
best_dist_context
=
self
.
tune_o2
()
else
:
best_dist_context
=
self
.
tune_o1
()
assert
(
best_dist_context
is
not
None
),
"can not find a parallel strategy to run, please use passes such as recompute, amp or sharding."
for
key
in
best_dist_context
.
_dist_tensors_for_program
:
if
key
in
self
.
_dist_context
.
_dist_tensors_for_program
:
self
.
_dist_context
.
_dist_tensors_for_program
[
key
]
=
best_dist_context
.
_dist_tensors_for_program
[
key
]
for
key
in
best_dist_context
.
_dist_ops_for_program
:
if
key
in
self
.
_dist_context
.
_dist_ops_for_program
:
self
.
_dist_context
.
_dist_ops_for_program
[
key
]
=
best_dist_context
.
_dist_ops_for_program
[
key
]
self
.
_dist_context
.
_process_meshes
=
best_dist_context
.
_process_meshes
end
=
time
.
time
()
self
.
_logger
.
info
(
"Rule-based tuner end in {}s."
.
format
(
end
-
begin
))
self
.
_logger
.
info
(
"The best strategy found is as follows: "
)
print_program_with_dist_attr
(
self
.
full_main_program
,
best_dist_context
)
# Save strategy if need
path
=
self
.
_strategy_path
if
path
:
self
.
save_strategy
(
best_dist_context
,
path
)
self
.
run_or_quit
()
python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py
浏览文件 @
d6011cb6
...
@@ -100,10 +100,10 @@ class TestRuleBasedTuner(unittest.TestCase):
...
@@ -100,10 +100,10 @@ class TestRuleBasedTuner(unittest.TestCase):
modeling
.
init_global
()
modeling
.
init_global
()
train_program
=
static
.
Program
()
train_program
=
static
.
Program
()
start_program
=
static
.
Program
()
start_program
=
static
.
Program
()
place
=
paddle
.
set_device
(
"gpu"
)
batch_size
=
8
batch_size
=
8
sequence_len
=
512
sequence_len
=
512
vocab_size
=
1000
vocab_size
=
1000
place
=
None
train_program
,
start_program
,
loss
,
gen_data
=
get_gpt_model
(
train_program
,
start_program
,
loss
,
gen_data
=
get_gpt_model
(
train_program
,
train_program
,
start_program
,
start_program
,
...
@@ -112,31 +112,29 @@ class TestRuleBasedTuner(unittest.TestCase):
...
@@ -112,31 +112,29 @@ class TestRuleBasedTuner(unittest.TestCase):
sequence_len
,
sequence_len
,
vocab_size
,
vocab_size
,
)
)
from
paddle.distributed.auto_parallel.cluster
import
Cluster
from
paddle.distributed.auto_parallel.dist_context
import
(
from
paddle.distributed.auto_parallel.dist_context
import
(
DistributedContext
,
DistributedContext
,
)
)
from
paddle.distributed.auto_parallel.process_mesh
import
ProcessMesh
from
paddle.distributed.auto_parallel.tuner.rule_based_tuner
import
(
from
paddle.distributed.auto_parallel.tuner.rule_based_tuner
import
(
RuleBasedTuner
,
RuleBasedTuner
,
)
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
0.2
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
0.2
)
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
cluster
=
Cluster
()
cluster
.
gen_default_config_cluster
(
node_count
=
1
,
device_count
=
8
)
dist_context
=
DistributedContext
(
dist_context
=
DistributedContext
(
serial_main_prog
=
train_program
,
serial_main_prog
=
train_program
,
serial_startup_prog
=
start_program
,
serial_startup_prog
=
start_program
,
serial_optimizer
=
opt
,
serial_optimizer
=
opt
,
serial_loss
=
loss
,
serial_loss
=
loss
,
cluster
=
cluster
,
)
)
dist_context
.
initialize
()
dist_context
.
initialize
()
tuner
=
RuleBasedTuner
(
dist_context
)
tuner
=
RuleBasedTuner
(
dist_context
)
tuner
.
cluster_operators
()
tuner
.
tune
()
tuner
.
gen_full_program
()
tuner
.
match_program
(
tuner
.
_dist_context
.
serial_main_program
)
process_mesh
=
ProcessMesh
([
0
,
1
])
tuner
.
gen_fwd_sub_programs_by_clone
()
tuner
.
complete_sub_fwd_programs
(
process_mesh
)
tuner
.
complete_sub_bwd_programs
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录