Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3980e222
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3980e222
编写于
3月 23, 2022
作者:
Z
zhaoyingli
提交者:
GitHub
3月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] engine & dist_saver (#40528)
* add dist_saver and update engine * add dist_saver and update engine
上级
ff7cbaae
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
671 addition
and
128 deletion
+671
-128
python/paddle/distributed/auto_parallel/dist_loader.py
python/paddle/distributed/auto_parallel/dist_loader.py
+26
-6
python/paddle/distributed/auto_parallel/dist_saver.py
python/paddle/distributed/auto_parallel/dist_saver.py
+241
-0
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+236
-117
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+8
-0
python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py
.../paddle/fluid/tests/unittests/auto_parallel/engine_api.py
+10
-5
python/paddle/fluid/tests/unittests/auto_parallel/engine_predict_api.py
...fluid/tests/unittests/auto_parallel/engine_predict_api.py
+122
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_engine_api.py
...le/fluid/tests/unittests/auto_parallel/test_engine_api.py
+28
-0
未找到文件。
python/paddle/distributed/auto_parallel/dist_loader.py
浏览文件 @
3980e222
...
...
@@ -15,6 +15,7 @@
import
abc
import
numpy
as
np
import
paddle
from
.utils
import
to_list
from
paddle.io
import
DataLoader
,
DistributedBatchSampler
...
...
@@ -51,10 +52,11 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
places
,
batch_size
=
1
,
epochs
=
1
,
steps_per_epoch
=
1000
,
steps_per_epoch
=
None
,
data_parallel_world_size
=
None
,
data_parallel_rank
=
None
,
drop_last
=
False
):
drop_last
=
False
,
inputs
=
[]):
self
.
feed_list
=
feed_list
self
.
places
=
places
self
.
steps_per_epoch
=
steps_per_epoch
...
...
@@ -62,6 +64,8 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
dataset
,
batch_size
,
epochs
,
data_parallel_world_size
,
data_parallel_rank
,
drop_last
)
self
.
_inner_dataloader
=
self
.
_create_inner_dataloader
()
self
.
_steps
=
self
.
_infer_steps
()
self
.
_inputs
=
inputs
def
__iter__
(
self
):
self
.
_cur_step
=
0
...
...
@@ -69,22 +73,38 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
return
self
def
__next__
(
self
):
if
self
.
_cur_step
<
self
.
steps_per_epoch
:
if
self
.
_cur_step
<
self
.
_steps
:
self
.
_cur_step
+=
1
else
:
self
.
_inner_dataloader
.
reset
()
raise
StopIteration
def
_infer_steps
(
self
):
if
self
.
steps_per_epoch
is
not
None
:
return
self
.
steps_per_epoch
try
:
steps_per_epoch
=
len
(
self
.
dataset
)
//
self
.
batch_size
except
:
raise
ValueError
(
"Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
)
return
steps_per_epoch
def
_create_inner_dataloader
(
self
):
def
data_generator
():
batch_data
=
None
for
step
,
data
in
enumerate
(
self
.
dataset
):
if
not
isinstance
(
data
,
list
):
data
=
to_list
(
data
)
if
batch_data
is
None
:
batch_data
=
[[]
for
i
in
range
(
len
(
data
))]
for
idx
,
data_item
in
enumerate
(
data
):
batch_data
[
idx
].
append
(
np
.
array
(
data_item
))
for
idx
in
range
(
len
(
data
)):
batch_data
[
idx
].
append
(
data
[
idx
])
if
(
step
+
1
)
%
self
.
batch_size
==
0
:
yield
batch_data
[
0
],
batch_data
[
1
]
yield
batch_data
batch_data
=
None
dataloader
=
paddle
.
fluid
.
io
.
DataLoader
.
from_generator
(
...
...
python/paddle/distributed/auto_parallel/dist_saver.py
0 → 100644
浏览文件 @
3980e222
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import
re
import
os
import
errno
import
pickle
import
warnings
import
logging
import
numpy
as
np
import
paddle
from
paddle
import
fluid
from
paddle.fluid
import
core
from
paddle.fluid.framework
import
static_only
from
.utils
import
get_dist_attr
from
.converter
import
Converter
from
.process_group
import
_g_process_group_map
from
..utils
import
get_logger
def
check_filename
(
re_exp
,
filename
):
if
re
.
search
(
re_exp
,
filename
):
return
True
else
:
return
False
def
_process_path
(
path
):
filename
=
os
.
path
.
basename
(
path
)
if
filename
==
""
:
raise
ValueError
(
"path should be of 'dirname/filename' format, but received filename is empty string"
)
try
:
dirname
=
os
.
path
.
dirname
(
path
)
os
.
makedirs
(
dirname
)
except
OSError
as
e
:
if
e
.
errno
!=
errno
.
EEXIST
:
raise
return
dirname
,
filename
class
DistributedSaver
:
def
__init__
(
self
):
self
.
_logger
=
get_logger
(
logging
.
INFO
)
def
save
(
self
,
path
,
serial_program
,
dist_main_program
,
dist_context
):
dirname
,
filename
=
_process_path
(
path
)
rank_id
=
paddle
.
distributed
.
get_rank
()
# save serial program when rank id is 0
if
rank_id
==
0
:
self
.
_save_rank_mapping
(
dirname
)
serial_model_filename
=
filename
+
"_serial.pdmodel"
serial_model_path
=
os
.
path
.
join
(
dirname
,
serial_model_filename
)
with
open
(
serial_model_path
,
"wb"
)
as
f
:
f
.
write
(
serial_program
.
desc
.
serialize_to_string
())
# save distributed main program
dist_model_filename
=
filename
+
"_dist"
+
str
(
rank_id
)
+
".pdmodel"
dist_model_path
=
os
.
path
.
join
(
dirname
,
dist_model_filename
)
with
open
(
dist_model_path
,
"wb"
)
as
f
:
f
.
write
(
dist_main_program
.
desc
.
serialize_to_string
())
# save distributed params
dist_param_filename
=
filename
+
"_dist"
+
str
(
rank_id
)
+
".pdparams"
dist_param_path
=
os
.
path
.
join
(
dirname
,
dist_param_filename
)
dist_param
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
dist_main_program
.
state_dict
().
items
()
}
with
open
(
dist_param_path
,
"wb"
)
as
f
:
pickle
.
dump
(
dist_param
,
f
)
# save distributed attribute
dist_attr_filename
=
filename
+
"_dist"
+
str
(
rank_id
)
+
".pdattr"
dist_attr_path
=
os
.
path
.
join
(
dirname
,
dist_attr_filename
)
dist_attrs
=
get_dist_attr
(
dist_main_program
,
dist_context
)
with
open
(
dist_attr_path
,
"wb"
)
as
f
:
pickle
.
dump
(
dist_attrs
,
f
)
# TODO:save cluster.json
def
load
(
self
,
path
,
program
,
dist_context
,
strict
=
True
,
load_optimizer
=
True
):
# TODO: if `program` is None, load `path.pdmodel`.
filename
=
os
.
path
.
basename
(
path
)
if
filename
==
""
:
raise
ValueError
(
"path should be of 'dirname/filename' format, but received filename is empty string"
)
dirname
=
os
.
path
.
dirname
(
path
)
# load path.pdparam
param_file_list
=
[]
for
param_file
in
os
.
listdir
(
dirname
):
if
check_filename
(
'{}(.*)_dist(.*).pdparams'
.
format
(
filename
),
param_file
):
param_file_list
.
append
(
os
.
path
.
join
(
dirname
,
param_file
))
param_file_list
.
sort
()
self
.
_logger
.
info
(
"Load distributed attribute file: {}"
.
format
(
param_file_list
))
param_dict
=
{}
for
param_file
in
param_file_list
:
with
open
(
param_file
,
'rb'
)
as
f
:
state_dict_info
=
pickle
.
load
(
f
,
encoding
=
'latin1'
)
for
name
,
value
in
state_dict_info
.
items
():
if
name
in
param_dict
:
param_dict
[
name
].
append
(
np
.
array
(
value
))
else
:
param_dict
[
name
]
=
[
np
.
array
(
value
)]
# load path.pdattr
dist_attr_file_list
=
[]
for
dist_attr_file
in
os
.
listdir
(
dirname
):
if
check_filename
(
'{}(.*)_dist(.*).pdattr'
.
format
(
filename
),
dist_attr_file
):
dist_attr_file_list
.
append
(
os
.
path
.
join
(
dirname
,
dist_attr_file
))
dist_attr_file_list
.
sort
()
self
.
_logger
.
info
(
"Load distributed attribute file: {}"
.
format
(
dist_attr_file_list
))
pre_dist_attr
=
{}
for
dist_attr_file
in
dist_attr_file_list
:
with
open
(
dist_attr_file
,
'rb'
)
as
f
:
dist_attr
=
pickle
.
load
(
f
,
encoding
=
'latin1'
)
for
name
,
attr
in
dist_attr
.
items
():
if
name
not
in
pre_dist_attr
:
pre_dist_attr
[
name
]
=
attr
# get current dist_attr
cur_dist_attr
=
get_dist_attr
(
program
,
dist_context
)
# param convert
converter
=
Converter
(
param_dict
,
pre_dist_attr
,
cur_dist_attr
)
param_dict
=
converter
.
convert
(
strict
=
strict
)
program
.
set_state_dict
(
param_dict
)
def
save_inference_model
(
self
,
path
,
feed_vars
,
fetch_vars
,
exe
,
**
kwargs
):
dirname
,
filename
=
_process_path
(
path
)
# save distributed inference program
rank_id
=
paddle
.
distributed
.
get_rank
()
if
rank_id
==
0
:
self
.
_save_rank_mapping
(
dirname
)
op_role_key
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
op_role_forward
=
int
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
dist_main_prog
=
kwargs
.
get
(
'program'
,
None
)
if
not
dist_main_prog
:
dist_main_prog
=
fluid
.
default_main_program
()
global_block
=
dist_main_prog
.
global_block
()
ops
=
global_block
.
ops
feed_vars_names
=
list
(
map
(
lambda
x
:
x
.
name
,
feed_vars
))
fetch_vars_names
=
list
(
map
(
lambda
x
:
x
.
name
,
fetch_vars
))
last_idx
=
-
1
for
idx
,
op
in
enumerate
(
ops
):
if
op
.
attr
(
op_role_key
)
!=
op_role_forward
:
continue
if
op
.
type
==
"read"
or
op
.
type
==
"feed"
or
op
.
type
==
'recv_v2'
:
feed_vars_names
+=
op
.
output
(
"Out"
)
if
op
.
type
==
"send_v2"
:
fetch_vars_names
+=
op
.
input
(
"X"
)
last_idx
=
max
(
idx
,
last_idx
)
for
out_name
in
op
.
output_arg_names
:
if
out_name
in
fetch_vars_names
:
last_idx
=
max
(
idx
,
last_idx
)
used_inputs
=
[]
used_outputs
=
[]
for
idx
,
op
in
enumerate
(
ops
):
if
idx
>
last_idx
:
break
used_inputs
+=
op
.
input_arg_names
used_outputs
+=
op
.
output_arg_names
dist_feed_vars_names
=
list
(
set
(
feed_vars_names
)
&
set
(
used_inputs
))
dist_fetch_vars_names
=
list
(
set
(
fetch_vars_names
)
&
set
(
used_outputs
))
dist_feed_vars
=
[
global_block
.
vars
[
name
]
for
name
in
dist_feed_vars_names
]
dist_fetch_vars
=
[
global_block
.
vars
[
name
]
for
name
in
dist_fetch_vars_names
]
# NOTE: `paddle.static.save_inference_model` does not support subblock.
dist_filename
=
filename
+
"_dist"
+
str
(
rank_id
)
dist_path
=
os
.
path
.
join
(
dirname
,
dist_filename
)
paddle
.
static
.
save_inference_model
(
dist_path
,
dist_feed_vars
,
dist_fetch_vars
,
exe
,
program
=
dist_main_prog
)
def
_save_rank_mapping
(
self
,
dirname
):
path
=
os
.
path
.
join
(
dirname
,
'rank_mapping.csv'
)
f
=
open
(
path
,
'w'
)
f
.
write
(
'[ring_id -> ranks]
\n
'
)
for
process_group
in
_g_process_group_map
.
values
():
ring_id
=
process_group
.
_group_id
ranks
=
[
str
(
rank
)
for
rank
in
process_group
.
_ranks
]
id_to_rank
=
str
(
ring_id
)
+
","
+
","
.
join
(
ranks
)
+
'
\n
'
f
.
write
(
id_to_rank
)
id_to_rank
=
""
f
.
write
(
'[rank -> ring_ids]
\n
'
)
rank_to_id_dict
=
{}
for
process_group
in
_g_process_group_map
.
values
():
ring_id
=
process_group
.
_group_id
for
rank
in
process_group
.
_ranks
:
if
rank
in
rank_to_id_dict
:
rank_to_id_dict
[
rank
].
append
(
str
(
ring_id
))
else
:
rank_to_id_dict
[
rank
]
=
[
str
(
ring_id
)]
rank_to_id
=
""
for
item
,
val
in
rank_to_id_dict
.
items
():
rank_to_id
+=
str
(
item
)
+
","
rank_to_id
+=
","
.
join
(
val
)
+
"
\n
"
f
.
write
(
rank_to_id
)
rank_to_id
=
""
f
.
close
()
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
3980e222
...
...
@@ -19,138 +19,158 @@ from collections import defaultdict
import
paddle
from
paddle
import
fluid
from
paddle.io
import
Dataset
from
paddle.fluid.backward
import
append_backward
import
paddle.fluid.core
as
core
from
paddle.metric
import
Metric
from
paddle.static
import
InputSpec
from
paddle.fluid
import
core
from
paddle.fluid
import
program_guard
from
paddle.fluid.backward
import
append_backward
from
paddle.fluid.framework
import
Operator
from
paddle.fluid.framework
import
_current_expected_place
as
_get_device
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.distributed.passes
import
new_pass
,
PassContext
from
paddle.distributed.utils
import
get_logger
from
.dist_loader
import
NonIterableGeneratorLoader
from
.dist_op
import
DistributedOperator
from
.dist_tensor
import
DistributedTensor
from
.dist_context
import
DistributedContext
from
.dist_context
import
get_default_distributed_context
from
.dist_context
import
set_default_distributed_context
from
.process_group
import
get_all_process_groups
from
.process_group
import
get_process_group
from
.process_group
import
get_world_process_group
from
.process_group
import
_g_process_group_map
,
ProcessGroup
from
.completion
import
Completer
from
.partitioner
import
Partitioner
from
.reshard
import
reshard
,
HAS_SENT
,
HAS_RECV
,
HAS_ALLGATHER
from
.cluster
import
Cluster
from
.mapper
import
mapping
from
.cluster
import
Cluster
from
.reshard
import
reshard
from
.planner
import
Planner
from
.utils
import
make_data_unshard
from
.utils
import
set_grad_var_shape
from
.utils
import
print_program_with_dist_attr
from
.utils
import
SerialProgramInfo
from
.completion
import
Completer
from
.partitioner
import
Partitioner
from
.dist_op
import
DistributedOperator
from
.dist_saver
import
DistributedSaver
from
.dist_loader
import
NonIterableGeneratorLoader
from
.utils
import
make_data_unshard
,
set_grad_var_shape
from
.utils
import
print_program_with_dist_attr
,
to_list
from
.process_group
import
get_all_process_groups
,
get_world_process_group
from
.dist_context
import
DistributedContext
,
get_default_distributed_context
paddle
.
enable_static
()
def
to_list
(
value
):
if
value
is
None
:
return
value
if
isinstance
(
value
,
(
list
,
tuple
)):
return
list
(
value
)
return
[
value
]
class
Engine
:
def
__init__
(
self
,
model
=
None
,
data_spec
=
None
,
cluster
=
None
,
strategy
=
None
):
def
__init__
(
self
,
model
=
None
,
inputs_spec
=
None
,
labels_spec
=
None
,
cluster
=
None
,
strategy
=
None
):
self
.
model
=
model
self
.
data_spec
=
data_spec
self
.
inputs_spec
=
self
.
_validate_spec
(
inputs_spec
)
self
.
labels_spec
=
self
.
_validate_spec
(
labels_spec
)
self
.
cluster
=
cluster
self
.
strategy
=
strategy
self
.
_executor
=
None
self
.
_orig_main_prog
=
fluid
.
default_main_program
()
self
.
_orig_startup_prog
=
fluid
.
default_startup_program
()
self
.
_orig_dist_context
=
get_default_distributed_context
()
self
.
_serial_main_progs
=
{}
self
.
_serial_startup_progs
=
{}
self
.
_dist_main_progs
=
defaultdict
(
dict
)
self
.
_dist_startup_progs
=
defaultdict
(
dict
)
self
.
_orig_dist_context
=
get_default_distributed_context
()
self
.
_dist_main_progs
=
defaultdict
(
dict
)
# dist main programs
self
.
_dist_startup_progs
=
defaultdict
(
dict
)
# dist startup programs
self
.
_dist_contexts
=
{}
self
.
_pass_contexts
=
{}
self
.
_cur_rank
=
paddle
.
distributed
.
get_rank
()
self
.
_logger
=
get_logger
(
logging
.
INFO
)
self
.
_saver
=
DistributedSaver
()
self
.
_feed_vars
=
{}
self
.
_fetch_vars
=
{}
def
prepare
(
self
,
optimizer
=
None
,
loss
=
None
,
metrics
=
None
,
mode
=
"train"
,
mode
=
'train'
,
all_ranks
=
False
):
self
.
optimizer
=
optimizer
self
.
loss
=
loss
self
.
metrics
=
metrics
self
.
_optimizer
=
optimizer
# TODO: check loss type
self
.
_loss
=
loss
self
.
_metrics
=
to_list
(
metrics
)
for
m
in
[
'train'
,
'predict'
]:
self
.
mode
=
m
self
.
_build
(
m
)
# build forward program
self
.
_plan
(
m
)
# completion & planner
self
.
_parallel
(
m
,
all_ranks
)
# parallel
self
.
_initialize
(
m
)
# init comm and startup program
self
.
mode
=
mode
self
.
_build
()
self
.
_plan
()
if
not
all_ranks
:
self
.
_parallel
(
self
.
_cur_rank
)
else
:
world_process_group
=
get_world_process_group
()
all_ranks
=
world_process_group
.
ranks
for
rank
in
all_ranks
:
self
.
_parallel
(
rank
)
self
.
_place
=
_get_device
()
if
isinstance
(
self
.
_place
,
fluid
.
CUDAPlace
):
self
.
_place
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
if
self
.
_executor
is
None
:
self
.
_executor
=
paddle
.
static
.
Executor
(
self
.
_place
)
def
_build
(
self
):
serial_main_prog
=
self
.
_serial_main_progs
.
get
(
self
.
mode
,
None
)
def
_build
(
self
,
mode
):
serial_main_prog
=
self
.
_serial_main_progs
.
get
(
mode
,
None
)
if
serial_main_prog
is
not
None
:
return
losses
=
[]
metrics
=
[]
serial_main_prog
=
self
.
_orig_main_prog
.
clone
()
serial_startup_prog
=
self
.
_orig_startup_prog
.
clone
()
with
fluid
.
program_guard
(
serial_main_prog
,
serial_startup_prog
):
inputs_spec
=
self
.
data_spec
[
0
]
labels_spec
=
self
.
data_spec
[
1
]
inputs
=
[
s
.
_create_feed_layer
()
for
s
in
to_list
(
inputs_spec
)]
labels
=
[
s
.
_create_feed_layer
()
for
s
in
to_list
(
labels_spec
)]
self
.
_input_vars
=
inputs
self
.
_label_vars
=
labels
self
.
_feed_vars
=
self
.
_input_vars
+
self
.
_label_vars
inputs_spec
=
self
.
inputs_spec
labels_spec
=
self
.
labels_spec
if
self
.
labels_spec
else
[]
inputs
=
[
s
.
_create_feed_layer
()
for
s
in
inputs_spec
]
labels
=
[
s
.
_create_feed_layer
()
for
s
in
labels_spec
]
outputs
=
to_list
(
self
.
model
(
*
inputs
))
if
self
.
mode
!=
"predict"
and
self
.
loss
:
loss
=
self
.
loss
(
*
(
outputs
+
labels
))
self
.
_loss_var
=
loss
self
.
_fetch_vars
=
{
"outputs"
:
outputs
,
"loss"
:
loss
}
self
.
_serial_main_progs
[
self
.
mode
]
=
serial_main_prog
self
.
_serial_startup_progs
[
self
.
mode
]
=
serial_startup_prog
self
.
_dist_contexts
[
self
.
mode
]
=
DistributedContext
(
serial_main_prog
,
serial_startup_prog
,
self
.
_dist_main_progs
[
self
.
mode
],
self
.
_dist_startup_progs
[
self
.
mode
])
self
.
_pass_contexts
[
self
.
mode
]
=
PassContext
()
def
_plan
(
self
):
if
mode
!=
"predict"
and
self
.
_loss
:
losses
=
to_list
(
self
.
_loss
(
*
(
outputs
+
labels
)))
self
.
_feed_vars
[
mode
]
=
{
"inputs"
:
inputs
,
"labels"
:
labels
}
self
.
_fetch_vars
[
mode
]
=
{
"outputs"
:
outputs
,
"loss"
:
losses
,
"metrics"
:
metrics
}
self
.
_serial_main_progs
[
mode
]
=
serial_main_prog
self
.
_serial_startup_progs
[
mode
]
=
serial_startup_prog
self
.
_dist_contexts
[
mode
]
=
DistributedContext
(
serial_main_prog
,
serial_startup_prog
,
self
.
_dist_main_progs
[
mode
],
self
.
_dist_startup_progs
[
mode
])
self
.
_pass_contexts
[
mode
]
=
PassContext
()
def
_plan
(
self
,
mode
):
# Complete the distributed annotation
serial_main_prog
=
self
.
_serial_main_progs
[
self
.
mode
]
self
.
_completer
=
Completer
(
self
.
_dist_contexts
[
self
.
mode
])
serial_main_prog
=
self
.
_serial_main_progs
[
mode
]
self
.
_completer
=
Completer
(
self
.
_dist_contexts
[
mode
])
self
.
_completer
.
complete_forward_annotation
(
serial_main_prog
)
# TODO: add auto planner process
# parse forward sub block
self
.
_dist_contexts
[
self
.
mode
].
block_state
.
parse_forward_blocks
(
self
.
_dist_contexts
[
mode
].
block_state
.
parse_forward_blocks
(
serial_main_prog
)
def
_parallel
(
self
,
rank
):
serial_main_program
=
self
.
_serial_main_progs
[
self
.
mode
]
serial_startup_program
=
self
.
_serial_startup_progs
[
self
.
mode
]
dist_context
=
self
.
_dist_contexts
[
self
.
mode
]
if
self
.
mode
!=
"predict"
and
self
.
loss
:
def
_parallel
(
self
,
mode
,
all_ranks
=
False
):
if
not
all_ranks
:
self
.
_parallel_program
(
mode
,
self
.
_cur_rank
)
else
:
world_process_group
=
get_world_process_group
()
all_ranks
=
world_process_group
.
ranks
for
rank
in
all_ranks
:
self
.
_parallel_program
(
mode
,
rank
)
def
_initialize
(
self
,
mode
):
# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
all_process_groups
=
get_all_process_groups
()
for
process_group
in
all_process_groups
:
if
self
.
_cur_rank
not
in
process_group
.
ranks
:
continue
process_group
.
instantiate
()
# initialize
self
.
_place
=
_get_device
()
if
isinstance
(
self
.
_place
,
fluid
.
CUDAPlace
):
self
.
_place
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
if
self
.
_executor
is
None
:
self
.
_executor
=
paddle
.
static
.
Executor
(
self
.
_place
)
dist_startup_prog
=
self
.
_dist_startup_progs
[
mode
][
self
.
_cur_rank
]
self
.
_executor
.
run
(
dist_startup_prog
)
def
_parallel_program
(
self
,
mode
,
rank
):
serial_main_program
=
self
.
_serial_main_progs
[
mode
]
serial_startup_program
=
self
.
_serial_startup_progs
[
mode
]
dist_context
=
self
.
_dist_contexts
[
mode
]
if
mode
==
"train"
and
self
.
_optimizer
:
# Generate backward
serial_loss
=
self
.
_
loss_var
serial_loss
=
self
.
_
fetch_vars
[
mode
][
"loss"
][
0
]
params_grads
=
self
.
_generate_backward
(
serial_main_program
,
serial_startup_program
,
serial_loss
)
# Apply pre optimization passes
...
...
@@ -172,8 +192,23 @@ class Engine:
# Apply post optimization passes
self
.
_apply_post_optimization
(
dist_main_prog
,
dist_startup_prog
,
rank
,
dist_params_grads
)
self
.
_dist_main_progs
[
self
.
mode
][
rank
]
=
dist_main_prog
self
.
_dist_startup_progs
[
self
.
mode
][
rank
]
=
dist_startup_prog
else
:
# Do logical partition
partitioner
=
Partitioner
(
dist_context
,
rank
)
dist_main_prog
,
dist_startup_prog
,
dist_params_grads
=
partitioner
.
partition
(
serial_main_program
,
serial_startup_program
,
[])
# Do reshard process
make_data_unshard
(
dist_main_prog
,
dist_startup_prog
,
dist_context
)
reshard
(
dist_main_prog
,
dist_startup_prog
,
rank
,
dist_context
,
[],
1
)
# clone program for test
if
mode
!=
'train'
:
dist_main_prog
=
dist_main_prog
.
clone
(
for_test
=
True
)
dist_startup_prog
=
dist_startup_prog
.
clone
(
for_test
=
True
)
self
.
_dist_main_progs
[
mode
][
rank
]
=
dist_main_prog
self
.
_dist_startup_progs
[
mode
][
rank
]
=
dist_startup_prog
def
_generate_backward
(
self
,
main_program
,
startup_program
,
loss
):
with
program_guard
(
main_program
,
startup_program
):
...
...
@@ -187,7 +222,7 @@ class Engine:
def
_generate_optimizer
(
self
,
main_program
,
startup_program
,
params_grads
):
with
program_guard
(
main_program
,
startup_program
):
optimizer_ops
=
copy
.
deepcopy
(
self
.
optimizer
).
apply_gradients
(
optimizer_ops
=
copy
.
deepcopy
(
self
.
_
optimizer
).
apply_gradients
(
params_grads
)
self
.
_completer
.
complete_update_annotation
(
main_program
)
return
optimizer_ops
...
...
@@ -239,42 +274,87 @@ class Engine:
[
main_program
],
[
startup_program
],
self
.
_pass_contexts
[
self
.
mode
])
def
fit
(
self
,
train_data
,
batch_size
=
1
,
epochs
=
1
,
steps_per_epoch
=
1000
):
def
fit
(
self
,
train_data
,
batch_size
=
1
,
epochs
=
1
,
steps_per_epoch
=
None
):
# TODO: callbacks
# TODO: evaluate after training
self
.
mode
=
'train'
assert
isinstance
(
train_data
,
Dataset
)
assert
steps_per_epoch
is
not
None
train_dataloader
=
self
.
_create_dataloader
(
train_data
,
batch_size
,
epochs
,
steps_per_epoch
)
self
.
_init_communication
()
dist_startup_prog
=
self
.
_dist_startup_progs
[
"train"
][
self
.
_cur_rank
]
self
.
_executor
.
run
(
dist_startup_prog
)
outputs
=
[]
for
epoch
in
range
(
epochs
):
# train_dataloader.start()
# for step in range(steps_per_epoch):
# logs = self.train_step(None)
# self._logger.info(logs)
# train_dataloader.reset()
for
step
,
data
in
enumerate
(
train_dataloader
):
logs
=
self
.
_train_step
(
data
)
logs
,
loss
=
self
.
_train_step
(
data
)
outputs
.
append
(
loss
)
train_logs
=
{
"train_"
+
name
:
val
for
name
,
val
in
logs
.
items
()
}
self
.
_logger
.
info
(
train_logs
)
return
outputs
def
predict
(
self
,
test_data
,
batch_size
=
1
,
use_program_cache
=
False
,
return_numpy
=
True
):
self
.
mode
=
'predict'
# TODO: need check dataset
test_dataloader
=
self
.
_create_dataloader
(
test_data
,
batch_size
)
outputs
=
[]
for
step
,
data
in
enumerate
(
test_dataloader
):
logs
,
outs
=
self
.
_predict_step
(
data
,
use_program_cache
,
return_numpy
)
outputs
.
append
(
outs
)
predict_logs
=
{
"predict_"
+
name
:
val
for
name
,
val
in
logs
.
items
()
}
self
.
_logger
.
info
(
predict_logs
)
return
outputs
def
_train_step
(
self
,
data
):
logs
=
{}
dist_main_prog
=
self
.
_dist_main_progs
[
"train"
][
self
.
_cur_rank
]
if
self
.
_loss_var
.
name
not
in
dist_main_prog
.
global_block
().
vars
:
dist_main_prog
=
self
.
_dist_main_progs
[
self
.
mode
][
self
.
_cur_rank
]
fetch_var
=
self
.
_fetch_vars
[
self
.
mode
][
"loss"
][
0
]
if
fetch_var
.
name
not
in
dist_main_prog
.
global_block
().
vars
:
loss
=
self
.
_executor
.
run
(
dist_main_prog
)
logs
[
"loss"
]
=
None
else
:
fetch_list
=
self
.
_loss_var
loss
=
self
.
_executor
.
run
(
dist_main_prog
,
fetch_list
=
fetch_list
)
loss
=
self
.
_executor
.
run
(
dist_main_prog
,
fetch_list
=
to_list
(
fetch_var
)
)
logs
[
"loss"
]
=
loss
return
logs
return
logs
,
loss
def
_create_dataloader
(
self
,
dataset
,
batch_size
,
epochs
,
steps_per_epoch
):
feed_list
=
self
.
_input_vars
+
self
.
_label_vars
def
_predict_step
(
self
,
data
,
use_program_cache
=
False
,
return_numpy
=
True
):
logs
=
{}
dist_main_prog
=
self
.
_dist_main_progs
[
self
.
mode
][
self
.
_cur_rank
]
fetch_var
=
[]
for
var
in
self
.
_fetch_vars
[
self
.
mode
][
"outputs"
]:
if
var
.
name
in
dist_main_prog
.
global_block
().
vars
:
fetch_var
.
append
(
var
)
if
fetch_var
is
[]:
outs
=
self
.
_executor
.
run
(
dist_main_prog
,
use_program_cache
=
use_program_cache
)
logs
[
"pred"
]
=
outs
else
:
outs
=
self
.
_executor
.
run
(
dist_main_prog
,
fetch_list
=
fetch_var
,
use_program_cache
=
use_program_cache
,
return_numpy
=
return_numpy
)
logs
[
"pred"
]
=
outs
return
logs
,
outs
def
_create_dataloader
(
self
,
dataset
,
batch_size
,
epochs
=
1
,
steps_per_epoch
=
None
):
feed_list
=
self
.
_feed_vars
[
self
.
mode
][
"inputs"
]
+
self
.
_feed_vars
[
self
.
mode
][
"labels"
]
dist_main_prog
=
self
.
_dist_main_progs
[
self
.
mode
][
self
.
_cur_rank
]
dist_startup_prog
=
self
.
_dist_startup_progs
[
self
.
mode
][
self
.
_cur_rank
]
dist_context
=
self
.
_dist_contexts
[
self
.
mode
]
...
...
@@ -284,8 +364,15 @@ class Engine:
op_size
=
len
(
dist_main_block
.
ops
)
places
=
paddle
.
static
.
cuda_places
()
with
fluid
.
program_guard
(
dist_main_prog
,
dist_startup_prog
):
inputs
=
self
.
_feed_vars
[
self
.
mode
][
"inputs"
]
dataloader
=
NonIterableGeneratorLoader
(
dataset
,
feed_list
,
places
,
batch_size
,
epochs
,
steps_per_epoch
)
dataset
,
feed_list
,
places
,
batch_size
,
epochs
,
steps_per_epoch
,
inputs
=
inputs
)
new_op_size
=
len
(
dist_main_block
.
ops
)
for
_
in
range
(
new_op_size
-
1
,
op_size
-
1
,
-
1
):
op
=
dist_main_block
.
ops
[
new_op_size
-
1
]
...
...
@@ -312,17 +399,49 @@ class Engine:
dist_main_block
.
_sync_with_cpp
()
return
dataloader
def
_init_communication
(
self
):
# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
all_process_groups
=
get_all_process_groups
()
for
process_group
in
all_process_groups
:
if
self
.
_cur_rank
not
in
process_group
.
ranks
:
continue
process_group
.
instantiate
()
def
_validate_spec
(
self
,
specs
):
specs
=
to_list
(
specs
)
if
specs
is
not
None
:
for
i
,
spec
in
enumerate
(
specs
):
assert
isinstance
(
spec
,
InputSpec
)
if
spec
.
name
is
None
:
raise
ValueError
(
"Requires Input[{}].name != None, but receive `None` with {}."
.
format
(
i
,
spec
))
return
specs
def
save
(
self
,
path
,
training
=
True
,
mode
=
None
):
if
not
mode
:
mode
=
self
.
mode
if
training
:
assert
'train'
in
self
.
_serial_main_progs
,
"training model is not ready, please call `engine.prepare(mode='train')` first."
serial_program
=
self
.
_serial_main_progs
[
"train"
]
dist_main_prog
=
self
.
_dist_main_progs
[
"train"
][
self
.
_cur_rank
]
dist_context
=
self
.
_dist_contexts
[
"train"
]
self
.
_saver
.
save
(
path
,
serial_program
=
serial_program
,
dist_main_program
=
dist_main_prog
,
dist_context
=
dist_context
)
else
:
assert
mode
,
"Please set the 'mode' you want to save."
feed_vars
=
self
.
_feed_vars
[
mode
][
'inputs'
]
fetch_vars
=
self
.
_fetch_vars
[
mode
][
'outputs'
]
dist_main_prog
=
self
.
_dist_main_progs
[
mode
][
self
.
_cur_rank
]
self
.
_saver
.
save_inference_model
(
path
,
feed_vars
,
fetch_vars
,
self
.
_executor
,
program
=
dist_main_prog
)
# def save(self, path, training=True):
# pass
def
load
(
self
,
path
,
strict
=
True
,
load_optimizer
=
True
,
mode
=
None
):
if
not
mode
:
mode
=
self
.
mode
assert
mode
,
"Please set the 'mode' you want to load."
# def load(self, path, strict=True, load_optimizer=True):
# pass
dist_main_prog
=
self
.
_dist_main_progs
[
mode
][
self
.
_cur_rank
]
dist_context
=
self
.
_dist_contexts
[
mode
]
self
.
_saver
.
load
(
path
,
dist_main_prog
,
dist_context
,
strict
,
load_optimizer
)
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
3980e222
...
...
@@ -1416,3 +1416,11 @@ def set_dist_op_desc_original_id(dist_op_desc, op_desc, dist_context):
# Third, print error infomation if we cannot find the original id
else
:
assert
False
,
"Cannot find the original id in the distributed context"
def
to_list
(
value
):
if
value
is
None
:
return
value
if
isinstance
(
value
,
(
list
,
tuple
)):
return
list
(
value
)
return
[
value
]
python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py
浏览文件 @
3980e222
...
...
@@ -108,10 +108,8 @@ def train():
grad_clip
=
None
)
dataset
=
MyDataset
(
batch_num
*
batch_size
)
data_spec
=
[
InputSpec
([
batch_size
,
hidden_size
],
'float32'
,
'x'
),
InputSpec
([
batch_size
],
'int64'
,
'label'
)
]
inputs_spec
=
InputSpec
([
batch_size
,
hidden_size
],
'float32'
,
'x'
)
labels_spec
=
InputSpec
([
batch_size
],
'int64'
,
'label'
)
dist_strategy
=
fleet
.
DistributedStrategy
()
dist_strategy
.
amp
=
False
...
...
@@ -121,11 +119,18 @@ def train():
dist_strategy
.
semi_auto
=
True
fleet
.
init
(
is_collective
=
True
,
strategy
=
dist_strategy
)
engine
=
Engine
(
mlp
,
data_spec
,
strategy
=
dist_strategy
)
engine
=
Engine
(
mlp
,
inputs_spec
=
inputs_spec
,
labels_spec
=
labels_spec
,
strategy
=
dist_strategy
)
engine
.
prepare
(
optimizer
,
loss
)
engine
.
fit
(
dataset
,
batch_size
=
batch_size
,
steps_per_epoch
=
batch_num
*
batch_size
)
engine
.
save
(
'./mlp'
)
engine
.
load
(
'./mlp'
)
engine
.
save
(
'./mlp_inf'
,
training
=
False
,
mode
=
'predict'
)
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/auto_parallel/engine_predict_api.py
0 → 100644
浏览文件 @
3980e222
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
time
import
paddle.fluid
as
fluid
import
copy
import
os
import
numpy
as
np
import
subprocess
import
paddle
import
paddle.nn
as
nn
import
paddle.fluid
as
fluid
import
paddle.static
as
static
import
paddle.nn.functional
as
F
import
paddle.utils
as
utils
from
paddle.fluid
import
layers
from
paddle.io
import
Dataset
,
IterableDataset
,
DataLoader
from
paddle.static
import
InputSpec
from
paddle.distributed
import
fleet
import
paddle.distributed.auto_parallel
as
auto
from
paddle.distributed.auto_parallel.engine
import
Engine
paddle
.
enable_static
()
global_process_mesh
=
auto
.
ProcessMesh
(
mesh
=
[
0
,
1
])
batch_size
=
1
batch_num
=
10
hidden_size
=
1024
image_size
=
hidden_size
paddle
.
seed
(
44
)
class
MyDataset
(
Dataset
):
def
__init__
(
self
,
num_samples
):
super
(
MyDataset
,
self
).
__init__
()
self
.
num_samples
=
num_samples
def
__getitem__
(
self
,
index
):
input
=
np
.
random
.
uniform
(
size
=
image_size
).
astype
(
"float32"
)
return
input
def
__len__
(
self
):
return
self
.
num_samples
class
MLPLayer
(
nn
.
Layer
):
def
__init__
(
self
,
hidden_size
=
1024
,
intermediate_size
=
4
*
1024
,
dropout_ratio
=
0.1
,
initializer_range
=
0.02
):
super
(
MLPLayer
,
self
).
__init__
()
d_model
=
hidden_size
dim_feedforward
=
intermediate_size
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Normal
(
mean
=
0.0
,
std
=
initializer_range
))
bias_attr
=
None
self
.
linear0
=
nn
.
Linear
(
d_model
,
dim_feedforward
,
weight_attr
,
bias_attr
=
bias_attr
)
self
.
linear1
=
nn
.
Linear
(
dim_feedforward
,
d_model
,
weight_attr
,
bias_attr
=
bias_attr
)
self
.
linear2
=
nn
.
Linear
(
d_model
,
1
,
weight_attr
,
bias_attr
=
bias_attr
)
self
.
norm
=
nn
.
LayerNorm
(
d_model
,
epsilon
=
1e-5
)
self
.
dropout
=
nn
.
Dropout
(
dropout_ratio
,
mode
=
"upscale_in_train"
)
def
forward
(
self
,
input
):
out
=
self
.
norm
(
input
)
out
=
self
.
linear0
(
input
)
auto
.
shard_tensor
(
self
.
linear0
.
weight
,
dist_attr
=
{
"process_mesh"
:
global_process_mesh
,
"dims_mapping"
:
[
-
1
,
0
]
})
out
=
F
.
gelu
(
out
,
approximate
=
True
)
out
=
self
.
linear1
(
out
)
auto
.
shard_tensor
(
self
.
linear1
.
weight
,
dist_attr
=
{
"process_mesh"
:
global_process_mesh
,
"dims_mapping"
:
[
0
,
-
1
]
})
out
=
self
.
dropout
(
out
)
out
=
self
.
linear2
(
out
)
return
out
def
train
():
mlp
=
MLPLayer
(
hidden_size
=
hidden_size
,
intermediate_size
=
4
*
hidden_size
,
dropout_ratio
=
0.1
,
initializer_range
=
0.02
)
dataset
=
MyDataset
(
batch_num
*
batch_size
)
inputs_spec
=
InputSpec
([
batch_size
,
hidden_size
],
'float32'
,
'x'
)
dist_strategy
=
fleet
.
DistributedStrategy
()
# init parallel optimizer
dist_strategy
.
semi_auto
=
True
fleet
.
init
(
is_collective
=
True
,
strategy
=
dist_strategy
)
engine
=
Engine
(
mlp
,
inputs_spec
=
inputs_spec
,
strategy
=
dist_strategy
)
engine
.
prepare
(
mode
=
'predict'
)
engine
.
predict
(
dataset
,
batch_size
=
batch_size
)
if
__name__
==
"__main__"
:
train
()
python/paddle/fluid/tests/unittests/auto_parallel/test_engine_api.py
浏览文件 @
3980e222
...
...
@@ -42,6 +42,34 @@ class TestEngineAPI(unittest.TestCase):
log_path
=
os
.
path
.
join
(
file_dir
,
"log"
)
if
os
.
path
.
exists
(
log_path
):
shutil
.
rmtree
(
log_path
)
files_path
=
[
path
for
path
in
os
.
listdir
(
'.'
)
if
'.pd'
in
path
]
for
path
in
files_path
:
if
os
.
path
.
exists
(
path
):
os
.
remove
(
path
)
if
os
.
path
.
exists
(
'rank_mapping.csv'
):
os
.
remove
(
'rank_mapping.csv'
)
def
test_engine_predict
(
self
):
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
launch_model_path
=
os
.
path
.
join
(
file_dir
,
"engine_predict_api.py"
)
if
os
.
environ
.
get
(
"WITH_COVERAGE"
,
"OFF"
)
==
"ON"
:
coverage_args
=
[
"-m"
,
"coverage"
,
"run"
,
"--branch"
,
"-p"
]
else
:
coverage_args
=
[]
cmd
=
[
sys
.
executable
,
"-u"
]
+
coverage_args
+
[
"-m"
,
"launch"
,
"--gpus"
,
"0,1"
,
launch_model_path
]
process
=
subprocess
.
Popen
(
cmd
)
process
.
wait
()
self
.
assertEqual
(
process
.
returncode
,
0
)
# Remove unnecessary files
log_path
=
os
.
path
.
join
(
file_dir
,
"log"
)
if
os
.
path
.
exists
(
log_path
):
shutil
.
rmtree
(
log_path
)
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录