Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
118a7415
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
118a7415
编写于
4月 17, 2023
作者:
C
caozhou
提交者:
GitHub
4月 17, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel]Add o2 tune of rule based tuner (#52928)
* add o2 tune * add unittest * fix error * set unittest timeout
上级
d7659ce4
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
582 addition
and
44 deletion
+582
-44
python/paddle/distributed/auto_parallel/cluster.py
python/paddle/distributed/auto_parallel/cluster.py
+15
-12
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+21
-1
python/paddle/distributed/auto_parallel/planner_v2.py
python/paddle/distributed/auto_parallel/planner_v2.py
+129
-7
python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
...addle/distributed/auto_parallel/tuner/rule_based_tuner.py
+273
-23
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_parallel_tuner_full.py
...tests/unittests/auto_parallel/test_parallel_tuner_full.py
+1
-1
python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner_o2.py
...tests/unittests/auto_parallel/test_rule_based_tuner_o2.py
+141
-0
未找到文件。
python/paddle/distributed/auto_parallel/cluster.py
浏览文件 @
118a7415
...
@@ -13,12 +13,15 @@
...
@@ -13,12 +13,15 @@
# limitations under the License.
# limitations under the License.
import
json
import
json
import
logging
import
os
import
os
import
re
import
re
from
enum
import
IntEnum
,
unique
from
enum
import
IntEnum
,
unique
import
paddle
import
paddle
from
..utils.log_utils
import
get_logger
@
unique
@
unique
class
DeviceType
(
IntEnum
):
class
DeviceType
(
IntEnum
):
...
@@ -830,6 +833,9 @@ class Cluster:
...
@@ -830,6 +833,9 @@ class Cluster:
return
self
.
__str__
()
return
self
.
__str__
()
logger
=
get_logger
(
logging
.
INFO
)
def
get_default_cluster
(
json_config
=
None
):
def
get_default_cluster
(
json_config
=
None
):
def
is_by_json_config
(
json_config
):
def
is_by_json_config
(
json_config
):
if
not
json_config
:
if
not
json_config
:
...
@@ -889,18 +895,15 @@ def get_default_cluster(json_config=None):
...
@@ -889,18 +895,15 @@ def get_default_cluster(json_config=None):
memory
=
int
(
gpu_info
.
total_memory
)
//
(
1000
**
3
)
memory
=
int
(
gpu_info
.
total_memory
)
//
(
1000
**
3
)
gpu_model
=
gpu_name
gpu_model
=
gpu_name
print
(
logger
.
info
(
"Node Count: "
,
"Node Count: {}, Local Device Size: {}, GPU Model: {}, GPU Memory: {}GB, World size: {}, EndPoint: {}."
.
format
(
node_count
,
node_count
,
"Local Device Size: "
,
local_device_count
,
local_device_count
,
gpu_model
,
"GPU Model: "
,
memory
,
gpu_model
,
paddle
.
distributed
.
get_world_size
(),
"GPU Memory: "
,
os
.
getenv
(
"PADDLE_CURRENT_ENDPOINT"
,
None
),
memory
,
)
"World size: "
,
paddle
.
distributed
.
get_world_size
(),
flush
=
True
,
)
)
cluster
.
gen_default_config_cluster
(
cluster
.
gen_default_config_cluster
(
node_count
=
node_count
,
node_count
=
node_count
,
...
...
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
118a7415
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
copy
import
copy
import
json
import
logging
import
logging
import
numbers
import
numbers
import
os
import
os
...
@@ -177,6 +178,23 @@ class Engine:
...
@@ -177,6 +178,23 @@ class Engine:
self
.
_strategy
=
strategy
or
Strategy
()
self
.
_strategy
=
strategy
or
Strategy
()
self
.
_logger
=
get_logger
(
logging
.
INFO
)
self
.
_logger
=
get_logger
(
logging
.
INFO
)
self
.
_json_config
=
None
if
cluster
:
self
.
_cluster
=
cluster
else
:
if
os
.
getenv
(
"PADDLE_AUTO_PARALLEL_CONFIG"
):
try
:
path
=
os
.
getenv
(
"PADDLE_AUTO_PARALLEL_CONFIG"
)
with
open
(
path
,
"r"
)
as
f
:
self
.
_json_config
=
json
.
load
(
f
)
except
Exception
as
e
:
self
.
_logger
.
info
(
"Load json failed, please check json file, engine will run default config."
)
self
.
_json_config
=
None
self
.
_cluster
=
get_default_cluster
(
self
.
_json_config
)
if
os
.
getenv
(
"POD_NAME"
):
if
os
.
getenv
(
"POD_NAME"
):
self
.
_logger
.
info
(
self
.
_logger
.
info
(
"Distribute training by paddle.distributed.launch"
"Distribute training by paddle.distributed.launch"
...
@@ -653,6 +671,7 @@ class Engine:
...
@@ -653,6 +671,7 @@ class Engine:
fetch_vars
,
fetch_vars
,
self
.
_cluster
,
self
.
_cluster
,
self
.
_strategy
,
self
.
_strategy
,
self
.
_json_config
,
)
)
self
.
_fwd_dist_contexts
[
mode
]
=
DistributedContext
(
self
.
_fwd_dist_contexts
[
mode
]
=
DistributedContext
(
serial_main_prog
,
serial_main_prog
,
...
@@ -663,6 +682,7 @@ class Engine:
...
@@ -663,6 +682,7 @@ class Engine:
fetch_vars
,
fetch_vars
,
self
.
_cluster
,
self
.
_cluster
,
self
.
_strategy
,
self
.
_strategy
,
self
.
_json_config
,
)
)
self
.
_dist_contexts
[
mode
].
gradient_scale
=
self
.
_strategy
.
gradient_scale
self
.
_dist_contexts
[
mode
].
gradient_scale
=
self
.
_strategy
.
gradient_scale
self
.
_fwd_main_progs
[
mode
]
=
serial_main_prog
.
clone
()
self
.
_fwd_main_progs
[
mode
]
=
serial_main_prog
.
clone
()
...
@@ -769,7 +789,7 @@ class Engine:
...
@@ -769,7 +789,7 @@ class Engine:
# instantiate communication by process_mapping.
# instantiate communication by process_mapping.
all_process_groups
=
get_all_process_groups
()
all_process_groups
=
get_all_process_groups
()
if
self
.
_strategy
.
auto_mode
==
"full"
:
if
self
.
_strategy
.
auto_mode
==
"full
_random
"
:
auto_utils
.
initialize_pg_in_full_mode
(
auto_utils
.
initialize_pg_in_full_mode
(
all_process_groups
,
self
.
_cur_rank
all_process_groups
,
self
.
_cur_rank
)
)
...
...
python/paddle/distributed/auto_parallel/planner_v2.py
浏览文件 @
118a7415
...
@@ -12,9 +12,25 @@
...
@@ -12,9 +12,25 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
logging
import
os
import
pickle
import
numpy
as
np
from
paddle.distributed.auto_parallel.dist_attribute
import
(
OperatorDistAttr
,
TensorDistAttr
,
)
from
paddle.distributed.auto_parallel.dist_op
import
DistributedOperator
from
paddle.distributed.auto_parallel.dist_tensor
import
DistributedTensor
from
paddle.distributed.auto_parallel.process_mesh
import
ProcessMesh
from
..utils.log_utils
import
get_logger
from
.completion
import
Completer
from
.completion
import
Completer
from
.dist_context
import
get_default_distributed_context
from
.dist_context
import
get_default_distributed_context
from
.tuner.parallel_tuner
import
ParallelTuner
from
.tuner.parallel_tuner
import
ParallelTuner
from
.tuner.rule_based_tuner
import
RuleBasedTuner
from
.utils
import
is_naive_data_parallel
from
.utils
import
is_naive_data_parallel
...
@@ -22,6 +38,7 @@ class Planner:
...
@@ -22,6 +38,7 @@ class Planner:
def
__init__
(
self
,
mode
,
dist_context
):
def
__init__
(
self
,
mode
,
dist_context
):
self
.
_mode
=
mode
self
.
_mode
=
mode
self
.
_dist_context
=
dist_context
self
.
_dist_context
=
dist_context
self
.
_load
=
False
# NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need
# NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need
# dependency of backward-forward ops in forward completion.
# dependency of backward-forward ops in forward completion.
...
@@ -29,30 +46,135 @@ class Planner:
...
@@ -29,30 +46,135 @@ class Planner:
self
.
_dist_context
.
_dist_op_context
=
default_ctx
.
dist_op_context
self
.
_dist_context
.
_dist_op_context
=
default_ctx
.
dist_op_context
self
.
_dist_context
.
data_parallel
=
default_ctx
.
data_parallel
self
.
_dist_context
.
data_parallel
=
default_ctx
.
data_parallel
if
not
is_naive_data_parallel
(
self
.
_dist_context
):
if
not
is_naive_data_parallel
(
self
.
_dist_context
):
# Use SSA graph for complex parall
el
ism
# Use SSA graph for complex parallism
self
.
_dist_context
.
initialize
(
with_graph
=
True
)
self
.
_dist_context
.
initialize
(
with_graph
=
True
)
else
:
else
:
# Use program for data parallel parall
el
ism
# Use program for data parallel parallism
self
.
_dist_context
.
initialize
(
with_graph
=
False
)
self
.
_dist_context
.
initialize
(
with_graph
=
False
)
self
.
_completer
=
Completer
(
self
.
_dist_context
)
self
.
_completer
=
Completer
(
self
.
_dist_context
)
self
.
_strategy
=
dist_context
.
strategy
self
.
_strategy
=
dist_context
.
strategy
# set parallel tuner for auto search
# set parallel tuner for auto search
if
self
.
_strategy
.
auto_mode
==
"full"
:
if
self
.
_strategy
.
auto_mode
==
"full
_random
"
:
self
.
_parallel_tuner
=
ParallelTuner
(
self
.
_parallel_tuner
=
ParallelTuner
(
self
.
_dist_context
,
mode
=
self
.
_mode
self
.
_dist_context
,
mode
=
self
.
_mode
)
)
elif
self
.
_strategy
.
auto_mode
==
"full_rule_based"
:
self
.
_parallel_tuner
=
RuleBasedTuner
(
self
.
_dist_context
,
mode
=
self
.
_mode
)
@
property
@
property
def
completer
(
self
):
def
completer
(
self
):
return
self
.
_completer
return
self
.
_completer
def
plan
(
self
):
def
plan
(
self
):
if
self
.
_strategy
.
auto_mode
==
"full"
:
logger
=
get_logger
(
logging
.
INFO
)
self
.
_parallel_tuner
.
tune
()
path
=
None
else
:
if
self
.
_dist_context
.
_json_config
:
self
.
_completer
.
complete_forward_annotation
()
try
:
path
=
self
.
_dist_context
.
_json_config
[
"tuner_load_path"
]
except
:
path
=
None
if
path
and
os
.
path
.
exists
(
path
):
try
:
with
open
(
path
,
"rb"
)
as
f
:
dist_attrs
=
pickle
.
load
(
f
)
tensor_dist_attrs
=
dist_attrs
[
"tensor"
]
op_dist_attrs
=
dist_attrs
[
"op"
]
process_meshes
=
dist_attrs
[
"process_meshes"
]
cluster
=
dist_attrs
[
"cluster"
]
last_gpu_model
=
cluster
.
machines
[
0
].
devices
[
0
].
model
last_gpu_memory
=
cluster
.
machines
[
0
].
devices
[
0
].
memory
last_node_count
=
len
(
cluster
.
machines
)
last_device_count
=
len
(
cluster
.
get_all_devices
(
"GPU"
))
gpu_model
=
(
self
.
_dist_context
.
cluster
.
machines
[
0
].
devices
[
0
].
model
)
gpu_memory
=
(
self
.
_dist_context
.
cluster
.
machines
[
0
].
devices
[
0
].
memory
)
node_count
=
len
(
self
.
_dist_context
.
cluster
.
machines
)
device_count
=
len
(
self
.
_dist_context
.
cluster
.
get_all_devices
(
"GPU"
)
)
if
(
gpu_model
!=
last_gpu_model
or
gpu_memory
!=
last_gpu_memory
or
last_node_count
!=
node_count
or
device_count
!=
last_device_count
):
logger
.
info
(
"The cluster {} nodes {} {} devices is different from the saved last cluster {} nodes {} {} devices, so we run the planner again."
.
format
(
node_count
,
device_count
,
gpu_model
,
last_node_count
,
last_device_count
,
last_gpu_model
,
)
)
need_set_dist_attr
=
False
else
:
need_set_dist_attr
=
True
except
:
need_set_dist_attr
=
False
if
need_set_dist_attr
:
for
key
in
op_dist_attrs
:
serial_op
=
self
.
_dist_context
.
_dist_ops_for_program
[
key
].
serial_op
# clear dist attr
serial_op
.
dist_attr
=
OperatorDistAttr
(
serial_op
.
desc
)
serial_op
.
dist_attr
.
parse_from_string
(
op_dist_attrs
[
key
])
self
.
_dist_context
.
_dist_ops_for_program
[
key
]
=
DistributedOperator
(
serial_op
)
for
key
in
tensor_dist_attrs
:
serial_tensor
=
(
self
.
_dist_context
.
_dist_tensors_for_program
[
key
].
serial_tensor
)
# clear dist attr
serial_tensor
.
dist_attr
=
TensorDistAttr
(
serial_tensor
.
desc
)
serial_tensor
.
dist_attr
.
parse_from_string
(
tensor_dist_attrs
[
key
]
)
self
.
_dist_context
.
_dist_tensors_for_program
[
key
]
=
DistributedTensor
(
serial_tensor
)
process_meshes
=
[]
for
item
in
dist_attrs
[
"process_meshes"
]:
process_ids
=
item
[
0
]
shape
=
item
[
1
]
process_meshes
.
append
(
ProcessMesh
(
np
.
array
(
process_ids
).
reshape
(
shape
).
tolist
()
)
)
self
.
_dist_context
.
process_meshes
=
process_meshes
self
.
_load
=
True
logger
.
info
(
f
"The parallel strategy has been loaded from
{
path
}
"
)
if
not
self
.
_load
:
if
self
.
_strategy
.
auto_mode
!=
"semi"
:
self
.
_parallel_tuner
.
tune
()
else
:
self
.
_completer
.
complete_forward_annotation
()
if
os
.
getenv
(
"PADDLE_AUTO_PARALLEL_STAGE"
,
"run"
)
!=
"run"
:
quit
()
# parse forward sub block
# parse forward sub block
self
.
_dist_context
.
block_state
.
parse_forward_blocks
(
self
.
_dist_context
.
block_state
.
parse_forward_blocks
(
self
.
_dist_context
.
serial_main_program
self
.
_dist_context
.
serial_main_program
...
...
python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
浏览文件 @
118a7415
...
@@ -35,6 +35,9 @@ from paddle.distributed.auto_parallel.dist_attribute import (
...
@@ -35,6 +35,9 @@ from paddle.distributed.auto_parallel.dist_attribute import (
)
)
from
paddle.distributed.auto_parallel.dist_context
import
DistributedContext
from
paddle.distributed.auto_parallel.dist_context
import
DistributedContext
from
paddle.distributed.auto_parallel.dist_tensor
import
DistributedTensor
from
paddle.distributed.auto_parallel.dist_tensor
import
DistributedTensor
from
paddle.distributed.auto_parallel.process_group
import
(
get_world_process_group
,
)
from
paddle.distributed.auto_parallel.process_mesh
import
ProcessMesh
from
paddle.distributed.auto_parallel.process_mesh
import
ProcessMesh
from
paddle.distributed.auto_parallel.utils
import
(
from
paddle.distributed.auto_parallel.utils
import
(
is_gradient_clip_op
,
is_gradient_clip_op
,
...
@@ -579,7 +582,6 @@ class GraphUtil:
...
@@ -579,7 +582,6 @@ class GraphUtil:
def
_match_core
(
src_node
,
tgt_node
):
def
_match_core
(
src_node
,
tgt_node
):
nonlocal
not_matched
nonlocal
not_matched
# not support one input name or output name corresponding to multiple vars
# not support one input name or output name corresponding to multiple vars
if
not_matched
:
if
not_matched
:
return
return
...
@@ -1126,13 +1128,6 @@ class RuleBasedTuner:
...
@@ -1126,13 +1128,6 @@ class RuleBasedTuner:
def
level
(
self
):
def
level
(
self
):
return
self
.
_level
return
self
.
_level
def
convert_process_mesh_to_key
(
self
,
process_mesh
):
"""Convert process mesh object to str."""
processes
=
","
.
join
([
str
(
x
)
for
x
in
process_mesh
.
_process_ids
])
topology
=
","
.
join
([
str
(
x
)
for
x
in
process_mesh
.
_shape
])
key
=
processes
+
";"
+
topology
return
key
def
gen_full_program
(
self
):
def
gen_full_program
(
self
):
"""Generate full program that contain backward and update phase program if mode is train."""
"""Generate full program that contain backward and update phase program if mode is train."""
self
.
full_main_program
=
self
.
dist_context
.
serial_main_program
.
clone
()
self
.
full_main_program
=
self
.
dist_context
.
serial_main_program
.
clone
()
...
@@ -1878,6 +1873,13 @@ class RuleBasedTuner:
...
@@ -1878,6 +1873,13 @@ class RuleBasedTuner:
][
parallelism
][
key
]
][
parallelism
][
key
]
self
.
_complete_sub_update_program
(
sub_program_dist_context
)
self
.
_complete_sub_update_program
(
sub_program_dist_context
)
def
convert_process_mesh_to_key
(
self
,
process_mesh
):
"""Convert process mesh object to str."""
processes
=
","
.
join
([
str
(
x
)
for
x
in
process_mesh
.
_process_ids
])
topology
=
","
.
join
([
str
(
x
)
for
x
in
process_mesh
.
_shape
])
key
=
processes
+
";"
+
topology
return
key
def
convert_device_mesh_to_key
(
self
,
device_mesh
):
def
convert_device_mesh_to_key
(
self
,
device_mesh
):
"""Convert device mesh object to str."""
"""Convert device mesh object to str."""
processes
=
","
.
join
([
str
(
x
)
for
x
in
device_mesh
.
device_ids
])
processes
=
","
.
join
([
str
(
x
)
for
x
in
device_mesh
.
device_ids
])
...
@@ -1894,6 +1896,168 @@ class RuleBasedTuner:
...
@@ -1894,6 +1896,168 @@ class RuleBasedTuner:
)
)
return
global_cost
.
time
,
max_memory
return
global_cost
.
time
,
max_memory
def
_local_stage_pass
(
self
,
start
,
end
,
process_mesh
):
"""Get the best cost and the corresponding strategy of layers on the given process mesh."""
# convert process mesh to dict key
key
=
self
.
convert_process_mesh_to_key
(
process_mesh
)
if
start
in
self
.
stage_best_cost_of_pm
:
if
end
in
self
.
stage_best_cost_of_pm
[
start
]:
if
key
in
self
.
stage_best_cost_of_pm
[
start
][
end
]:
return
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"cost"
]
assert
end
>=
start
selective_parallelisms
=
(
[
"dp"
,
"mp"
]
if
len
(
process_mesh
.
shape
)
==
1
else
[
"dp_mp"
,
"mp_dp"
]
)
if
start
not
in
self
.
stage_best_cost_of_pm
:
self
.
stage_best_cost_of_pm
[
start
]
=
{}
if
end
not
in
self
.
stage_best_cost_of_pm
[
start
]:
self
.
stage_best_cost_of_pm
[
start
][
end
]
=
{}
if
key
not
in
self
.
stage_best_cost_of_pm
[
start
][
end
]:
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
]
=
{}
if
end
==
start
:
dist_contexts_x
=
[
DistributedContext
(),
DistributedContext
()]
else
:
dist_contexts_x
=
self
.
stage_best_cost_of_pm
[
start
][
end
-
1
][
key
][
"dist_context"
]
# Use beam search, the beam size is 2.
# When the process mesh is 1-D, the selecetive parallelsim can be dp or mp.
# Because the first layer often contains more ops than other layer, using beam search can find more accurate strategy.
count
=
0
for
dist_context_x
in
dist_contexts_x
:
if
end
==
start
and
count
==
1
:
break
for
parallelism
in
selective_parallelisms
:
dist_context_y
=
self
.
sub_programs_dist_context
[
end
][
parallelism
][
key
]
dist_context
=
self
.
combine_dist_contexts
(
[
dist_context_x
,
dist_context_y
]
)
if
(
"dist_context"
not
in
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
]
):
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"dist_context"
]
=
[
None
,
None
]
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"cost"
]
=
[
sys
.
maxsize
,
sys
.
maxsize
,
]
# estimate cost and memory
cost
,
local_stage_memory
=
self
.
_get_sub_program_cost
(
dist_context
)
if
local_stage_memory
>
0.9
*
self
.
cluster
.
machines
[
0
].
devices
[
0
].
memory
*
(
1024
**
3
):
cost
=
sys
.
maxsize
index
=
-
1
for
idx
,
item
in
enumerate
(
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"cost"
]
):
if
cost
<=
item
:
index
=
idx
break
if
index
==
0
:
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"cost"
][
1
]
=
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"cost"
][
0
]
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"dist_context"
][
1
]
=
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"dist_context"
][
0
]
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"cost"
][
0
]
=
cost
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"dist_context"
][
0
]
=
dist_context
elif
index
==
1
:
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"cost"
][
1
]
=
cost
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"dist_context"
][
1
]
=
dist_context
count
+=
1
if
(
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"cost"
][
1
]
<
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"cost"
][
0
]
):
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"best_cost"
]
=
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"cost"
][
1
]
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"best_dist_context"
]
=
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"dist_context"
][
1
]
else
:
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"best_cost"
]
=
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"cost"
][
0
]
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"best_dist_context"
]
=
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"dist_context"
][
0
]
return
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"best_cost"
]
def
local_stage_pass
(
self
,
start
,
end
,
device_mesh
):
"""Get the best cost and the corresponding strategy of layers on the given device mesh."""
dm_key
=
self
.
convert_device_mesh_to_key
(
device_mesh
)
device_mesh_shape
=
device_mesh
.
shape
if
len
(
device_mesh_shape
)
==
1
:
device_mesh_shape
.
insert
(
0
,
1
)
process_mesh_shapes
=
convert_to_process_meshes
(
device_mesh_shape
)
best_cost
=
sys
.
maxsize
if
start
not
in
self
.
stage_best_cost_of_dm
:
self
.
stage_best_cost_of_dm
[
start
]
=
{}
if
end
not
in
self
.
stage_best_cost_of_dm
[
start
]:
self
.
stage_best_cost_of_dm
[
start
][
end
]
=
{}
if
dm_key
not
in
self
.
stage_best_cost_of_dm
[
start
][
end
]:
self
.
stage_best_cost_of_dm
[
start
][
end
][
dm_key
]
=
{}
for
process_mesh_shape
in
process_mesh_shapes
:
process_mesh
=
ProcessMesh
(
np
.
array
(
device_mesh
.
device_ids
)
.
reshape
(
process_mesh_shape
)
.
tolist
()
)
key
=
self
.
convert_process_mesh_to_key
(
process_mesh
)
for
i
in
range
(
start
,
end
+
1
):
self
.
_local_stage_pass
(
start
,
i
,
process_mesh
)
if
(
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"best_cost"
]
<=
best_cost
):
best_cost
=
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"best_cost"
]
self
.
stage_best_cost_of_dm
[
start
][
end
][
dm_key
][
"cost"
]
=
best_cost
self
.
stage_best_cost_of_dm
[
start
][
end
][
dm_key
][
"dist_context"
]
=
self
.
stage_best_cost_of_pm
[
start
][
end
][
key
][
"best_dist_context"
]
return
best_cost
def
combine_dist_contexts
(
self
,
dist_contexts
):
def
combine_dist_contexts
(
self
,
dist_contexts
):
"""Combine the dist attr in dist contexts to one dist context."""
"""Combine the dist attr in dist contexts to one dist context."""
combined_dist_context
=
DistributedContext
()
combined_dist_context
=
DistributedContext
()
...
@@ -1927,7 +2091,7 @@ class RuleBasedTuner:
...
@@ -1927,7 +2091,7 @@ class RuleBasedTuner:
self
.
layers
=
self
.
cluster_operators
()
self
.
layers
=
self
.
cluster_operators
()
end
=
time
.
time
()
end
=
time
.
time
()
self
.
_logger
.
info
(
self
.
_logger
.
info
(
"Cluster operators to {} layers in {}s."
.
format
(
"Cluster operators to {} layers in {
:.2f
}s."
.
format
(
len
(
self
.
layers
),
end
-
begin
len
(
self
.
layers
),
end
-
begin
)
)
)
)
...
@@ -1937,7 +2101,7 @@ class RuleBasedTuner:
...
@@ -1937,7 +2101,7 @@ class RuleBasedTuner:
self
.
gen_fwd_sub_programs_by_clone
()
self
.
gen_fwd_sub_programs_by_clone
()
end
=
time
.
time
()
end
=
time
.
time
()
self
.
_logger
.
info
(
self
.
_logger
.
info
(
f
"Generate programs of every layer in
{
end
-
begin
}
s."
f
"Generate programs of every layer in
{
end
-
begin
:.
2
f
}
s."
)
)
# step3: partition devices to device meshes
# step3: partition devices to device meshes
...
@@ -1948,7 +2112,7 @@ class RuleBasedTuner:
...
@@ -1948,7 +2112,7 @@ class RuleBasedTuner:
)
)
device_meshes_list
=
ClusterPartitionUtil
.
partition_cluster
(
n
,
m
)
device_meshes_list
=
ClusterPartitionUtil
.
partition_cluster
(
n
,
m
)
end
=
time
.
time
()
end
=
time
.
time
()
self
.
_logger
.
info
(
f
"Partition cluster in
{
end
-
begin
}
s."
)
self
.
_logger
.
info
(
f
"Partition cluster in
{
end
-
begin
:.
2
f
}
s."
)
# step4: transform device mesh to process meshes
# step4: transform device mesh to process meshes
dm_idx
=
0
dm_idx
=
0
...
@@ -1987,7 +2151,7 @@ class RuleBasedTuner:
...
@@ -1987,7 +2151,7 @@ class RuleBasedTuner:
begin
=
time
.
time
()
begin
=
time
.
time
()
self
.
gen_full_program
()
self
.
gen_full_program
()
end
=
time
.
time
()
end
=
time
.
time
()
self
.
_logger
.
info
(
f
"Generate full program in
{
end
-
begin
}
s."
)
self
.
_logger
.
info
(
f
"Generate full program in
{
end
-
begin
:.
2
f
}
s."
)
# step6: complete forward sub programs
# step6: complete forward sub programs
begin
=
time
.
time
()
begin
=
time
.
time
()
...
@@ -1995,7 +2159,7 @@ class RuleBasedTuner:
...
@@ -1995,7 +2159,7 @@ class RuleBasedTuner:
self
.
complete_sub_fwd_programs
(
process_mesh
)
self
.
complete_sub_fwd_programs
(
process_mesh
)
end
=
time
.
time
()
end
=
time
.
time
()
self
.
_logger
.
info
(
self
.
_logger
.
info
(
f
"Complete all sub forward programs in
{
end
-
begin
}
s."
f
"Complete all sub forward programs in
{
end
-
begin
:.
2
f
}
s."
)
)
if
self
.
mode
==
"train"
:
if
self
.
mode
==
"train"
:
...
@@ -2004,7 +2168,9 @@ class RuleBasedTuner:
...
@@ -2004,7 +2168,9 @@ class RuleBasedTuner:
self
.
complete_sub_bwd_programs
()
self
.
complete_sub_bwd_programs
()
end
=
time
.
time
()
end
=
time
.
time
()
self
.
_logger
.
info
(
self
.
_logger
.
info
(
f
"Complete all sub backward programs in
{
end
-
begin
}
s."
"Complete all sub backward programs in {:.2f}s."
.
format
(
end
-
begin
)
)
)
# step8: complete update sub programs
# step8: complete update sub programs
...
@@ -2015,6 +2181,88 @@ class RuleBasedTuner:
...
@@ -2015,6 +2181,88 @@ class RuleBasedTuner:
f
"Complete all sub update programs in
{
end
-
begin
}
s."
f
"Complete all sub update programs in
{
end
-
begin
}
s."
)
)
def
layer_placement_pass
(
self
,
stages
,
layers
,
device_meshes
):
"""Get the best cost and the corresponding strategy of the given layers on the stages which running on the devices."""
stage_layer_cost
=
[
[
sys
.
maxsize
for
i
in
range
(
layers
)]
for
j
in
range
(
stages
)
]
# To get the balance among the stages, we select the minimum maximum cost of stages.
min_max_stage_costs
=
[
[
None
for
i
in
range
(
layers
)]
for
j
in
range
(
stages
)
]
best_strategies
=
[[
None
for
i
in
range
(
layers
)]
for
j
in
range
(
stages
)]
for
s
in
range
(
len
(
device_meshes
)):
for
i
in
range
(
0
,
layers
):
if
s
==
0
:
stage_layer_cost
[
s
][
i
]
=
self
.
local_stage_pass
(
0
,
i
,
device_meshes
[
s
]
)
min_max_stage_costs
[
s
][
i
]
=
stage_layer_cost
[
s
][
i
]
key
=
self
.
convert_device_mesh_to_key
(
device_meshes
[
s
])
best_strategies
[
s
][
i
]
=
self
.
stage_best_cost_of_dm
[
0
][
i
][
key
][
"dist_context"
]
else
:
min_cost
=
sys
.
maxsize
min_max_stage_cost
=
sys
.
maxsize
for
j
in
range
(
0
,
i
):
key
=
self
.
convert_device_mesh_to_key
(
device_meshes
[
s
])
local_stage_cost
=
self
.
local_stage_pass
(
j
+
1
,
i
,
device_meshes
[
s
]
)
dist_context
=
self
.
combine_dist_contexts
(
[
best_strategies
[
s
-
1
][
j
],
self
.
stage_best_cost_of_dm
[
j
+
1
][
i
][
key
][
"dist_context"
],
]
)
cost
,
_
=
self
.
_get_sub_program_cost
(
dist_context
)
max_stage_cost
=
(
min_max_stage_costs
[
s
-
1
][
j
]
if
local_stage_cost
<
min_max_stage_costs
[
s
-
1
][
j
]
else
local_stage_cost
)
if
cost
<=
min_cost
:
if
cost
==
min_cost
:
if
max_stage_cost
<
min_max_stage_cost
:
min_max_stage_cost
=
max_stage_cost
best_strategies
[
s
][
i
]
=
dist_context
else
:
break
else
:
best_strategies
[
s
][
i
]
=
dist_context
min_cost
=
cost
stage_layer_cost
[
s
][
i
]
=
min_cost
min_max_stage_costs
[
s
][
i
]
=
min_max_stage_cost
return
(
stage_layer_cost
[
stages
-
1
][
layers
-
1
],
best_strategies
[
stages
-
1
][
layers
-
1
],
)
def
tune_o2
(
self
):
"""The o2 level tuning."""
best_dist_context
=
None
best_cost
=
sys
.
maxsize
for
device_meshes
in
self
.
device_meshes_list
:
cost
,
dist_context
=
self
.
layer_placement_pass
(
len
(
device_meshes
),
len
(
self
.
layers
),
device_meshes
)
if
cost
<=
best_cost
:
self
.
_logger
.
info
(
"O2 level: a better strategy has be found as follows: "
)
print_program_with_dist_attr
(
self
.
full_main_program
,
best_dist_context
)
best_cost
=
cost
best_dist_context
=
dist_context
return
best_dist_context
def
tune_o1
(
self
):
def
tune_o1
(
self
):
"""The o1 level tuning."""
"""The o1 level tuning."""
best_cost
=
sys
.
maxsize
best_cost
=
sys
.
maxsize
...
@@ -2082,7 +2330,7 @@ class RuleBasedTuner:
...
@@ -2082,7 +2330,7 @@ class RuleBasedTuner:
)
)
self
.
_logger
.
info
(
self
.
_logger
.
info
(
"Cost Model: The max memory is {
}GB and cost is {
} when {} parallelism under process mesh shape {} on {} stages."
.
format
(
"Cost Model: The max memory is {
:.2f}GB and cost is {:.2f
} when {} parallelism under process mesh shape {} on {} stages."
.
format
(
memory
/
(
1024
**
3
),
memory
/
(
1024
**
3
),
cost
,
cost
,
parallelism
,
parallelism
,
...
@@ -2090,8 +2338,8 @@ class RuleBasedTuner:
...
@@ -2090,8 +2338,8 @@ class RuleBasedTuner:
len
(
device_meshes
),
len
(
device_meshes
),
)
)
)
)
# 1
5% buffer is reserved
for memory cost
# 1
0% buffer is reserved safely
for memory cost
if
memory
>
0.
85
*
self
.
cluster
.
machines
[
0
].
devices
[
if
memory
>
0.
9
*
self
.
cluster
.
machines
[
0
].
devices
[
0
0
].
memory
*
(
1024
**
3
):
].
memory
*
(
1024
**
3
):
cost
=
sys
.
maxsize
cost
=
sys
.
maxsize
...
@@ -2100,7 +2348,7 @@ class RuleBasedTuner:
...
@@ -2100,7 +2348,7 @@ class RuleBasedTuner:
best_cost
=
cost
best_cost
=
cost
best_dist_context
=
dist_context_of_device_meshes
best_dist_context
=
dist_context_of_device_meshes
self
.
_logger
.
info
(
self
.
_logger
.
info
(
"O1 level: a better strategy has be found that parallelism is {} under process mesh shape {} on {} stages with max memory {}GB."
.
format
(
"O1 level: a better strategy has be found that parallelism is {} under process mesh shape {} on {} stages with max memory {
:.2f
}GB."
.
format
(
parallelism
,
parallelism
,
process_mesh_shape
,
process_mesh_shape
,
len
(
device_meshes
),
len
(
device_meshes
),
...
@@ -2110,9 +2358,6 @@ class RuleBasedTuner:
...
@@ -2110,9 +2358,6 @@ class RuleBasedTuner:
return
best_dist_context
return
best_dist_context
def
tune_o2
(
self
):
return
None
def
save_strategy
(
self
,
best_dist_context
,
path
):
def
save_strategy
(
self
,
best_dist_context
,
path
):
dist_attrs
=
{
"tensor"
:
{},
"op"
:
{},
"process_meshes"
:
[]}
dist_attrs
=
{
"tensor"
:
{},
"op"
:
{},
"process_meshes"
:
[]}
for
key
in
best_dist_context
.
_dist_tensors_for_program
:
for
key
in
best_dist_context
.
_dist_tensors_for_program
:
...
@@ -2151,9 +2396,14 @@ class RuleBasedTuner:
...
@@ -2151,9 +2396,14 @@ class RuleBasedTuner:
begin
=
time
.
time
()
begin
=
time
.
time
()
self
.
match_program
(
self
.
_dist_context
.
serial_main_program
)
self
.
match_program
(
self
.
_dist_context
.
serial_main_program
)
end
=
time
.
time
()
end
=
time
.
time
()
self
.
_logger
.
info
(
f
"Pattern match in
{
end
-
begin
}
s."
)
self
.
_logger
.
info
(
f
"Pattern match in
{
end
-
begin
:.
2
f
}
s."
)
if
self
.
_use_dp
:
if
self
.
_use_dp
:
total_rank
=
(
self
.
_cluster
.
get_num_machines
()
*
self
.
_cluster
.
_num_devices_per_machine
)
get_world_process_group
().
add_ranks
(
list
(
range
(
total_rank
)))
completer
=
Completer
(
self
.
_dist_context
)
completer
=
Completer
(
self
.
_dist_context
)
completer
.
complete_forward_annotation
()
completer
.
complete_forward_annotation
()
print_program_with_dist_attr
(
print_program_with_dist_attr
(
...
@@ -2213,7 +2463,7 @@ class RuleBasedTuner:
...
@@ -2213,7 +2463,7 @@ class RuleBasedTuner:
self
.
_dist_context
.
_process_meshes
=
best_dist_context
.
_process_meshes
self
.
_dist_context
.
_process_meshes
=
best_dist_context
.
_process_meshes
end
=
time
.
time
()
end
=
time
.
time
()
self
.
_logger
.
info
(
f
"Rule-based tuner end in
{
end
-
begin
}
s."
)
self
.
_logger
.
info
(
f
"Rule-based tuner end in
{
end
-
begin
:.
2
f
}
s."
)
self
.
_logger
.
info
(
"The best strategy found is as follows: "
)
self
.
_logger
.
info
(
"The best strategy found is as follows: "
)
print_program_with_dist_attr
(
self
.
full_main_program
,
best_dist_context
)
print_program_with_dist_attr
(
self
.
full_main_program
,
best_dist_context
)
...
...
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
118a7415
...
@@ -85,6 +85,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
...
@@ -85,6 +85,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties
(
test_pass_base_list PROPERTIES TIMEOUT 20
)
set_tests_properties
(
test_pass_base_list PROPERTIES TIMEOUT 20
)
py_test_modules
(
test_fuse_adamw_pass MODULES test_fuse_adamw_pass
)
py_test_modules
(
test_fuse_adamw_pass MODULES test_fuse_adamw_pass
)
set_tests_properties
(
test_fuse_adamw_pass PROPERTIES TIMEOUT 20
)
set_tests_properties
(
test_fuse_adamw_pass PROPERTIES TIMEOUT 20
)
py_test_modules
(
test_rule_based_tuner_o2 MODULES test_rule_based_tuner_o2
)
set_tests_properties
(
test_rule_based_tuner_o2 PROPERTIES TIMEOUT 50
)
# End of unittests WITH single card and timeout
# End of unittests WITH single card and timeout
# NOTE(zyl): unittests WITH single card and WITHOUT timeout
# NOTE(zyl): unittests WITH single card and WITHOUT timeout
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_parallel_tuner_full.py
浏览文件 @
118a7415
...
@@ -153,7 +153,7 @@ class TestParallelTunerFull(unittest.TestCase):
...
@@ -153,7 +153,7 @@ class TestParallelTunerFull(unittest.TestCase):
cluster
=
Cluster
()
cluster
=
Cluster
()
cluster
.
gen_default_config_cluster
(
node_count
=
1
,
device_count
=
8
)
cluster
.
gen_default_config_cluster
(
node_count
=
1
,
device_count
=
8
)
strategy
=
Strategy
()
strategy
=
Strategy
()
strategy
.
auto_mode
=
"full"
strategy
.
auto_mode
=
"full
_random
"
dist_context
=
DistributedContext
(
dist_context
=
DistributedContext
(
train_program
,
train_program
,
start_program
,
start_program
,
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner_o2.py
0 → 100644
浏览文件 @
118a7415
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
unittest
import
numpy
as
np
import
paddle
from
paddle
import
static
sys
.
path
.
append
(
".."
)
import
auto_parallel_gpt_model
as
modeling
from
auto_parallel_gpt_model
import
(
GPTForPretraining
,
GPTModel
,
GPTPretrainingCriterion
,
)
def
get_gpt_model
(
train_program
,
start_program
,
place
,
batch_size
,
sequence_len
,
vocab_size
):
with
static
.
program_guard
(
train_program
,
start_program
):
tokens
=
paddle
.
static
.
data
(
name
=
"tokens"
,
shape
=
[
batch_size
,
sequence_len
],
dtype
=
'int64'
)
position_ids
=
paddle
.
static
.
data
(
name
=
"position_ids"
,
shape
=
[
batch_size
,
sequence_len
],
dtype
=
'int64'
)
attention_mask
=
paddle
.
static
.
data
(
name
=
"attention_mask"
,
shape
=
[
batch_size
,
1
,
sequence_len
,
sequence_len
],
dtype
=
'float32'
,
)
labels
=
paddle
.
static
.
data
(
name
=
"labels"
,
shape
=
[
batch_size
,
sequence_len
],
dtype
=
'int64'
)
loss_mask
=
paddle
.
static
.
data
(
name
=
"loss_mask"
,
shape
=
[
batch_size
,
sequence_len
],
dtype
=
'float32'
)
gpt
=
GPTModel
(
vocab_size
=
1000
,
hidden_size
=
64
,
num_hidden_layers
=
2
,
num_attention_heads
=
8
,
intermediate_size
=
256
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.0
,
attention_probs_dropout_prob
=
0.0
,
max_position_embeddings
=
1024
,
type_vocab_size
=
1
,
initializer_range
=
0.02
,
pad_token_id
=
0
,
eos_token_id
=
7
,
bos_token_id
=
0
,
eol_token_id
=
3
,
)
model
=
GPTForPretraining
(
gpt
,
vocab_size
=
1000
,
hidden_size
=
64
,
initializer_range
=
0.02
)
preds
=
model
(
tokens
,
position_ids
,
attention_mask
)
criterion
=
GPTPretrainingCriterion
()
loss
=
criterion
(
preds
,
labels
,
loss_mask
)
def
gen_data
():
np
.
random
.
seed
(
2021
)
tokens
=
[]
position_ids
=
[]
attention_mask
=
[]
labels
=
[]
loss_mask
=
[]
for
_
in
range
(
batch_size
):
tokens
.
append
(
np
.
random
.
randint
(
vocab_size
,
size
=
sequence_len
))
position_ids
.
append
(
np
.
arange
(
sequence_len
))
attention_mask
.
append
([
np
.
tril
(
np
.
ones
(
sequence_len
))])
labels
.
append
(
np
.
random
.
randint
(
vocab_size
,
size
=
sequence_len
))
loss_mask
.
append
(
np
.
ones
(
sequence_len
))
return
tokens
,
position_ids
,
attention_mask
,
labels
,
loss_mask
return
train_program
,
start_program
,
loss
,
gen_data
class
TestRuleBasedTuner
(
unittest
.
TestCase
):
def
test_gpt_o2
(
self
):
modeling
.
init_global
()
train_program
=
static
.
Program
()
start_program
=
static
.
Program
()
batch_size
=
8
sequence_len
=
512
vocab_size
=
1000
place
=
None
train_program
,
start_program
,
loss
,
gen_data
=
get_gpt_model
(
train_program
,
start_program
,
place
,
batch_size
,
sequence_len
,
vocab_size
,
)
from
paddle.distributed.auto_parallel.cluster
import
Cluster
from
paddle.distributed.auto_parallel.dist_context
import
(
DistributedContext
,
)
from
paddle.distributed.auto_parallel.tuner.rule_based_tuner
import
(
RuleBasedTuner
,
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
0.2
)
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
cluster
=
Cluster
()
cluster
.
gen_default_config_cluster
(
node_count
=
1
,
device_count
=
8
)
dist_context
=
DistributedContext
(
serial_main_prog
=
train_program
,
serial_startup_prog
=
start_program
,
serial_optimizer
=
opt
,
serial_loss
=
loss
,
cluster
=
cluster
,
)
dist_context
.
initialize
()
tuner
=
RuleBasedTuner
(
dist_context
,
level
=
"o2"
)
tuner
.
tune
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录