Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d6011cb6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d6011cb6
编写于
3月 28, 2023
作者:
C
caozhou
提交者:
GitHub
3月 28, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Add o1 level tune (#52041)
* add tune o1 level * add unittest
上级
418b983c
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
820 addition
and
721 deletion
+820
-721
python/paddle/distributed/auto_parallel/cluster.py
python/paddle/distributed/auto_parallel/cluster.py
+76
-14
python/paddle/distributed/auto_parallel/cost/base_cost.py
python/paddle/distributed/auto_parallel/cost/base_cost.py
+38
-25
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
+1
-1
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
+0
-640
python/paddle/distributed/auto_parallel/cost/estimate_cost.py
...on/paddle/distributed/auto_parallel/cost/estimate_cost.py
+82
-32
python/paddle/distributed/auto_parallel/reshard.py
python/paddle/distributed/auto_parallel/reshard.py
+2
-0
python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
...addle/distributed/auto_parallel/tuner/rule_based_tuner.py
+614
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py
...id/tests/unittests/auto_parallel/test_rule_based_tuner.py
+7
-9
未找到文件。
python/paddle/distributed/auto_parallel/cluster.py
浏览文件 @
d6011cb6
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
import
json
import
json
import
os
import
os
import
re
from
enum
import
IntEnum
,
unique
from
enum
import
IntEnum
,
unique
import
paddle
import
paddle
...
@@ -449,7 +450,6 @@ class Cluster:
...
@@ -449,7 +450,6 @@ class Cluster:
npu_models
=
[
"NPU"
]
npu_models
=
[
"NPU"
]
dcu_models
=
[
"DCU"
]
dcu_models
=
[
"DCU"
]
all_gpu_models
=
gpu_models
+
xpu_models
+
npu_models
+
dcu_models
all_gpu_models
=
gpu_models
+
xpu_models
+
npu_models
+
dcu_models
assert
gpu_model
in
all_gpu_models
self
.
_num_devices_per_machine
=
device_count
self
.
_num_devices_per_machine
=
device_count
def
_convert_to_type
(
gpu_model
):
def
_convert_to_type
(
gpu_model
):
...
@@ -462,6 +462,8 @@ class Cluster:
...
@@ -462,6 +462,8 @@ class Cluster:
type
=
"NPU"
type
=
"NPU"
elif
gpu_model
in
dcu_models
:
elif
gpu_model
in
dcu_models
:
type
=
"DCU"
type
=
"DCU"
else
:
type
=
"GPU"
assert
type
is
not
None
assert
type
is
not
None
return
type
return
type
...
@@ -470,6 +472,12 @@ class Cluster:
...
@@ -470,6 +472,12 @@ class Cluster:
model
=
None
model
=
None
if
gpu_model
==
"V100"
:
if
gpu_model
==
"V100"
:
model
=
"Tesla V100-SXM2-"
+
str
(
gpu_memory
)
+
"GB"
model
=
"Tesla V100-SXM2-"
+
str
(
gpu_memory
)
+
"GB"
elif
gpu_model
==
"A100"
:
model
=
"Tesla A100-SXM-"
+
str
(
gpu_memory
)
+
"GB"
elif
gpu_model
==
"A30"
:
model
=
"Tesla A30-SXM-"
+
str
(
gpu_memory
)
+
"GB"
else
:
model
=
gpu_model
+
str
(
gpu_memory
)
+
"GB"
assert
model
is
not
None
assert
model
is
not
None
return
model
return
model
...
@@ -527,6 +535,8 @@ class Cluster:
...
@@ -527,6 +535,8 @@ class Cluster:
device
[
"memory"
]
=
memory
device
[
"memory"
]
=
memory
device
[
"sp_gflops"
]
=
sp_gflops
device
[
"sp_gflops"
]
=
sp_gflops
device
[
"dp_gflops"
]
=
dp_gflops
device
[
"dp_gflops"
]
=
dp_gflops
# hard code
device
[
"type"
]
=
"GPU"
global_id_to_device_type
[
global_id
]
=
type
global_id_to_device_type
[
global_id
]
=
type
global_id_to_node
[
global_id
]
=
i
global_id_to_node
[
global_id
]
=
i
devices
.
append
(
device
)
devices
.
append
(
device
)
...
@@ -820,30 +830,82 @@ class Cluster:
...
@@ -820,30 +830,82 @@ class Cluster:
return
self
.
__str__
()
return
self
.
__str__
()
def
get_default_cluster
():
def
get_default_cluster
(
json_config
=
None
):
def
is_by_json_config
(
json_config
):
if
not
json_config
:
return
False
if
"cluster"
not
in
json_config
:
return
False
else
:
if
"path"
not
in
json_config
[
"cluster"
]:
if
"num_nodes"
not
in
json_config
[
"cluster"
]:
return
False
if
"num_gpus"
not
in
json_config
[
"cluster"
]:
return
False
if
"gpu_model"
not
in
json_config
[
"cluster"
]:
return
False
if
"gpu_memory"
not
in
json_config
[
"cluster"
]:
return
False
return
True
else
:
return
True
cluster
=
Cluster
()
cluster
=
Cluster
()
local_device_count
=
os
.
getenv
(
"PADDLE_LOCAL_SIZE"
)
if
json_config
and
is_by_json_config
(
json_config
):
if
local_device_count
is
None
:
# Get GPU info by json config
local_device_count
=
1
if
"path"
in
json_config
[
"cluster"
]:
else
:
cluster
.
build_from_file
(
json_config
[
"cluster"
][
"path"
])
local_device_count
=
int
(
local_device_count
)
return
cluster
global_device_count
=
os
.
getenv
(
"PADDLE_GLOBAL_SIZE"
)
else
:
if
global_device_count
is
None
:
node_count
=
json_config
[
"cluster"
][
"num_nodes"
]
node_count
=
1
local_device_count
=
json_config
[
"cluster"
][
"num_gpus"
]
gpu_model
=
json_config
[
"cluster"
][
"gpu_model"
]
memory
=
json_config
[
"cluster"
][
"gpu_memory"
]
else
:
else
:
global_device_count
=
int
(
global_device_count
)
# Get GPU info by get_device_properties
assert
global_device_count
%
local_device_count
==
0
local_device_count
=
os
.
getenv
(
"PADDLE_LOCAL_SIZE"
)
node_count
=
int
(
global_device_count
)
//
local_device_count
if
local_device_count
is
None
:
local_device_count
=
1
else
:
local_device_count
=
int
(
local_device_count
)
global_device_count
=
os
.
getenv
(
"PADDLE_GLOBAL_SIZE"
)
if
global_device_count
is
None
:
node_count
=
1
else
:
global_device_count
=
int
(
global_device_count
)
assert
global_device_count
%
local_device_count
==
0
node_count
=
int
(
global_device_count
)
//
local_device_count
gpu_info
=
paddle
.
device
.
cuda
.
get_device_properties
()
assert
gpu_info
,
"Auto parallel just runs on gpu now."
gpu_name
=
gpu_info
.
name
try
:
re_result
=
re
.
split
(
r
'[ , -]'
,
gpu_name
)
gpu_model
=
re_result
[
1
]
memory
=
int
(
re_result
[
-
1
][:
-
2
])
except
:
memory
=
int
(
gpu_info
.
total_memory
)
//
(
1000
**
3
)
gpu_model
=
gpu_name
print
(
print
(
"Node Count: "
,
"Node Count: "
,
node_count
,
node_count
,
"Local Device Size: "
,
"Local Device Size: "
,
local_device_count
,
local_device_count
,
"GPU Model: "
,
gpu_model
,
"GPU Memory: "
,
memory
,
"World size: "
,
"World size: "
,
paddle
.
distributed
.
get_world_size
(),
paddle
.
distributed
.
get_world_size
(),
flush
=
True
,
flush
=
True
,
)
)
cluster
.
gen_default_config_cluster
(
cluster
.
gen_default_config_cluster
(
node_count
=
node_count
,
device_count
=
local_device_count
node_count
=
node_count
,
device_count
=
local_device_count
,
gpu_model
=
gpu_model
,
gpu_memory
=
memory
,
)
)
return
cluster
return
cluster
python/paddle/distributed/auto_parallel/cost/base_cost.py
浏览文件 @
d6011cb6
...
@@ -16,6 +16,7 @@ from collections import OrderedDict
...
@@ -16,6 +16,7 @@ from collections import OrderedDict
from
functools
import
reduce
from
functools
import
reduce
import
paddle
import
paddle
from
paddle.utils.flops
import
flops
from
..cluster
import
LinkType
from
..cluster
import
LinkType
from
..dist_tensor
import
DistributedTensor
from
..dist_tensor
import
DistributedTensor
...
@@ -91,9 +92,10 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
...
@@ -91,9 +92,10 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
output_desc
=
OrderedDict
()
output_desc
=
OrderedDict
()
# Get partitioned shape of input
# Get partitioned shape of input
input_var_desc
=
{}
for
input_name
in
op
.
input_names
:
for
input_name
in
op
.
input_names
:
var_name_list
=
op
.
input
(
input_name
)
var_name_list
=
op
.
input
(
input_name
)
var_desc
=
[]
input_var_desc
[
input_name
]
=
[]
for
var_name
in
var_name_list
:
for
var_name
in
var_name_list
:
var
=
get_var_with_recursion
(
var
=
get_var_with_recursion
(
var_name
,
op
.
block
,
op
.
block
.
program
var_name
,
op
.
block
,
op
.
block
.
program
...
@@ -112,7 +114,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
...
@@ -112,7 +114,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
process
,
process
,
shard_sizes
,
shard_sizes
,
)
)
var_desc
.
append
((
var
.
dtype
,
shape
)
)
input_var_desc
[
input_name
].
append
(
shape
)
# For special op such as embedding and its grad op
# For special op such as embedding and its grad op
if
(
if
(
...
@@ -137,8 +139,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
...
@@ -137,8 +139,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
relative_idx
=
relative_idx
*
per_part_size
relative_idx
=
relative_idx
*
per_part_size
desc
[
"attrs"
][
"start_index"
]
=
relative_idx
desc
[
"attrs"
][
"start_index"
]
=
relative_idx
input_desc
[
input_name
]
=
var_desc
desc
[
"inputs"
]
=
input_var_desc
desc
[
"inputs"
]
=
input_desc
for
out_name
in
op
.
output_names
:
for
out_name
in
op
.
output_names
:
var_name_list
=
op
.
output
(
out_name
)
var_name_list
=
op
.
output
(
out_name
)
...
@@ -350,7 +351,9 @@ def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None):
...
@@ -350,7 +351,9 @@ def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None):
return
desc
return
desc
def
build_comm_costs_from_descs
(
op_cost_class
,
ctx
,
processes
,
descs
,
cluster
):
def
build_comm_costs_from_descs
(
op_cost_class
,
ctx
,
processes
,
descs
,
cluster
,
is_dp
=
False
):
"""Build comm costs by descriptions"""
"""Build comm costs by descriptions"""
comm_context
=
CommContext
(
cluster
)
comm_context
=
CommContext
(
cluster
)
group_ranks_list
=
[]
group_ranks_list
=
[]
...
@@ -363,6 +366,8 @@ def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster):
...
@@ -363,6 +366,8 @@ def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster):
comm_op_cost
=
op_cost_class
(
comm_op_cost
=
op_cost_class
(
op_desc
=
desc
,
comm_context
=
comm_context
op_desc
=
desc
,
comm_context
=
comm_context
)
)
if
is_dp
:
comm_op_cost
.
cost
.
time
*=
0.9
comm_op_cost_list
.
append
(
comm_op_cost
)
comm_op_cost_list
.
append
(
comm_op_cost
)
return
comm_op_cost_list
return
comm_op_cost_list
...
@@ -389,6 +394,7 @@ def build_dp_costs(
...
@@ -389,6 +394,7 @@ def build_dp_costs(
vars
=
dist_op
.
serial_op
.
block
.
vars
vars
=
dist_op
.
serial_op
.
block
.
vars
var_name
=
var_names
[
0
]
var_name
=
var_names
[
0
]
has_found
=
False
has_found
=
False
is_input
=
True
for
name
in
dist_op
.
serial_op
.
input_arg_names
:
for
name
in
dist_op
.
serial_op
.
input_arg_names
:
if
var_name
in
name
:
if
var_name
in
name
:
var_name
=
name
var_name
=
name
...
@@ -400,6 +406,7 @@ def build_dp_costs(
...
@@ -400,6 +406,7 @@ def build_dp_costs(
if
var_name
in
name
:
if
var_name
in
name
:
var_name
=
name
var_name
=
name
has_found
=
True
has_found
=
True
is_input
=
False
break
break
if
not
has_found
:
if
not
has_found
:
return
return
...
@@ -418,6 +425,7 @@ def build_dp_costs(
...
@@ -418,6 +425,7 @@ def build_dp_costs(
processes
,
processes
,
c_allreduce_sum_descs
,
c_allreduce_sum_descs
,
cluster
,
cluster
,
is_dp
=
True
,
)
)
result
.
append
(
comm_cost_list
)
result
.
append
(
comm_cost_list
)
...
@@ -431,22 +439,11 @@ def build_dp_costs(
...
@@ -431,22 +439,11 @@ def build_dp_costs(
desc
=
{}
desc
=
{}
desc
[
"op"
]
=
op_type
desc
[
"op"
]
=
op_type
desc
[
"inputs"
]
=
{}
desc
[
"inputs"
]
=
{}
if
var_name
in
dist_attr
.
inputs_dist_attrs
:
dims_mapping
=
(
dims_mapping
=
dist_attr
.
get_input_dims_mapping
(
var_name
)
dist_attr
.
get_input_dims_mapping
(
var_name
)
elif
var_name
in
dist_attr
.
outputs_dist_attrs
:
if
is_input
dims_mapping
=
dist_attr
.
get_output_dims_mapping
(
var_name
)
else
dist_attr
.
get_output_dims_mapping
(
var_name
)
else
:
)
raise
AssertionError
(
"cannot find dims_mapping for {} in {}"
.
format
(
var_name
,
dist_attr
)
)
# dims_mapping = (
# dist_attr.get_input_dims_mapping(var_name)
# if dist_attr.get_input_dims_mapping(var_name) is not None
# else dist_attr.get_output_dims_mapping(var_name)
# )
var
=
get_var_with_recursion
(
var
=
get_var_with_recursion
(
var_name
,
var_name
,
dist_op
.
serial_op
.
block
,
dist_op
.
serial_op
.
block
,
...
@@ -493,8 +490,6 @@ class CommContext:
...
@@ -493,8 +490,6 @@ class CommContext:
# if cluster has no info about those vars, it will be set by default
# if cluster has no info about those vars, it will be set by default
self
.
base_ring
=
None
self
.
base_ring
=
None
self
.
base_tree
=
None
self
.
base_tree
=
None
# self.base_inter_ring = None
# self.base_inter_tree = None
self
.
intra_ring
=
None
self
.
intra_ring
=
None
self
.
intra_tree
=
None
self
.
intra_tree
=
None
self
.
inter_ring
=
None
self
.
inter_ring
=
None
...
@@ -508,8 +503,6 @@ class CommContext:
...
@@ -508,8 +503,6 @@ class CommContext:
# set default
# set default
self
.
base_ring
=
8.4
self
.
base_ring
=
8.4
self
.
base_tree
=
0.0
self
.
base_tree
=
0.0
# self.base_inter_ring = 9.6
# self.base_inter_tree = 28
# NVL in default
# NVL in default
self
.
intra_ring
=
3.4
self
.
intra_ring
=
3.4
self
.
intra_tree
=
28
self
.
intra_tree
=
28
...
@@ -681,6 +674,8 @@ class Cost:
...
@@ -681,6 +674,8 @@ class Cost:
class
OpCost
:
class
OpCost
:
OP_TYPE
=
"op"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
):
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
):
self
.
_op
=
op
self
.
_op
=
op
self
.
_op_desc
=
op_desc
self
.
_op_desc
=
op_desc
...
@@ -883,6 +878,24 @@ class CompOpCost(OpCost):
...
@@ -883,6 +878,24 @@ class CompOpCost(OpCost):
)
)
)
)
def
calc_flops
(
self
):
if
not
self
.
op_desc
:
return
0
if
"_grad"
in
self
.
__class__
.
OP_TYPE
:
op_type
=
self
.
__class__
.
OP_TYPE
[:
len
(
self
.
__class__
.
OP_TYPE
)
-
5
]
return
2
*
flops
(
op_type
,
self
.
op_desc
[
"inputs"
],
self
.
op_desc
[
"attrs"
]
)
return
flops
(
self
.
__class__
.
OP_TYPE
,
self
.
op_desc
[
"inputs"
],
self
.
op_desc
[
"attrs"
],
)
def
calc_time
(
self
):
flops_count
=
self
.
calc_flops
()
return
flops_count
*
2.9e-7
def
register_op_cost
(
cls
):
def
register_op_cost
(
cls
):
op_type
=
cls
.
OP_TYPE
op_type
=
cls
.
OP_TYPE
...
...
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
浏览文件 @
d6011cb6
...
@@ -140,7 +140,7 @@ class IdentityOpCost(CommOpCost):
...
@@ -140,7 +140,7 @@ class IdentityOpCost(CommOpCost):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
def
calc_time
(
self
):
def
calc_time
(
self
):
return
0
return
self
.
comm_count
*
1
/
(
144
*
1e3
)
@
register_op_cost
@
register_op_cost
...
...
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
浏览文件 @
d6011cb6
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/cost/estimate_cost.py
浏览文件 @
d6011cb6
...
@@ -189,6 +189,9 @@ class CostEstimator:
...
@@ -189,6 +189,9 @@ class CostEstimator:
# Calc dist op cost
# Calc dist op cost
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_op
:
continue
op_dist_attr
=
dist_op
.
dist_attr
op_dist_attr
=
dist_op
.
dist_attr
processes
=
op_dist_attr
.
process_mesh
.
process_ids
processes
=
op_dist_attr
.
process_mesh
.
process_ids
...
@@ -225,6 +228,8 @@ class CostEstimator:
...
@@ -225,6 +228,8 @@ class CostEstimator:
for
rank
in
group_ranks
:
for
rank
in
group_ranks
:
self
.
local_cost
(
rank
).
time
=
(
self
.
local_cost
(
rank
).
time
=
(
max_time
+
comm_op_cost
.
time
max_time
+
comm_op_cost
.
time
if
op
.
attr
(
'op_role'
)
!=
OpRole
.
Backward
else
max_time
+
0.9
*
comm_op_cost
.
time
)
)
if
rank
not
in
self
.
_bubble_time_mapping
:
if
rank
not
in
self
.
_bubble_time_mapping
:
self
.
_bubble_time_mapping
[
rank
]
=
0
self
.
_bubble_time_mapping
[
rank
]
=
0
...
@@ -290,6 +295,7 @@ class CostEstimator:
...
@@ -290,6 +295,7 @@ class CostEstimator:
self
.
_ordered_ops
.
append
([
op
.
desc
.
id
(),
op
])
self
.
_ordered_ops
.
append
([
op
.
desc
.
id
(),
op
])
self
.
_ordered_ops
.
sort
(
key
=
lambda
x
:
x
[
0
])
self
.
_ordered_ops
.
sort
(
key
=
lambda
x
:
x
[
0
])
parameters
=
set
()
for
op_id
,
op
in
self
.
_ordered_ops
:
for
op_id
,
op
in
self
.
_ordered_ops
:
if
op
.
type
in
[
if
op
.
type
in
[
"create_py_reader"
,
"create_py_reader"
,
...
@@ -298,11 +304,14 @@ class CostEstimator:
...
@@ -298,11 +304,14 @@ class CostEstimator:
]:
]:
continue
continue
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_op
:
continue
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
for
var_name
in
op
.
input_arg_names
:
for
var_name
in
op
.
input_arg_names
:
input_dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
input_dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
var_name
var_name
)
)
if
var_name
not
in
var_info
:
if
var_name
not
in
var_info
:
var_info
[
var_name
]
=
{}
var_info
[
var_name
]
=
{}
key
=
_convert_pm_and_dm_to_str
(
key
=
_convert_pm_and_dm_to_str
(
...
@@ -311,6 +320,10 @@ class CostEstimator:
...
@@ -311,6 +320,10 @@ class CostEstimator:
if
key
not
in
var_info
[
var_name
]:
if
key
not
in
var_info
[
var_name
]:
var_info
[
var_name
][
key
]
=
{}
var_info
[
var_name
][
key
]
=
{}
# It is even partition now
# It is even partition now
if
"position"
not
in
var_info
[
var_name
][
key
]:
var_info
[
var_name
][
key
][
"position"
]
=
[]
var_info
[
var_name
][
key
][
"position"
].
append
(
op_id
)
if
"memory"
not
in
var_info
[
var_name
][
key
]:
if
"memory"
not
in
var_info
[
var_name
][
key
]:
var
=
dist_op
.
get_serial_input
(
var_name
)
var
=
dist_op
.
get_serial_input
(
var_name
)
global_sizes
=
var
.
shape
global_sizes
=
var
.
shape
...
@@ -324,9 +337,16 @@ class CostEstimator:
...
@@ -324,9 +337,16 @@ class CostEstimator:
var_info
[
var_name
][
key
][
"memory"
]
=
self
.
_calculate_bytes
(
var_info
[
var_name
][
key
][
"memory"
]
=
self
.
_calculate_bytes
(
sizes
,
dtype
sizes
,
dtype
)
)
if
"position"
not
in
var_info
[
var_name
][
key
]:
if
var
.
persistable
:
var_info
[
var_name
][
key
][
"position"
]
=
[]
name
=
var_name
+
key
var_info
[
var_name
][
key
][
"position"
].
append
(
op_id
)
if
name
not
in
parameters
:
parameters
.
add
(
name
)
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
memories
:
memories
[
process
]
=
0
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
for
var_name
in
op
.
output_arg_names
:
for
var_name
in
op
.
output_arg_names
:
output_dims_mapping
=
dist_op
.
dist_attr
.
get_output_dims_mapping
(
output_dims_mapping
=
dist_op
.
dist_attr
.
get_output_dims_mapping
(
...
@@ -339,6 +359,10 @@ class CostEstimator:
...
@@ -339,6 +359,10 @@ class CostEstimator:
)
)
if
key
not
in
var_info
[
var_name
]:
if
key
not
in
var_info
[
var_name
]:
var_info
[
var_name
][
key
]
=
{}
var_info
[
var_name
][
key
]
=
{}
if
"position"
not
in
var_info
[
var_name
][
key
]:
var_info
[
var_name
][
key
][
"position"
]
=
[]
var_info
[
var_name
][
key
][
"position"
].
append
(
op_id
)
if
"memory"
not
in
var_info
[
var_name
][
key
]:
if
"memory"
not
in
var_info
[
var_name
][
key
]:
var
=
dist_op
.
get_serial_output
(
var_name
)
var
=
dist_op
.
get_serial_output
(
var_name
)
global_sizes
=
var
.
shape
global_sizes
=
var
.
shape
...
@@ -352,11 +376,19 @@ class CostEstimator:
...
@@ -352,11 +376,19 @@ class CostEstimator:
var_info
[
var_name
][
key
][
"memory"
]
=
self
.
_calculate_bytes
(
var_info
[
var_name
][
key
][
"memory"
]
=
self
.
_calculate_bytes
(
sizes
,
dtype
sizes
,
dtype
)
)
if
"position"
not
in
var_info
[
var_name
][
key
]:
if
var
.
persistable
:
var_info
[
var_name
][
key
][
"position"
]
=
[]
name
=
var_name
+
key
var_info
[
var_name
][
key
][
"position"
].
append
(
op_id
)
if
name
not
in
parameters
:
parameters
.
add
(
name
)
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
memories
:
memories
[
process
]
=
0
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
has_used_vars
=
set
()
has_used_vars
=
set
()
not_calc_vars
=
set
()
for
op_id
,
op
in
self
.
_ordered_ops
:
for
op_id
,
op
in
self
.
_ordered_ops
:
if
op
.
type
in
[
if
op
.
type
in
[
"create_py_reader"
,
"create_py_reader"
,
...
@@ -367,6 +399,8 @@ class CostEstimator:
...
@@ -367,6 +399,8 @@ class CostEstimator:
can_free_memories
=
{}
can_free_memories
=
{}
can_free_vars
=
set
()
can_free_vars
=
set
()
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_op
:
continue
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
for
var_name
in
op
.
input_arg_names
:
for
var_name
in
op
.
input_arg_names
:
input_dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
input_dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
...
@@ -378,24 +412,30 @@ class CostEstimator:
...
@@ -378,24 +412,30 @@ class CostEstimator:
has_used_var
=
var_name
+
key
has_used_var
=
var_name
+
key
var
=
dist_op
.
get_serial_input
(
var_name
)
var
=
dist_op
.
get_serial_input
(
var_name
)
# Not used
# Not used
if
var_name
+
key
not
in
has_used_vars
:
if
(
has_used_var
not
in
has_used_vars
and
has_used_var
not
in
parameters
):
if
has_used_var
in
not_calc_vars
:
continue
has_used_vars
.
add
(
has_used_var
)
has_used_vars
.
add
(
has_used_var
)
for
process
in
process_mesh
.
process_ids
:
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
memories
:
if
process
not
in
memories
:
memories
[
process
]
=
0
memories
[
process
]
=
0
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
# Used
# Used
else
:
if
op_id
==
var_info
[
var_name
][
key
][
"position"
][
-
1
]:
if
op_id
==
var_info
[
var_name
][
key
][
"position"
][
-
1
]:
if
(
if
has_used_var
not
in
can_free_vars
:
has_used_var
not
in
can_free_vars
can_free_vars
.
add
(
has_used_var
)
and
not
var
.
persistable
if
not
var
.
persistable
:
):
for
process
in
process_mesh
.
process_ids
:
can_free_vars
.
add
(
has_used_var
)
if
process
not
in
can_free_memories
:
for
process
in
process_mesh
.
process_ids
:
can_free_memories
[
process
]
=
0
if
process
not
in
can_free_memories
:
can_free_memories
[
process
]
+=
var_info
[
can_free_memories
[
process
]
=
0
var_name
can_free_memories
[
process
]
+=
var_info
[
var_name
][
][
key
][
"memory"
]
key
][
"memory"
]
for
var_name
in
op
.
output_arg_names
:
for
var_name
in
op
.
output_arg_names
:
output_dims_mapping
=
dist_op
.
dist_attr
.
get_output_dims_mapping
(
output_dims_mapping
=
dist_op
.
dist_attr
.
get_output_dims_mapping
(
...
@@ -406,25 +446,36 @@ class CostEstimator:
...
@@ -406,25 +446,36 @@ class CostEstimator:
)
)
has_used_var
=
var_name
+
key
has_used_var
=
var_name
+
key
var
=
dist_op
.
get_serial_output
(
var_name
)
var
=
dist_op
.
get_serial_output
(
var_name
)
if
(
op
.
type
==
"reshape2"
or
op
.
type
==
"transpose2"
or
op
.
type
==
"elementwise_add"
):
not_calc_vars
.
add
(
has_used_var
)
continue
# Not used
# Not used
if
var_name
+
key
not
in
has_used_vars
:
if
(
has_used_var
not
in
has_used_vars
and
has_used_var
not
in
parameters
):
has_used_vars
.
add
(
has_used_var
)
has_used_vars
.
add
(
has_used_var
)
for
process
in
process_mesh
.
process_ids
:
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
memories
:
if
process
not
in
memories
:
memories
[
process
]
=
0
memories
[
process
]
=
0
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
# Used
# Used
else
:
if
op_id
==
var_info
[
var_name
][
key
][
"position"
][
-
1
]:
if
op_id
==
var_info
[
var_name
][
key
][
"position"
][
-
1
]:
if
(
if
has_used_var
not
in
can_free_vars
:
has_used_var
not
in
can_free_vars
can_free_vars
.
add
(
has_used_var
)
and
not
var
.
persistable
if
not
var
.
persistable
:
):
for
process
in
process_mesh
.
process_ids
:
can_free_vars
.
add
(
has_used_var
)
if
process
not
in
can_free_memories
:
for
process
in
process_mesh
.
process_ids
:
can_free_memories
[
process
]
=
0
if
process
not
in
can_free_memories
:
can_free_memories
[
process
]
+=
var_info
[
can_free_memories
[
process
]
=
0
var_name
can_free_memories
[
process
]
+=
var_info
[
var_name
][
][
key
][
"memory"
]
key
][
"memory"
]
# Calc peak memory
# Calc peak memory
for
process
in
memories
:
for
process
in
memories
:
...
@@ -433,7 +484,6 @@ class CostEstimator:
...
@@ -433,7 +484,6 @@ class CostEstimator:
else
:
else
:
if
memories
[
process
]
>
self
.
max_memories
[
process
]:
if
memories
[
process
]
>
self
.
max_memories
[
process
]:
self
.
max_memories
[
process
]
=
memories
[
process
]
self
.
max_memories
[
process
]
=
memories
[
process
]
# Free memory
# Free memory
for
process
in
can_free_memories
:
for
process
in
can_free_memories
:
if
process
in
memories
:
if
process
in
memories
:
...
@@ -513,7 +563,7 @@ class CostEstimator:
...
@@ -513,7 +563,7 @@ class CostEstimator:
# Padding automatically
# Padding automatically
max_len
=
0
max_len
=
0
header
=
[
"Execution Time(
m
s)"
,
"Max Memory(MiB)"
]
header
=
[
"Execution Time(
u
s)"
,
"Max Memory(MiB)"
]
vals
=
[
round
(
self
.
global_cost
.
time
,
3
),
int
(
self
.
max_memory
//
1e6
)]
vals
=
[
round
(
self
.
global_cost
.
time
,
3
),
int
(
self
.
max_memory
//
1e6
)]
for
memory
in
vals
+
header
:
for
memory
in
vals
+
header
:
if
len
(
str
(
memory
))
>
max_len
:
if
len
(
str
(
memory
))
>
max_len
:
...
...
python/paddle/distributed/auto_parallel/reshard.py
浏览文件 @
d6011cb6
...
@@ -2716,6 +2716,8 @@ class Resharder:
...
@@ -2716,6 +2716,8 @@ class Resharder:
)
)
# simplified processing: ignore union process mesh and output reshard
# simplified processing: ignore union process mesh and output reshard
dist_op
=
self
.
dist_context
.
get_dist_op_for_program
(
op
)
dist_op
=
self
.
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_tensor
or
not
dist_op
:
return
reshard_op_cost
dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
tensor
.
name
tensor
.
name
)
)
...
...
python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
浏览文件 @
d6011cb6
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py
浏览文件 @
d6011cb6
...
@@ -100,10 +100,10 @@ class TestRuleBasedTuner(unittest.TestCase):
...
@@ -100,10 +100,10 @@ class TestRuleBasedTuner(unittest.TestCase):
modeling
.
init_global
()
modeling
.
init_global
()
train_program
=
static
.
Program
()
train_program
=
static
.
Program
()
start_program
=
static
.
Program
()
start_program
=
static
.
Program
()
place
=
paddle
.
set_device
(
"gpu"
)
batch_size
=
8
batch_size
=
8
sequence_len
=
512
sequence_len
=
512
vocab_size
=
1000
vocab_size
=
1000
place
=
None
train_program
,
start_program
,
loss
,
gen_data
=
get_gpt_model
(
train_program
,
start_program
,
loss
,
gen_data
=
get_gpt_model
(
train_program
,
train_program
,
start_program
,
start_program
,
...
@@ -112,31 +112,29 @@ class TestRuleBasedTuner(unittest.TestCase):
...
@@ -112,31 +112,29 @@ class TestRuleBasedTuner(unittest.TestCase):
sequence_len
,
sequence_len
,
vocab_size
,
vocab_size
,
)
)
from
paddle.distributed.auto_parallel.cluster
import
Cluster
from
paddle.distributed.auto_parallel.dist_context
import
(
from
paddle.distributed.auto_parallel.dist_context
import
(
DistributedContext
,
DistributedContext
,
)
)
from
paddle.distributed.auto_parallel.process_mesh
import
ProcessMesh
from
paddle.distributed.auto_parallel.tuner.rule_based_tuner
import
(
from
paddle.distributed.auto_parallel.tuner.rule_based_tuner
import
(
RuleBasedTuner
,
RuleBasedTuner
,
)
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
0.2
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
0.2
)
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
cluster
=
Cluster
()
cluster
.
gen_default_config_cluster
(
node_count
=
1
,
device_count
=
8
)
dist_context
=
DistributedContext
(
dist_context
=
DistributedContext
(
serial_main_prog
=
train_program
,
serial_main_prog
=
train_program
,
serial_startup_prog
=
start_program
,
serial_startup_prog
=
start_program
,
serial_optimizer
=
opt
,
serial_optimizer
=
opt
,
serial_loss
=
loss
,
serial_loss
=
loss
,
cluster
=
cluster
,
)
)
dist_context
.
initialize
()
dist_context
.
initialize
()
tuner
=
RuleBasedTuner
(
dist_context
)
tuner
=
RuleBasedTuner
(
dist_context
)
tuner
.
cluster_operators
()
tuner
.
tune
()
tuner
.
gen_full_program
()
tuner
.
match_program
(
tuner
.
_dist_context
.
serial_main_program
)
process_mesh
=
ProcessMesh
([
0
,
1
])
tuner
.
gen_fwd_sub_programs_by_clone
()
tuner
.
complete_sub_fwd_programs
(
process_mesh
)
tuner
.
complete_sub_bwd_programs
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录