Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d6011cb6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d6011cb6
编写于
3月 28, 2023
作者:
C
caozhou
提交者:
GitHub
3月 28, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Add o1 level tune (#52041)
* add tune o1 level * add unittest
上级
418b983c
变更
8
展开全部
显示空白变更内容
内联
并排
Showing
8 changed file
with
820 addition
and
721 deletion
+820
-721
python/paddle/distributed/auto_parallel/cluster.py
python/paddle/distributed/auto_parallel/cluster.py
+76
-14
python/paddle/distributed/auto_parallel/cost/base_cost.py
python/paddle/distributed/auto_parallel/cost/base_cost.py
+38
-25
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
+1
-1
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
+0
-640
python/paddle/distributed/auto_parallel/cost/estimate_cost.py
...on/paddle/distributed/auto_parallel/cost/estimate_cost.py
+82
-32
python/paddle/distributed/auto_parallel/reshard.py
python/paddle/distributed/auto_parallel/reshard.py
+2
-0
python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
...addle/distributed/auto_parallel/tuner/rule_based_tuner.py
+614
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py
...id/tests/unittests/auto_parallel/test_rule_based_tuner.py
+7
-9
未找到文件。
python/paddle/distributed/auto_parallel/cluster.py
浏览文件 @
d6011cb6
...
...
@@ -14,6 +14,7 @@
import
json
import
os
import
re
from
enum
import
IntEnum
,
unique
import
paddle
...
...
@@ -449,7 +450,6 @@ class Cluster:
npu_models
=
[
"NPU"
]
dcu_models
=
[
"DCU"
]
all_gpu_models
=
gpu_models
+
xpu_models
+
npu_models
+
dcu_models
assert
gpu_model
in
all_gpu_models
self
.
_num_devices_per_machine
=
device_count
def
_convert_to_type
(
gpu_model
):
...
...
@@ -462,6 +462,8 @@ class Cluster:
type
=
"NPU"
elif
gpu_model
in
dcu_models
:
type
=
"DCU"
else
:
type
=
"GPU"
assert
type
is
not
None
return
type
...
...
@@ -470,6 +472,12 @@ class Cluster:
model
=
None
if
gpu_model
==
"V100"
:
model
=
"Tesla V100-SXM2-"
+
str
(
gpu_memory
)
+
"GB"
elif
gpu_model
==
"A100"
:
model
=
"Tesla A100-SXM-"
+
str
(
gpu_memory
)
+
"GB"
elif
gpu_model
==
"A30"
:
model
=
"Tesla A30-SXM-"
+
str
(
gpu_memory
)
+
"GB"
else
:
model
=
gpu_model
+
str
(
gpu_memory
)
+
"GB"
assert
model
is
not
None
return
model
...
...
@@ -527,6 +535,8 @@ class Cluster:
device
[
"memory"
]
=
memory
device
[
"sp_gflops"
]
=
sp_gflops
device
[
"dp_gflops"
]
=
dp_gflops
# hard code
device
[
"type"
]
=
"GPU"
global_id_to_device_type
[
global_id
]
=
type
global_id_to_node
[
global_id
]
=
i
devices
.
append
(
device
)
...
...
@@ -820,13 +830,45 @@ class Cluster:
return
self
.
__str__
()
def
get_default_cluster
():
def
get_default_cluster
(
json_config
=
None
):
def
is_by_json_config
(
json_config
):
if
not
json_config
:
return
False
if
"cluster"
not
in
json_config
:
return
False
else
:
if
"path"
not
in
json_config
[
"cluster"
]:
if
"num_nodes"
not
in
json_config
[
"cluster"
]:
return
False
if
"num_gpus"
not
in
json_config
[
"cluster"
]:
return
False
if
"gpu_model"
not
in
json_config
[
"cluster"
]:
return
False
if
"gpu_memory"
not
in
json_config
[
"cluster"
]:
return
False
return
True
else
:
return
True
cluster
=
Cluster
()
if
json_config
and
is_by_json_config
(
json_config
):
# Get GPU info by json config
if
"path"
in
json_config
[
"cluster"
]:
cluster
.
build_from_file
(
json_config
[
"cluster"
][
"path"
])
return
cluster
else
:
node_count
=
json_config
[
"cluster"
][
"num_nodes"
]
local_device_count
=
json_config
[
"cluster"
][
"num_gpus"
]
gpu_model
=
json_config
[
"cluster"
][
"gpu_model"
]
memory
=
json_config
[
"cluster"
][
"gpu_memory"
]
else
:
# Get GPU info by get_device_properties
local_device_count
=
os
.
getenv
(
"PADDLE_LOCAL_SIZE"
)
if
local_device_count
is
None
:
local_device_count
=
1
else
:
local_device_count
=
int
(
local_device_count
)
global_device_count
=
os
.
getenv
(
"PADDLE_GLOBAL_SIZE"
)
if
global_device_count
is
None
:
node_count
=
1
...
...
@@ -834,16 +876,36 @@ def get_default_cluster():
global_device_count
=
int
(
global_device_count
)
assert
global_device_count
%
local_device_count
==
0
node_count
=
int
(
global_device_count
)
//
local_device_count
gpu_info
=
paddle
.
device
.
cuda
.
get_device_properties
()
assert
gpu_info
,
"Auto parallel just runs on gpu now."
gpu_name
=
gpu_info
.
name
try
:
re_result
=
re
.
split
(
r
'[ , -]'
,
gpu_name
)
gpu_model
=
re_result
[
1
]
memory
=
int
(
re_result
[
-
1
][:
-
2
])
except
:
memory
=
int
(
gpu_info
.
total_memory
)
//
(
1000
**
3
)
gpu_model
=
gpu_name
print
(
"Node Count: "
,
node_count
,
"Local Device Size: "
,
local_device_count
,
"GPU Model: "
,
gpu_model
,
"GPU Memory: "
,
memory
,
"World size: "
,
paddle
.
distributed
.
get_world_size
(),
flush
=
True
,
)
cluster
.
gen_default_config_cluster
(
node_count
=
node_count
,
device_count
=
local_device_count
node_count
=
node_count
,
device_count
=
local_device_count
,
gpu_model
=
gpu_model
,
gpu_memory
=
memory
,
)
return
cluster
python/paddle/distributed/auto_parallel/cost/base_cost.py
浏览文件 @
d6011cb6
...
...
@@ -16,6 +16,7 @@ from collections import OrderedDict
from
functools
import
reduce
import
paddle
from
paddle.utils.flops
import
flops
from
..cluster
import
LinkType
from
..dist_tensor
import
DistributedTensor
...
...
@@ -91,9 +92,10 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
output_desc
=
OrderedDict
()
# Get partitioned shape of input
input_var_desc
=
{}
for
input_name
in
op
.
input_names
:
var_name_list
=
op
.
input
(
input_name
)
var_desc
=
[]
input_var_desc
[
input_name
]
=
[]
for
var_name
in
var_name_list
:
var
=
get_var_with_recursion
(
var_name
,
op
.
block
,
op
.
block
.
program
...
...
@@ -112,7 +114,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
process
,
shard_sizes
,
)
var_desc
.
append
((
var
.
dtype
,
shape
)
)
input_var_desc
[
input_name
].
append
(
shape
)
# For special op such as embedding and its grad op
if
(
...
...
@@ -137,8 +139,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
relative_idx
=
relative_idx
*
per_part_size
desc
[
"attrs"
][
"start_index"
]
=
relative_idx
input_desc
[
input_name
]
=
var_desc
desc
[
"inputs"
]
=
input_desc
desc
[
"inputs"
]
=
input_var_desc
for
out_name
in
op
.
output_names
:
var_name_list
=
op
.
output
(
out_name
)
...
...
@@ -350,7 +351,9 @@ def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None):
return
desc
def
build_comm_costs_from_descs
(
op_cost_class
,
ctx
,
processes
,
descs
,
cluster
):
def
build_comm_costs_from_descs
(
op_cost_class
,
ctx
,
processes
,
descs
,
cluster
,
is_dp
=
False
):
"""Build comm costs by descriptions"""
comm_context
=
CommContext
(
cluster
)
group_ranks_list
=
[]
...
...
@@ -363,6 +366,8 @@ def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster):
comm_op_cost
=
op_cost_class
(
op_desc
=
desc
,
comm_context
=
comm_context
)
if
is_dp
:
comm_op_cost
.
cost
.
time
*=
0.9
comm_op_cost_list
.
append
(
comm_op_cost
)
return
comm_op_cost_list
...
...
@@ -389,6 +394,7 @@ def build_dp_costs(
vars
=
dist_op
.
serial_op
.
block
.
vars
var_name
=
var_names
[
0
]
has_found
=
False
is_input
=
True
for
name
in
dist_op
.
serial_op
.
input_arg_names
:
if
var_name
in
name
:
var_name
=
name
...
...
@@ -400,6 +406,7 @@ def build_dp_costs(
if
var_name
in
name
:
var_name
=
name
has_found
=
True
is_input
=
False
break
if
not
has_found
:
return
...
...
@@ -418,6 +425,7 @@ def build_dp_costs(
processes
,
c_allreduce_sum_descs
,
cluster
,
is_dp
=
True
,
)
result
.
append
(
comm_cost_list
)
...
...
@@ -431,22 +439,11 @@ def build_dp_costs(
desc
=
{}
desc
[
"op"
]
=
op_type
desc
[
"inputs"
]
=
{}
if
var_name
in
dist_attr
.
inputs_dist_attrs
:
dims_mapping
=
dist_attr
.
get_input_dims_mapping
(
var_name
)
elif
var_name
in
dist_attr
.
outputs_dist_attrs
:
dims_mapping
=
dist_attr
.
get_output_dims_mapping
(
var_name
)
else
:
raise
AssertionError
(
"cannot find dims_mapping for {} in {}"
.
format
(
var_name
,
dist_attr
)
dims_mapping
=
(
dist_attr
.
get_input_dims_mapping
(
var_name
)
if
is_input
else
dist_attr
.
get_output_dims_mapping
(
var_name
)
)
# dims_mapping = (
# dist_attr.get_input_dims_mapping(var_name)
# if dist_attr.get_input_dims_mapping(var_name) is not None
# else dist_attr.get_output_dims_mapping(var_name)
# )
var
=
get_var_with_recursion
(
var_name
,
dist_op
.
serial_op
.
block
,
...
...
@@ -493,8 +490,6 @@ class CommContext:
# if cluster has no info about those vars, it will be set by default
self
.
base_ring
=
None
self
.
base_tree
=
None
# self.base_inter_ring = None
# self.base_inter_tree = None
self
.
intra_ring
=
None
self
.
intra_tree
=
None
self
.
inter_ring
=
None
...
...
@@ -508,8 +503,6 @@ class CommContext:
# set default
self
.
base_ring
=
8.4
self
.
base_tree
=
0.0
# self.base_inter_ring = 9.6
# self.base_inter_tree = 28
# NVL in default
self
.
intra_ring
=
3.4
self
.
intra_tree
=
28
...
...
@@ -681,6 +674,8 @@ class Cost:
class
OpCost
:
OP_TYPE
=
"op"
def
__init__
(
self
,
op
=
None
,
op_desc
=
None
):
self
.
_op
=
op
self
.
_op_desc
=
op_desc
...
...
@@ -883,6 +878,24 @@ class CompOpCost(OpCost):
)
)
def
calc_flops
(
self
):
if
not
self
.
op_desc
:
return
0
if
"_grad"
in
self
.
__class__
.
OP_TYPE
:
op_type
=
self
.
__class__
.
OP_TYPE
[:
len
(
self
.
__class__
.
OP_TYPE
)
-
5
]
return
2
*
flops
(
op_type
,
self
.
op_desc
[
"inputs"
],
self
.
op_desc
[
"attrs"
]
)
return
flops
(
self
.
__class__
.
OP_TYPE
,
self
.
op_desc
[
"inputs"
],
self
.
op_desc
[
"attrs"
],
)
def
calc_time
(
self
):
flops_count
=
self
.
calc_flops
()
return
flops_count
*
2.9e-7
def
register_op_cost
(
cls
):
op_type
=
cls
.
OP_TYPE
...
...
python/paddle/distributed/auto_parallel/cost/comm_op_cost.py
浏览文件 @
d6011cb6
...
...
@@ -140,7 +140,7 @@ class IdentityOpCost(CommOpCost):
super
().
__init__
(
op
=
op
,
op_desc
=
op_desc
,
comm_context
=
comm_context
)
def
calc_time
(
self
):
return
0
return
self
.
comm_count
*
1
/
(
144
*
1e3
)
@
register_op_cost
...
...
python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
浏览文件 @
d6011cb6
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/cost/estimate_cost.py
浏览文件 @
d6011cb6
...
...
@@ -189,6 +189,9 @@ class CostEstimator:
# Calc dist op cost
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_op
:
continue
op_dist_attr
=
dist_op
.
dist_attr
processes
=
op_dist_attr
.
process_mesh
.
process_ids
...
...
@@ -225,6 +228,8 @@ class CostEstimator:
for
rank
in
group_ranks
:
self
.
local_cost
(
rank
).
time
=
(
max_time
+
comm_op_cost
.
time
if
op
.
attr
(
'op_role'
)
!=
OpRole
.
Backward
else
max_time
+
0.9
*
comm_op_cost
.
time
)
if
rank
not
in
self
.
_bubble_time_mapping
:
self
.
_bubble_time_mapping
[
rank
]
=
0
...
...
@@ -290,6 +295,7 @@ class CostEstimator:
self
.
_ordered_ops
.
append
([
op
.
desc
.
id
(),
op
])
self
.
_ordered_ops
.
sort
(
key
=
lambda
x
:
x
[
0
])
parameters
=
set
()
for
op_id
,
op
in
self
.
_ordered_ops
:
if
op
.
type
in
[
"create_py_reader"
,
...
...
@@ -298,11 +304,14 @@ class CostEstimator:
]:
continue
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_op
:
continue
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
for
var_name
in
op
.
input_arg_names
:
input_dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
var_name
)
if
var_name
not
in
var_info
:
var_info
[
var_name
]
=
{}
key
=
_convert_pm_and_dm_to_str
(
...
...
@@ -311,6 +320,10 @@ class CostEstimator:
if
key
not
in
var_info
[
var_name
]:
var_info
[
var_name
][
key
]
=
{}
# It is even partition now
if
"position"
not
in
var_info
[
var_name
][
key
]:
var_info
[
var_name
][
key
][
"position"
]
=
[]
var_info
[
var_name
][
key
][
"position"
].
append
(
op_id
)
if
"memory"
not
in
var_info
[
var_name
][
key
]:
var
=
dist_op
.
get_serial_input
(
var_name
)
global_sizes
=
var
.
shape
...
...
@@ -324,9 +337,16 @@ class CostEstimator:
var_info
[
var_name
][
key
][
"memory"
]
=
self
.
_calculate_bytes
(
sizes
,
dtype
)
if
"position"
not
in
var_info
[
var_name
][
key
]:
var_info
[
var_name
][
key
][
"position"
]
=
[]
var_info
[
var_name
][
key
][
"position"
].
append
(
op_id
)
if
var
.
persistable
:
name
=
var_name
+
key
if
name
not
in
parameters
:
parameters
.
add
(
name
)
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
memories
:
memories
[
process
]
=
0
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
for
var_name
in
op
.
output_arg_names
:
output_dims_mapping
=
dist_op
.
dist_attr
.
get_output_dims_mapping
(
...
...
@@ -339,6 +359,10 @@ class CostEstimator:
)
if
key
not
in
var_info
[
var_name
]:
var_info
[
var_name
][
key
]
=
{}
if
"position"
not
in
var_info
[
var_name
][
key
]:
var_info
[
var_name
][
key
][
"position"
]
=
[]
var_info
[
var_name
][
key
][
"position"
].
append
(
op_id
)
if
"memory"
not
in
var_info
[
var_name
][
key
]:
var
=
dist_op
.
get_serial_output
(
var_name
)
global_sizes
=
var
.
shape
...
...
@@ -352,11 +376,19 @@ class CostEstimator:
var_info
[
var_name
][
key
][
"memory"
]
=
self
.
_calculate_bytes
(
sizes
,
dtype
)
if
"position"
not
in
var_info
[
var_name
][
key
]:
var_info
[
var_name
][
key
][
"position"
]
=
[]
var_info
[
var_name
][
key
][
"position"
].
append
(
op_id
)
if
var
.
persistable
:
name
=
var_name
+
key
if
name
not
in
parameters
:
parameters
.
add
(
name
)
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
memories
:
memories
[
process
]
=
0
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
has_used_vars
=
set
()
not_calc_vars
=
set
()
for
op_id
,
op
in
self
.
_ordered_ops
:
if
op
.
type
in
[
"create_py_reader"
,
...
...
@@ -367,6 +399,8 @@ class CostEstimator:
can_free_memories
=
{}
can_free_vars
=
set
()
dist_op
=
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_op
:
continue
process_mesh
=
dist_op
.
dist_attr
.
process_mesh
for
var_name
in
op
.
input_arg_names
:
input_dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
...
...
@@ -378,24 +412,30 @@ class CostEstimator:
has_used_var
=
var_name
+
key
var
=
dist_op
.
get_serial_input
(
var_name
)
# Not used
if
var_name
+
key
not
in
has_used_vars
:
if
(
has_used_var
not
in
has_used_vars
and
has_used_var
not
in
parameters
):
if
has_used_var
in
not_calc_vars
:
continue
has_used_vars
.
add
(
has_used_var
)
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
memories
:
memories
[
process
]
=
0
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
# Used
else
:
if
op_id
==
var_info
[
var_name
][
key
][
"position"
][
-
1
]:
if
has_used_var
not
in
can_free_vars
:
if
(
has_used_var
not
in
can_free_vars
and
not
var
.
persistable
):
can_free_vars
.
add
(
has_used_var
)
if
not
var
.
persistable
:
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
can_free_memories
:
can_free_memories
[
process
]
=
0
can_free_memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
can_free_memories
[
process
]
+=
var_info
[
var_name
]
[
key
][
"memory"
]
for
var_name
in
op
.
output_arg_names
:
output_dims_mapping
=
dist_op
.
dist_attr
.
get_output_dims_mapping
(
...
...
@@ -406,25 +446,36 @@ class CostEstimator:
)
has_used_var
=
var_name
+
key
var
=
dist_op
.
get_serial_output
(
var_name
)
if
(
op
.
type
==
"reshape2"
or
op
.
type
==
"transpose2"
or
op
.
type
==
"elementwise_add"
):
not_calc_vars
.
add
(
has_used_var
)
continue
# Not used
if
var_name
+
key
not
in
has_used_vars
:
if
(
has_used_var
not
in
has_used_vars
and
has_used_var
not
in
parameters
):
has_used_vars
.
add
(
has_used_var
)
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
memories
:
memories
[
process
]
=
0
memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
# Used
else
:
if
op_id
==
var_info
[
var_name
][
key
][
"position"
][
-
1
]:
if
has_used_var
not
in
can_free_vars
:
if
(
has_used_var
not
in
can_free_vars
and
not
var
.
persistable
):
can_free_vars
.
add
(
has_used_var
)
if
not
var
.
persistable
:
for
process
in
process_mesh
.
process_ids
:
if
process
not
in
can_free_memories
:
can_free_memories
[
process
]
=
0
can_free_memories
[
process
]
+=
var_info
[
var_name
][
key
][
"memory"
]
can_free_memories
[
process
]
+=
var_info
[
var_name
]
[
key
][
"memory"
]
# Calc peak memory
for
process
in
memories
:
...
...
@@ -433,7 +484,6 @@ class CostEstimator:
else
:
if
memories
[
process
]
>
self
.
max_memories
[
process
]:
self
.
max_memories
[
process
]
=
memories
[
process
]
# Free memory
for
process
in
can_free_memories
:
if
process
in
memories
:
...
...
@@ -513,7 +563,7 @@ class CostEstimator:
# Padding automatically
max_len
=
0
header
=
[
"Execution Time(
m
s)"
,
"Max Memory(MiB)"
]
header
=
[
"Execution Time(
u
s)"
,
"Max Memory(MiB)"
]
vals
=
[
round
(
self
.
global_cost
.
time
,
3
),
int
(
self
.
max_memory
//
1e6
)]
for
memory
in
vals
+
header
:
if
len
(
str
(
memory
))
>
max_len
:
...
...
python/paddle/distributed/auto_parallel/reshard.py
浏览文件 @
d6011cb6
...
...
@@ -2716,6 +2716,8 @@ class Resharder:
)
# simplified processing: ignore union process mesh and output reshard
dist_op
=
self
.
dist_context
.
get_dist_op_for_program
(
op
)
if
not
dist_tensor
or
not
dist_op
:
return
reshard_op_cost
dims_mapping
=
dist_op
.
dist_attr
.
get_input_dims_mapping
(
tensor
.
name
)
...
...
python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
浏览文件 @
d6011cb6
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py
浏览文件 @
d6011cb6
...
...
@@ -100,10 +100,10 @@ class TestRuleBasedTuner(unittest.TestCase):
modeling
.
init_global
()
train_program
=
static
.
Program
()
start_program
=
static
.
Program
()
place
=
paddle
.
set_device
(
"gpu"
)
batch_size
=
8
sequence_len
=
512
vocab_size
=
1000
place
=
None
train_program
,
start_program
,
loss
,
gen_data
=
get_gpt_model
(
train_program
,
start_program
,
...
...
@@ -112,31 +112,29 @@ class TestRuleBasedTuner(unittest.TestCase):
sequence_len
,
vocab_size
,
)
from
paddle.distributed.auto_parallel.cluster
import
Cluster
from
paddle.distributed.auto_parallel.dist_context
import
(
DistributedContext
,
)
from
paddle.distributed.auto_parallel.process_mesh
import
ProcessMesh
from
paddle.distributed.auto_parallel.tuner.rule_based_tuner
import
(
RuleBasedTuner
,
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
0.2
)
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
cluster
=
Cluster
()
cluster
.
gen_default_config_cluster
(
node_count
=
1
,
device_count
=
8
)
dist_context
=
DistributedContext
(
serial_main_prog
=
train_program
,
serial_startup_prog
=
start_program
,
serial_optimizer
=
opt
,
serial_loss
=
loss
,
cluster
=
cluster
,
)
dist_context
.
initialize
()
tuner
=
RuleBasedTuner
(
dist_context
)
tuner
.
cluster_operators
()
tuner
.
gen_full_program
()
tuner
.
match_program
(
tuner
.
_dist_context
.
serial_main_program
)
process_mesh
=
ProcessMesh
([
0
,
1
])
tuner
.
gen_fwd_sub_programs_by_clone
()
tuner
.
complete_sub_fwd_programs
(
process_mesh
)
tuner
.
complete_sub_bwd_programs
()
tuner
.
tune
()
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录