Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f2226441
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f2226441
编写于
1月 27, 2022
作者:
C
caozhou
提交者:
GitHub
1月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Auto Parallel】Update Planner (#39201)
* update planner * update unitest * update dist matmul * update auto converter
上级
2b9bb8bb
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
260 addition
and
41 deletion
+260
-41
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+16
-4
python/paddle/distributed/auto_parallel/planner.py
python/paddle/distributed/auto_parallel/planner.py
+24
-35
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+7
-2
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/auto_parallel/auto_parallel_relaunch_with_gpt_planner.py
.../auto_parallel/auto_parallel_relaunch_with_gpt_planner.py
+148
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_relaunch_with_gpt_planner.py
...unittests/auto_parallel/test_relaunch_with_gpt_planner.py
+62
-0
python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py
...n/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py
+1
-0
未找到文件。
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
f2226441
...
...
@@ -165,6 +165,18 @@ def _is_auto_compatible_for_matmul(dist_op):
if
y_dims_mapping_len
==
1
:
y_dims_mapping
.
insert
(
1
,
-
1
)
# NOTE: Partition is not supported if matmul op has trans.
if
op_desc
.
type
()
==
"matmul_v2"
:
if
op_desc
.
attr
(
'trans_x'
)
or
op_desc
.
attr
(
'trans_y'
):
if
x_dims_mapping
[
-
2
:]
!=
[
-
1
,
-
1
]
or
y_dims_mapping
[
-
2
:]
!=
[
-
1
,
-
1
]:
return
False
elif
op_desc
.
type
()
==
"matmul"
:
if
op_desc
.
attr
(
'transpose_X'
)
or
op_desc
.
attr
(
'transpose_Y'
):
if
x_dims_mapping
[
-
2
:]
!=
[
-
1
,
-
1
]
or
y_dims_mapping
[
-
2
:]
!=
[
-
1
,
-
1
]:
return
False
# Deal with dim > 2 and take care of broadcasting
if
out_dims_mapping_len
>
2
:
broadcast_x_dims_mapping
=
[]
...
...
@@ -550,7 +562,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
# TODO infer logic comm presentation
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
1
]
Weight_var
.
name
)[
-
1
]
assert
matmul_col_dim_mapping
>=
0
,
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_col_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
...
...
@@ -775,7 +787,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# TODO infer logic comm presentation
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
0
]
Weight_var
.
name
)[
-
2
]
assert
matmul_row_dim_mapping
>=
0
,
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_row_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
...
...
@@ -1064,7 +1076,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
# TODO infer logic comm presentation
matmul_col_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
1
]
Weight_var
.
name
)[
-
1
]
assert
matmul_col_dim_mapping
>=
0
,
"col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_col_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
...
...
@@ -1283,7 +1295,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
# TODO infer logic comm presentation
matmul_row_dim_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
Weight_var
.
name
)[
0
]
Weight_var
.
name
)[
-
2
]
assert
matmul_row_dim_mapping
>=
0
,
"row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]"
.
format
(
matmul_row_dim_mapping
)
process_mesh_shape
=
op_dist_attr
.
process_mesh
.
topology
...
...
python/paddle/distributed/auto_parallel/planner.py
浏览文件 @
f2226441
...
...
@@ -84,34 +84,22 @@ class PlanFilter:
@
staticmethod
def
check_dims_mapping_for_special_op
(
op
,
op_dist_attr
,
vars
):
if
op
.
type
==
"layer_norm"
:
bias_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
op
.
input
(
"Bias"
)[
0
])
scale_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
op
.
input
(
"Scale"
)[
0
])
x_dims_mapping
=
op_dist_attr
.
get_input_dims_mapping
(
op
.
input
(
"X"
)[
0
])
mean_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
op
.
output
(
"Mean"
)[
0
])
variance_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
op
.
output
(
"Variance"
)[
0
])
y_dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
op
.
output
(
"Y"
)[
0
])
if
x_dims_mapping
!=
y_dims_mapping
:
return
False
if
scale_dims_mapping
[
0
]
!=
x_dims_mapping
[
-
1
]:
return
False
if
bias_dims_mapping
[
0
]
!=
y_dims_mapping
[
-
1
]:
return
False
if
mean_dims_mapping
[
0
]
!=
x_dims_mapping
[
0
]:
return
False
if
variance_dims_mapping
[
0
]
!=
x_dims_mapping
[
0
]:
return
False
# NOTE: Those ops has some partition limits, and will be solved when corresponding dist op implemented in the future.
if
op
.
type
==
"elementwise_add"
or
op
.
type
==
'layer_norm'
or
op
.
type
==
"softmax_with_cross_entropy"
:
for
name
in
op
.
input_arg_names
:
for
item
in
op_dist_attr
.
get_input_dims_mapping
(
name
):
if
item
!=
-
1
:
return
False
for
name
in
op
.
output_arg_names
:
for
item
in
op_dist_attr
.
get_output_dims_mapping
(
name
):
if
item
!=
-
1
:
return
False
if
op
.
type
==
"lookup_table_v2"
:
for
name
in
op
.
input_arg_names
:
if
name
==
'pos_embeddings'
:
for
item
in
op_dist_attr
.
get_input_dims_mapping
(
name
):
if
item
!=
-
1
:
return
False
return
True
...
...
@@ -426,13 +414,14 @@ class MCMC(SearchAlgorithm):
var_name
)
==
dims_mapping
:
dist_context
.
set_op_dist_attr_for_program
(
search_op
,
op_dist_attr
)
tensor_dist_attr
=
TensorDistributedAttribute
(
)
tensor_dist_attr
.
process_mesh
=
op_dist_attr
.
process_mesh
tensor_dist_attr
.
dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
var_name
)
dist_context
.
set_tensor_dist_attr_for_program
(
vars
[
var_name
],
tensor_dist_attr
)
for
name
in
search_op
.
output_arg_names
:
tensor_dist_attr
=
TensorDistributedAttribute
(
)
tensor_dist_attr
.
process_mesh
=
op_dist_attr
.
process_mesh
tensor_dist_attr
.
dims_mapping
=
op_dist_attr
.
get_output_dims_mapping
(
name
)
dist_context
.
set_tensor_dist_attr_for_program
(
vars
[
name
],
tensor_dist_attr
)
has_changed
=
True
break
if
has_changed
:
...
...
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
f2226441
...
...
@@ -593,8 +593,10 @@ def load_parameter_into_program(param_dict, program):
param_dict(dict): parameters' name and value.
program(Program): the program to be updated
"""
_check_param_dict
(
param_
dict
)
assert
isinstance
(
param_dict
,
dict
)
assert
program
and
isinstance
(
program
,
paddle
.
fluid
.
framework
.
Program
)
if
not
param_dict
:
return
program
.
set_state_dict
(
param_dict
)
...
...
@@ -705,7 +707,6 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr):
dist_param_dict(dict): parameters' value of current rank.
"""
assert
_check_dist_attr
(
pre_dist_attr
),
"'pre_dist_attr' cannot be None."
assert
_check_dist_attr
(
cur_dist_attr
),
"'pre_dist_attr' cannot be None."
assert
isinstance
(
dist_param_dict
,
dict
),
\
"The type of 'dist_param_dict' should be 'dict', but got {}."
.
format
(
str
(
type
(
dist_param_dict
)))
...
...
@@ -720,6 +721,9 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr):
"The value of 'dist_param_dict' is parameter's value of all ranks, "
"and its type should be 'list(numpy.ndarray)'."
)
if
cur_dist_attr
is
None
:
return
{}
param_not_in_pre
=
[]
param_not_in_cur
=
[]
logging
.
info
(
"Start to merge and slice parameters."
)
...
...
@@ -1268,6 +1272,7 @@ def get_all_distributed_main_program(serial_program_info, dist_context,
used_dist_context
.
_dist_op_context
=
DistributedOperatorContext
()
_
,
_
,
dist_startup_program
,
dist_main_program
,
_
=
copied_parallelizer
.
_get_dist_program
(
rank_id
,
used_dist_context
)
# print("dist_main_program: ", dist_main_program)
all_dist_main_program
.
append
(
dist_main_program
)
return
all_dist_main_program
...
...
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
f2226441
...
...
@@ -5,4 +5,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties
(
test_auto_parallel_relaunch PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 120
)
py_test_modules
(
test_relaunch_with_planner MODULES test_relaunch_with_planner ENVS
${
dist_ENVS
}
)
set_tests_properties
(
test_relaunch_with_planner PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 120
)
py_test_modules
(
test_relaunch_with_gpt_planner MODULES test_relaunch_with_planner ENVS
${
dist_ENVS
}
)
set_tests_properties
(
test_relaunch_with_gpt_planner PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 240
)
endif
()
python/paddle/fluid/tests/unittests/auto_parallel/auto_parallel_relaunch_with_gpt_planner.py
0 → 100644
浏览文件 @
f2226441
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.static
as
static
from
paddle.distributed
import
fleet
import
sys
import
numpy
as
np
import
paddle.distributed.auto_parallel
as
auto
from
auto_parallel_relaunch_model
import
mlp_pretrain_forward
from
auto_parallel_relaunch_model
import
batch_generator_creator
sys
.
path
.
append
(
".."
)
import
auto_parallel_gpt_model
as
modeling
from
auto_parallel_gpt_model
import
GPTModel
,
GPTForPretraining
,
GPTPretrainingCriterion
def
get_gpt_model
(
train_program
,
start_program
,
place
,
batch_size
,
sequence_len
,
vocab_size
):
modeling
.
init_global
()
with
static
.
program_guard
(
train_program
,
start_program
):
tokens
=
paddle
.
static
.
data
(
name
=
"tokens"
,
shape
=
[
batch_size
,
sequence_len
],
dtype
=
'int64'
)
position_ids
=
paddle
.
static
.
data
(
name
=
"position_ids"
,
shape
=
[
batch_size
,
sequence_len
],
dtype
=
'int64'
)
attention_mask
=
paddle
.
static
.
data
(
name
=
"attention_mask"
,
shape
=
[
batch_size
,
1
,
sequence_len
,
sequence_len
],
dtype
=
'float32'
)
labels
=
paddle
.
static
.
data
(
name
=
"labels"
,
shape
=
[
batch_size
,
sequence_len
],
dtype
=
'int64'
)
loss_mask
=
paddle
.
static
.
data
(
name
=
"loss_mask"
,
shape
=
[
batch_size
,
sequence_len
],
dtype
=
'float32'
)
data_holder
=
[
tokens
,
position_ids
,
attention_mask
,
labels
,
loss_mask
]
gpt
=
GPTModel
(
vocab_size
=
1000
,
hidden_size
=
64
,
num_hidden_layers
=
2
,
num_attention_heads
=
8
,
intermediate_size
=
256
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.0
,
attention_probs_dropout_prob
=
0.0
,
max_position_embeddings
=
1024
,
type_vocab_size
=
1
,
initializer_range
=
0.02
,
pad_token_id
=
0
,
eos_token_id
=
7
,
bos_token_id
=
0
,
eol_token_id
=
3
)
model
=
GPTForPretraining
(
gpt
,
vocab_size
=
1000
,
hidden_size
=
64
,
initializer_range
=
0.02
)
preds
=
model
(
tokens
,
position_ids
,
attention_mask
)
criterion
=
GPTPretrainingCriterion
()
loss
=
criterion
(
preds
,
labels
,
loss_mask
)
def
gen_data
():
np
.
random
.
seed
(
2021
)
tokens
=
[]
position_ids
=
[]
attention_mask
=
[]
labels
=
[]
loss_mask
=
[]
for
_
in
range
(
batch_size
):
tokens
.
append
(
np
.
random
.
randint
(
vocab_size
,
size
=
sequence_len
))
position_ids
.
append
(
np
.
arange
(
sequence_len
))
attention_mask
.
append
([
np
.
tril
(
np
.
ones
(
sequence_len
))])
labels
.
append
(
np
.
random
.
randint
(
vocab_size
,
size
=
sequence_len
))
loss_mask
.
append
(
np
.
ones
(
sequence_len
))
return
tokens
,
position_ids
,
attention_mask
,
labels
,
loss_mask
return
train_program
,
start_program
,
loss
,
gen_data
def
train
():
dist_strategy
=
fleet
.
DistributedStrategy
()
# init parallel optimizer
dist_strategy
.
auto_search
=
True
fleet
.
init
(
is_collective
=
True
,
strategy
=
dist_strategy
)
train_program
=
static
.
Program
()
start_program
=
static
.
Program
()
place
=
paddle
.
set_device
(
"gpu"
)
gpus
=
[
0
,
1
]
batch_size
=
8
sequence_len
=
512
vocab_size
=
1000
train_program
,
start_program
,
loss
,
gen_data
=
get_gpt_model
(
train_program
,
start_program
,
place
,
batch_size
,
sequence_len
,
vocab_size
)
optimizer
=
paddle
.
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.00001
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-08
,
grad_clip
=
None
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
_
,
_
,
distributed_startup_program
,
distributed_main_program
=
optimizer
.
minimize
(
loss
,
start_program
)
places
=
static
.
cuda_places
()
exe
=
paddle
.
static
.
Executor
(
places
[
0
])
exe
.
run
(
distributed_startup_program
)
for
step
in
range
(
10
):
tokens
,
position_ids
,
attention_mask
,
labels
,
loss_mask
=
gen_data
()
if
loss
.
name
in
distributed_main_program
.
global_block
().
vars
:
loss_print
,
=
exe
.
run
(
distributed_main_program
,
feed
=
{
"tokens"
:
tokens
,
"position_ids"
:
position_ids
,
"attention_mask"
:
attention_mask
,
"labels"
:
labels
,
"loss_mask"
:
loss_mask
},
fetch_list
=
[
loss
])
print
(
"step: %s, loss: %f"
%
(
step
,
loss_print
[
0
]))
else
:
exe
.
run
(
distributed_main_program
,
feed
=
{
"tokens"
:
tokens
,
"position_ids"
:
position_ids
,
"attention_mask"
:
attention_mask
,
"labels"
:
labels
,
"loss_mask"
:
loss_mask
})
print
(
"step: %s, loss: %s"
%
(
step
,
"None"
))
if
__name__
==
"__main__"
:
train
()
python/paddle/fluid/tests/unittests/auto_parallel/test_relaunch_with_gpt_planner.py
0 → 100644
浏览文件 @
f2226441
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
os
import
sys
import
json
import
shutil
import
subprocess
from
paddle.distributed.fleet.launch_utils
import
run_with_coverage
class
TestPlannerReLaunch
(
unittest
.
TestCase
):
def
test_relaunch_with_planner
(
self
):
from
test_auto_parallel_relaunch
import
cluster_json
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
cluster_json_path
=
os
.
path
.
join
(
file_dir
,
"auto_parallel_cluster.json"
)
cluster_json_object
=
json
.
loads
(
cluster_json
)
with
open
(
cluster_json_path
,
"w"
)
as
cluster_json_file
:
json
.
dump
(
cluster_json_object
,
cluster_json_file
)
launch_model_path
=
os
.
path
.
join
(
file_dir
,
"auto_parallel_relaunch_with_gpt_planner.py"
)
if
os
.
environ
.
get
(
"WITH_COVERAGE"
,
"OFF"
)
==
"ON"
:
coverage_args
=
[
"-m"
,
"coverage"
,
"run"
,
"--branch"
,
"-p"
]
else
:
coverage_args
=
[]
cmd
=
[
sys
.
executable
,
"-u"
]
+
coverage_args
+
[
"-m"
,
"launch"
,
"--cluster_topo_path"
,
cluster_json_path
,
"--enable_auto_mapping"
,
"True"
,
launch_model_path
]
process
=
subprocess
.
Popen
(
cmd
)
process
.
wait
()
self
.
assertEqual
(
process
.
returncode
,
0
)
# Remove unnecessary files
if
os
.
path
.
exists
(
cluster_json_path
):
os
.
remove
(
cluster_json_path
)
rank_mapping_json_path
=
os
.
path
.
join
(
file_dir
,
"auto_parallel_rank_mapping.json"
)
if
os
.
path
.
exists
(
rank_mapping_json_path
):
os
.
remove
(
rank_mapping_json_path
)
log_path
=
os
.
path
.
join
(
file_dir
,
"log"
)
if
os
.
path
.
exists
(
log_path
):
shutil
.
rmtree
(
log_path
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py
浏览文件 @
f2226441
...
...
@@ -34,6 +34,7 @@ paddle.enable_static()
def
init_global
():
global
_global_parallel_strategy
_global_parallel_strategy
=
None
global
_global_process_mesh
global
PP_MESH_LIST
global
DPPP_MESH_LIST
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录