Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3649099f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3649099f
编写于
8月 15, 2022
作者:
Z
zhaoyingli
提交者:
GitHub
8月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] add collate_fn for dist_loader (#45053)
* add collate_fn * fix number of inputs
上级
8788513b
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
151 addition
and
68 deletion
+151
-68
python/paddle/distributed/auto_parallel/completion.py
python/paddle/distributed/auto_parallel/completion.py
+5
-1
python/paddle/distributed/auto_parallel/dist_loader.py
python/paddle/distributed/auto_parallel/dist_loader.py
+61
-30
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+40
-15
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+1
-0
python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py
...dle/distributed/auto_parallel/tuner/optimization_tuner.py
+8
-4
python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py
.../paddle/fluid/tests/unittests/auto_parallel/engine_api.py
+9
-8
python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py
...le/fluid/tests/unittests/auto_parallel/high_order_grad.py
+1
-1
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py
...addle/fluid/tests/unittests/test_auto_parallel_reshard.py
+26
-9
未找到文件。
python/paddle/distributed/auto_parallel/completion.py
浏览文件 @
3649099f
...
...
@@ -1300,6 +1300,10 @@ class Completer:
def
complete_update_annotation
(
self
,
serial_main_program
):
"""Complete the annotation of vars and ops in the update phase for parallel program."""
# Copy the dist tensors and dist ops annotated by users from the default context
# global mesh
from
paddle.distributed.auto_parallel.process_group
import
get_world_process_group
world_ranks
=
get_world_process_group
().
ranks
# Notice: serial_main_program is actually a dist_main_program of current rank,
# and must be passed into this function.
...
...
@@ -1371,7 +1375,7 @@ class Completer:
if
not
learning_rate_completed
:
learning_rate_completed
=
True
var_dist_attr
=
TensorDistributedAttribute
()
var_dist_attr
.
process_mesh
=
ref_process_mesh
var_dist_attr
.
process_mesh
=
world_ranks
var_dist_attr
.
dims_mapping
=
[
-
1
]
self
.
_dist_context
.
set_tensor_dist_attr_for_program
(
learning_var
,
var_dist_attr
)
...
...
python/paddle/distributed/auto_parallel/dist_loader.py
浏览文件 @
3649099f
...
...
@@ -17,7 +17,8 @@ import numpy as np
import
paddle
from
.utils
import
to_list
from
paddle.fluid.layers.utils
import
flatten
from
paddle.io
import
DataLoader
,
DistributedBatchSampler
from
paddle.io
import
DataLoader
,
BatchSampler
,
IterableDataset
from
paddle.fluid.dataloader.dataloader_iter
import
_DatasetKind
,
default_collate_fn
,
default_convert_fn
class
DistributedDataLoader
(
metaclass
=
abc
.
ABCMeta
):
...
...
@@ -29,14 +30,32 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
data_parallel_world_size
=
None
,
data_parallel_rank
=
None
,
drop_last
=
False
):
if
isinstance
(
dataset
,
IterableDataset
):
raise
TypeError
(
"IterableDataset is not supported."
)
else
:
self
.
dataset_kind
=
_DatasetKind
.
MAP
self
.
dataset
=
dataset
self
.
batch_size
=
batch_size
self
.
epochs
=
epochs
self
.
data_parallel_world_size
=
data_parallel_world_size
self
.
data_parallel_rank
=
data_parallel_rank
self
.
drop_lost
=
drop_last
if
data_parallel_world_size
is
not
None
and
batch_size
is
not
None
:
assert
batch_size
%
data_parallel_world_size
==
0
if
batch_size
is
None
:
self
.
batch_size
=
None
self
.
batch_sampler
=
None
else
:
if
data_parallel_world_size
is
not
None
:
assert
batch_size
%
data_parallel_world_size
==
0
,
\
"'batch_size' must be divisible by data parallel size"
self
.
batch_size
=
batch_size
self
.
batch_sampler
=
BatchSampler
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
False
,
drop_last
=
drop_last
)
self
.
auto_collate_batch
=
self
.
batch_sampler
is
not
None
self
.
sampler_iter
=
iter
(
self
.
index_sampler
)
self
.
dp_world_size
=
1
if
data_parallel_world_size
is
None
else
data_parallel_world_size
self
.
dp_rank
=
0
if
data_parallel_rank
is
None
else
data_parallel_rank
@
abc
.
abstractmethod
def
__iter__
(
self
):
...
...
@@ -46,6 +65,16 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
def
__next__
(
self
):
raise
NotImplementedError
@
property
def
index_sampler
(
self
):
if
self
.
auto_collate_batch
:
return
self
.
batch_sampler
else
:
if
self
.
dataset_kind
==
_DatasetKind
.
MAP
:
return
list
(
range
(
len
(
self
.
dataset
)))
else
:
raise
TypeError
(
"Only support datasets in map-style."
)
class
NonIterableGeneratorLoader
(
DistributedDataLoader
):
...
...
@@ -56,21 +85,29 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
batch_size
=
1
,
epochs
=
1
,
steps_per_epoch
=
None
,
collate_fn
=
None
,
data_parallel_world_size
=
None
,
data_parallel_rank
=
None
,
drop_last
=
False
):
self
.
feed_list
=
feed_list
self
.
places
=
places
self
.
steps_per_epoch
=
steps_per_epoch
self
.
dp_world_size
=
1
if
data_parallel_world_size
is
None
else
data_parallel_world_size
self
.
dp_rank
=
0
if
data_parallel_rank
is
None
else
data_parallel_rank
super
(
NonIterableGeneratorLoader
,
self
).
__init__
(
dataset
,
batch_size
,
epochs
,
data_parallel_world_size
,
data_parallel_rank
,
drop_last
)
self
.
_inner_dataloader
=
self
.
_create_inner_dataloader
()
if
self
.
auto_collate_batch
:
self
.
collate_fn
=
collate_fn
or
default_collate_fn
else
:
self
.
collate_fn
=
collate_fn
or
default_convert_fn
self
.
dataset_fetcher
=
_DatasetKind
.
create_fetcher
(
self
.
dataset_kind
,
self
.
dataset
,
self
.
auto_collate_batch
,
self
.
collate_fn
,
self
.
drop_lost
)
self
.
_steps
=
self
.
_infer_steps
()
self
.
_inner_dataloader
=
self
.
_create_inner_dataloader
()
def
__iter__
(
self
):
self
.
_cur_step
=
0
...
...
@@ -101,31 +138,25 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
def
_create_inner_dataloader
(
self
):
def
sample_data_generator
():
batch_data
=
None
for
step
,
data
in
enumerate
(
self
.
dataset
):
data
=
flatten
(
data
)
if
batch_data
is
None
:
batch_data
=
[[]
for
i
in
range
(
len
(
data
))]
for
idx
in
range
(
len
(
data
)):
batch_data
[
idx
].
append
(
data
[
idx
])
if
(
step
+
1
)
%
self
.
batch_size
==
0
:
partial_data
=
[]
for
d
in
batch_data
:
array
=
np
.
array
(
d
)
partial_data
.
append
(
np
.
split
(
array
,
self
.
dp_world_size
)[
self
.
dp_rank
])
yield
partial_data
[:
len
(
self
.
feed_list
)]
batch_data
=
None
for
indices
in
self
.
sampler_iter
:
assert
len
(
indices
)
%
self
.
dp_world_size
==
0
,
\
"Please set batch_size to be divisible by data parallel size"
n
=
len
(
indices
)
//
self
.
dp_world_size
cur_indices
=
[
indices
[
i
:
i
+
n
]
for
i
in
range
(
0
,
len
(
indices
),
n
)
]
batch
=
self
.
dataset_fetcher
.
fetch
(
cur_indices
[
self
.
dp_rank
])
yield
batch
[:
len
(
self
.
feed_list
)]
def
batch_data_generator
():
for
data
in
self
.
dataset
:
data
=
flatten
(
data
)
for
indices
in
self
.
sampler_iter
:
partial_data
=
[]
for
d
in
data
:
assert
d
.
shape
[
0
]
%
self
.
dp_world_size
==
0
,
\
"Please padding dataset with data parallel size"
batch
=
self
.
dataset_fetcher
.
fetch
(
indices
)
for
data
in
batch
:
assert
data
.
shape
[
0
]
%
self
.
dp_world_size
==
0
,
\
"Please padding dataset's batch_size to be divisible by data parallel size"
partial_data
.
append
(
np
.
split
(
d
,
self
.
dp_world_size
)[
self
.
dp_rank
])
np
.
split
(
d
ata
,
self
.
dp_world_size
)[
self
.
dp_rank
])
yield
partial_data
[:
len
(
self
.
feed_list
)]
dataloader
=
paddle
.
fluid
.
io
.
DataLoader
.
from_generator
(
...
...
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
3649099f
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
time
import
copy
import
logging
from
collections
import
defaultdict
...
...
@@ -306,6 +307,7 @@ class Engine:
mode
].
dist_startup_programs
self
.
_feed_vars
[
mode
]
=
self
.
_dist_contexts
[
mode
].
serial_feed_vars
self
.
_fetch_vars
[
mode
]
=
self
.
_dist_contexts
[
mode
].
serial_fetch_vars
self
.
_optimizer
=
self
.
_dist_contexts
[
mode
].
serial_optimizer
if
self
.
_nranks
>
1
:
# Traverse different rank programs and traverse each op of them,
...
...
@@ -403,7 +405,8 @@ class Engine:
epochs
=
1
,
fetches
=
None
,
steps_per_epoch
=
None
,
use_program_cache
=
False
,
collate_fn
=
None
,
use_cache
=
False
,
return_numpy
=
True
):
# TODO: callbacks
# TODO: evaluate after training
...
...
@@ -417,18 +420,24 @@ class Engine:
assert
self
.
mode
in
self
.
_dist_main_progs
,
\
"train model is not ready, please call `engine.prepare()` first."
train_dataloader
=
self
.
_create_dataloader
(
train_data
,
batch_size
,
epochs
,
steps_per_epoch
)
epochs
,
steps_per_epoch
,
collate_fn
)
usr_fetch
=
self
.
_validate_fetches
(
fetches
)
fetch_loss
=
self
.
_validate_fetches
(
self
.
fetch_vars
[
"loss"
])
fetch_list
,
fetch_map
=
self
.
_fetch_map
(
fetch_loss
,
usr_fetch
)
lr_scheduler
=
self
.
get_lr_scheduler
(
self
.
main_program
)
for
epoch
in
range
(
epochs
):
train_logs
=
{
"epoch"
:
epoch
}
for
step
,
_
in
enumerate
(
train_dataloader
):
outs
=
self
.
_executor
.
run
(
self
.
main_program
,
fetch_list
=
fetch_list
,
use_program_cache
=
use_
program_
cache
,
use_program_cache
=
use_cache
,
return_numpy
=
return_numpy
)
if
lr_scheduler
is
not
None
:
lr_scheduler
.
step
()
train_logs
[
"lr"
]
=
self
.
_optimizer
.
get_lr
()
train_logs
[
"step"
]
=
step
# inner fetches
if
fetch_loss
:
...
...
@@ -444,7 +453,8 @@ class Engine:
eval_data
,
batch_size
=
1
,
fetches
=
None
,
use_program_cache
=
False
,
collate_fn
=
None
,
use_cache
=
False
,
return_numpy
=
True
):
self
.
mode
=
'eval'
if
not
self
.
_mode_init_states
[
self
.
mode
]:
...
...
@@ -452,7 +462,9 @@ class Engine:
assert
self
.
mode
in
self
.
_dist_main_progs
,
\
"eval model is not ready, please call `engine.prepare()` first."
eval_dataloader
=
self
.
_create_dataloader
(
eval_data
,
batch_size
)
eval_dataloader
=
self
.
_create_dataloader
(
eval_data
,
batch_size
,
collate_fn
=
collate_fn
)
usr_fetch
=
self
.
_validate_fetches
(
fetches
)
fetch_loss
=
self
.
_validate_fetches
(
self
.
fetch_vars
[
"loss"
])
...
...
@@ -464,7 +476,7 @@ class Engine:
eval_logs
=
{
"step"
:
step
}
outs
=
self
.
_executor
.
run
(
self
.
main_program
,
fetch_list
=
fetch_list
,
use_program_cache
=
use_
program_
cache
,
use_program_cache
=
use_cache
,
return_numpy
=
return_numpy
)
# inner fetches
if
fetch_loss
:
...
...
@@ -489,7 +501,8 @@ class Engine:
test_data
,
batch_size
=
1
,
fetches
=
None
,
use_program_cache
=
False
,
collate_fn
=
None
,
use_cache
=
False
,
return_numpy
=
True
):
self
.
mode
=
'predict'
if
not
self
.
_mode_init_states
[
self
.
mode
]:
...
...
@@ -497,7 +510,9 @@ class Engine:
assert
self
.
mode
in
self
.
_dist_main_progs
,
\
"predict model is not ready, please call `engine.prepare()` first."
test_dataloader
=
self
.
_create_dataloader
(
test_data
,
batch_size
)
test_dataloader
=
self
.
_create_dataloader
(
test_data
,
batch_size
,
collate_fn
=
collate_fn
)
usr_fetch
=
self
.
_validate_fetches
(
fetches
)
fetch_outputs
=
self
.
_validate_fetches
(
self
.
fetch_vars
[
"outputs"
])
...
...
@@ -508,7 +523,7 @@ class Engine:
predict_logs
=
{
"step"
:
step
}
outs
=
self
.
_executor
.
run
(
self
.
main_program
,
fetch_list
=
fetch_list
,
use_program_cache
=
use_
program_
cache
,
use_program_cache
=
use_cache
,
return_numpy
=
return_numpy
)
outputs
.
append
(
outs
[:
len
(
fetch_outputs
)])
for
i
,
out
in
enumerate
(
outs
):
...
...
@@ -521,7 +536,8 @@ class Engine:
dataset
,
batch_size
,
epochs
=
1
,
steps_per_epoch
=
None
):
steps_per_epoch
=
None
,
collate_fn
=
None
):
dist_main_prog
=
self
.
_dist_main_progs
[
self
.
mode
][
self
.
_cur_rank
]
dist_startup_prog
=
self
.
_dist_startup_progs
[
self
.
mode
][
self
.
_cur_rank
]
dist_context
=
self
.
_dist_contexts
[
self
.
mode
]
...
...
@@ -554,6 +570,7 @@ class Engine:
batch_size
,
epochs
,
steps_per_epoch
,
collate_fn
,
data_parallel_world_size
=
self
.
_input_split_size
,
data_parallel_rank
=
self
.
_input_split_rank
)
...
...
@@ -645,12 +662,11 @@ class Engine:
config
=
self
.
strategy
.
recompute_configs
# extract ckpts by specific model
self
.
model
if
isinstance
(
self
.
model
,
paddle
.
nn
.
Layer
):
if
hasattr
(
self
.
model
,
"
model
"
)
and
self
.
model
.
model
.
__class__
.
__name__
==
'GPTForPretraining'
:
exact_ckpts
=
self
.
model
.
model
.
gpt
.
checkpoints
self
.
model
,
"
gpt
"
)
and
self
.
model
.
__class__
.
__name__
==
'GPTForPretraining'
:
exact_ckpts
=
self
.
model
.
gpt
.
checkpoints
else
:
exact_ckpts
=
config
[
"checkpoints"
]
...
...
@@ -659,7 +675,7 @@ class Engine:
config
[
"checkpoints"
]
=
exact_ckpts
[:]
self
.
strategy
.
recompute_configs
=
config
logs
=
{
'Model Class'
:
self
.
model
.
model
.
__class__
.
__name__
,
'Model Class'
:
self
.
model
.
__class__
.
__name__
,
'Applied Recompute ckpts'
:
exact_ckpts
}
self
.
_logger
.
info
(
logs
)
...
...
@@ -699,6 +715,15 @@ class Engine:
self
.
_saver
.
load
(
path
,
dist_main_prog
,
dist_context
,
strict
,
load_optimizer
)
@
staticmethod
def
get_lr_scheduler
(
program
):
lr_sheduler
=
None
if
hasattr
(
program
,
'lr_sheduler'
):
from
paddle.optimizer.lr
import
LRScheduler
lr_sheduler
=
program
.
lr_sheduler
assert
isinstance
(
lr_sheduler
,
LRScheduler
),
"must be LRScheduler"
return
lr_sheduler
@
property
def
mode
(
self
):
return
self
.
_mode
...
...
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
3649099f
...
...
@@ -149,6 +149,7 @@ class Parallelizer:
paddle
.
enable_static
()
else
:
optimizer
=
copy
.
deepcopy
(
optimizer
)
self
.
_dist_context
.
_serial_optimizer
=
optimizer
with
program_guard
(
main_program
,
startup_program
):
optimizer_ops
=
optimizer
.
apply_gradients
(
params_grads
)
self
.
_completer
.
complete_update_annotation
(
main_program
)
...
...
python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py
浏览文件 @
3649099f
...
...
@@ -363,11 +363,15 @@ class OptimizationTuner:
profile_args
=
" "
.
join
([
"--rank"
,
str
(
self
.
rank
),
"--device_id"
,
str
(
self
.
device_id
),
"--ctx_filename"
,
ctx_path
,
str
(
self
.
rank
),
"--device_id"
,
str
(
self
.
device_id
),
"--ctx_filename"
,
ctx_path
,
"--profile_start_step"
,
str
(
self
.
_config
.
profile_start_step
),
"--profile_end_step"
,
str
(
self
.
_config
.
profile_end_step
)
str
(
self
.
_config
.
profile_start_step
),
"--profile_end_step"
,
str
(
self
.
_config
.
profile_end_step
),
])
cmd_args
=
"-m paddle.distributed.auto_parallel.tuner.profiler"
+
" "
+
profile_args
cmd
=
[
sys
.
executable
,
"-u"
]
+
coverage_args
+
shlex
.
split
(
cmd_args
)
...
...
python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py
浏览文件 @
3649099f
...
...
@@ -31,6 +31,8 @@ from paddle.static import InputSpec
from
paddle.distributed
import
fleet
import
paddle.distributed.auto_parallel
as
auto
from
paddle.distributed.auto_parallel.engine
import
Engine
from
paddle.optimizer.lr
import
CosineAnnealingDecay
from
paddle.fluid.dataloader.collate
import
default_collate_fn
paddle
.
enable_static
()
global_process_mesh
=
auto
.
ProcessMesh
(
mesh
=
[
0
,
1
])
...
...
@@ -106,19 +108,18 @@ def train(fetch):
dropout_ratio
=
0.1
,
initializer_range
=
0.02
)
loss
=
paddle
.
nn
.
CrossEntropyLoss
()
optimizer
=
paddle
.
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.00001
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-08
,
grad_clip
=
None
)
scheduler
=
paddle
.
optimizer
.
lr
.
CosineAnnealingDecay
(
learning_rate
=
0.00001
,
T_max
=
10
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
scheduler
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-08
,
grad_clip
=
None
)
inputs_spec
=
InputSpec
([
batch_size
,
hidden_size
],
'float32'
,
'x'
)
labels_spec
=
InputSpec
([
batch_size
],
'int64'
,
'label'
)
dist_strategy
=
fleet
.
DistributedStrategy
()
dist_strategy
.
amp
=
False
dist_strategy
.
pipeline
=
False
dist_strategy
.
recompute
=
False
dist_strategy
.
semi_auto
=
True
fleet
.
init
(
is_collective
=
True
,
strategy
=
dist_strategy
)
...
...
python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py
浏览文件 @
3649099f
...
...
@@ -145,7 +145,7 @@ def main():
labels_spec
=
labels_spec
,
strategy
=
dist_strategy
)
engine
.
prepare
(
optimizer
=
optimizer
,
loss
=
loss_func
)
res
=
engine
.
fit
(
train_dataset
,
batch_size
=
None
)
engine
.
fit
(
train_dataset
,
batch_size
=
None
)
dist_context
=
engine
.
dist_context
block
=
engine
.
main_program
.
global_block
()
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py
浏览文件 @
3649099f
...
...
@@ -282,13 +282,16 @@ class TestMLPReshard(unittest.TestCase):
if
op
.
type
==
"gelu_grad"
:
op_need_check
=
op
break
# print_program_with_dist_attr(dist_main_prog, dist_context)
# grad op should have dist attr
self
.
assertTrue
(
check_backward_dist_attr
(
dist_context
,
dist_main_prog
,
op_need_check
))
# clear _g_process_group_map
_g_process_group_map
.
clear
()
_g_process_group_map
[
0
]
=
ProcessGroup
(
0
,
[])
def
test_mlp_pp
(
self
):
global
_global_parallel_strategy
_global_parallel_strategy
=
"pp"
...
...
@@ -305,29 +308,35 @@ class TestMLPReshard(unittest.TestCase):
rank_id
=
1
dist_main_prog
,
dist_startup_prog
,
dist_params_grads
=
get_dist_prog
(
train_program
,
startup_program
,
dist_context
,
rank_id
)
for
key
in
list
(
_g_process_group_map
.
keys
()):
del
_g_process_group_map
[
key
]
_g_process_group_map
[
0
]
=
ProcessGroup
(
0
,
[])
resharder
=
Resharder
(
dist_main_prog
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
.
reshard
()
# check send and recv result
self
.
assertTrue
(
check_send_recv_result
(
dist_main_prog
,
rank_id
))
# parameter initialization of every rank should be different in the pipeline scene
self
.
assertTrue
(
check_initialization
(
dist_startup_prog
,
rank_id
))
# clear _g_process_group_map
_g_process_group_map
.
clear
()
_g_process_group_map
[
0
]
=
ProcessGroup
(
0
,
[])
def
test_mlp_pp_diff_process_mesh
(
self
):
global
_global_parallel_strategy
_global_parallel_strategy
=
"pp"
global
_global_process_mesh
_global_process_mesh
=
auto
.
ProcessMesh
(
mesh
=
[
0
,
1
])
global
PP_MESH_0
PP_MESH_0
=
auto
.
ProcessMesh
(
mesh
=
[
0
])
global
PP_MESH_1
PP_MESH_1
=
auto
.
ProcessMesh
(
mesh
=
[
1
])
train_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
dist_context
=
DistributedContext
()
rank_id
=
1
dist_main_prog
,
dist_startup_prog
,
dist_params_grads
=
get_dist_prog
(
train_program
,
startup_program
,
dist_context
,
rank_id
,
True
)
for
key
in
list
(
_g_process_group_map
.
keys
()):
del
_g_process_group_map
[
key
]
_g_process_group_map
[
0
]
=
ProcessGroup
(
0
,
[])
resharder
=
Resharder
(
dist_main_prog
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
.
reshard
()
...
...
@@ -335,6 +344,10 @@ class TestMLPReshard(unittest.TestCase):
self
.
assertTrue
(
check_send_recv_result
(
dist_main_prog
,
rank_id
))
self
.
assertTrue
(
check_initialization
(
dist_startup_prog
,
rank_id
))
# clear _g_process_group_map
_g_process_group_map
.
clear
()
_g_process_group_map
[
0
]
=
ProcessGroup
(
0
,
[])
def
test_mlp_dp
(
self
):
global
_global_parallel_strategy
_global_parallel_strategy
=
"dp"
...
...
@@ -350,12 +363,16 @@ class TestMLPReshard(unittest.TestCase):
resharder
=
Resharder
(
dist_main_prog
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
.
reshard
()
# send and recv should not exist in dp scene.
self
.
assertFalse
(
check_send_recv_result
(
dist_main_prog
,
rank_id
))
# all parameters should be initialized in dp scene
self
.
assertTrue
(
check_initialization_for_dp
(
dist_startup_prog
))
# clear _g_process_group_map
_g_process_group_map
.
clear
()
_g_process_group_map
[
0
]
=
ProcessGroup
(
0
,
[])
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录