Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3649099f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3649099f
编写于
8月 15, 2022
作者:
Z
zhaoyingli
提交者:
GitHub
8月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] add collate_fn for dist_loader (#45053)
* add collate_fn * fix number of inputs
上级
8788513b
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
151 addition
and
68 deletion
+151
-68
python/paddle/distributed/auto_parallel/completion.py
python/paddle/distributed/auto_parallel/completion.py
+5
-1
python/paddle/distributed/auto_parallel/dist_loader.py
python/paddle/distributed/auto_parallel/dist_loader.py
+61
-30
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+40
-15
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+1
-0
python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py
...dle/distributed/auto_parallel/tuner/optimization_tuner.py
+8
-4
python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py
.../paddle/fluid/tests/unittests/auto_parallel/engine_api.py
+9
-8
python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py
...le/fluid/tests/unittests/auto_parallel/high_order_grad.py
+1
-1
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py
...addle/fluid/tests/unittests/test_auto_parallel_reshard.py
+26
-9
未找到文件。
python/paddle/distributed/auto_parallel/completion.py
浏览文件 @
3649099f
...
...
@@ -1300,6 +1300,10 @@ class Completer:
def
complete_update_annotation
(
self
,
serial_main_program
):
"""Complete the annotation of vars and ops in the update phase for parallel program."""
# Copy the dist tensors and dist ops annotated by users from the default context
# global mesh
from
paddle.distributed.auto_parallel.process_group
import
get_world_process_group
world_ranks
=
get_world_process_group
().
ranks
# Notice: serial_main_program is actually a dist_main_program of current rank,
# and must be passed into this function.
...
...
@@ -1371,7 +1375,7 @@ class Completer:
if
not
learning_rate_completed
:
learning_rate_completed
=
True
var_dist_attr
=
TensorDistributedAttribute
()
var_dist_attr
.
process_mesh
=
ref_process_mesh
var_dist_attr
.
process_mesh
=
world_ranks
var_dist_attr
.
dims_mapping
=
[
-
1
]
self
.
_dist_context
.
set_tensor_dist_attr_for_program
(
learning_var
,
var_dist_attr
)
...
...
python/paddle/distributed/auto_parallel/dist_loader.py
浏览文件 @
3649099f
...
...
@@ -17,7 +17,8 @@ import numpy as np
import
paddle
from
.utils
import
to_list
from
paddle.fluid.layers.utils
import
flatten
from
paddle.io
import
DataLoader
,
DistributedBatchSampler
from
paddle.io
import
DataLoader
,
BatchSampler
,
IterableDataset
from
paddle.fluid.dataloader.dataloader_iter
import
_DatasetKind
,
default_collate_fn
,
default_convert_fn
class
DistributedDataLoader
(
metaclass
=
abc
.
ABCMeta
):
...
...
@@ -29,14 +30,32 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
data_parallel_world_size
=
None
,
data_parallel_rank
=
None
,
drop_last
=
False
):
if
isinstance
(
dataset
,
IterableDataset
):
raise
TypeError
(
"IterableDataset is not supported."
)
else
:
self
.
dataset_kind
=
_DatasetKind
.
MAP
self
.
dataset
=
dataset
self
.
batch_size
=
batch_size
self
.
epochs
=
epochs
self
.
data_parallel_world_size
=
data_parallel_world_size
self
.
data_parallel_rank
=
data_parallel_rank
self
.
drop_lost
=
drop_last
if
data_parallel_world_size
is
not
None
and
batch_size
is
not
None
:
assert
batch_size
%
data_parallel_world_size
==
0
if
batch_size
is
None
:
self
.
batch_size
=
None
self
.
batch_sampler
=
None
else
:
if
data_parallel_world_size
is
not
None
:
assert
batch_size
%
data_parallel_world_size
==
0
,
\
"'batch_size' must be divisible by data parallel size"
self
.
batch_size
=
batch_size
self
.
batch_sampler
=
BatchSampler
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
False
,
drop_last
=
drop_last
)
self
.
auto_collate_batch
=
self
.
batch_sampler
is
not
None
self
.
sampler_iter
=
iter
(
self
.
index_sampler
)
self
.
dp_world_size
=
1
if
data_parallel_world_size
is
None
else
data_parallel_world_size
self
.
dp_rank
=
0
if
data_parallel_rank
is
None
else
data_parallel_rank
@
abc
.
abstractmethod
def
__iter__
(
self
):
...
...
@@ -46,6 +65,16 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
def
__next__
(
self
):
raise
NotImplementedError
@
property
def
index_sampler
(
self
):
if
self
.
auto_collate_batch
:
return
self
.
batch_sampler
else
:
if
self
.
dataset_kind
==
_DatasetKind
.
MAP
:
return
list
(
range
(
len
(
self
.
dataset
)))
else
:
raise
TypeError
(
"Only support datasets in map-style."
)
class
NonIterableGeneratorLoader
(
DistributedDataLoader
):
...
...
@@ -56,21 +85,29 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
batch_size
=
1
,
epochs
=
1
,
steps_per_epoch
=
None
,
collate_fn
=
None
,
data_parallel_world_size
=
None
,
data_parallel_rank
=
None
,
drop_last
=
False
):
self
.
feed_list
=
feed_list
self
.
places
=
places
self
.
steps_per_epoch
=
steps_per_epoch
self
.
dp_world_size
=
1
if
data_parallel_world_size
is
None
else
data_parallel_world_size
self
.
dp_rank
=
0
if
data_parallel_rank
is
None
else
data_parallel_rank
super
(
NonIterableGeneratorLoader
,
self
).
__init__
(
dataset
,
batch_size
,
epochs
,
data_parallel_world_size
,
data_parallel_rank
,
drop_last
)
self
.
_inner_dataloader
=
self
.
_create_inner_dataloader
()
if
self
.
auto_collate_batch
:
self
.
collate_fn
=
collate_fn
or
default_collate_fn
else
:
self
.
collate_fn
=
collate_fn
or
default_convert_fn
self
.
dataset_fetcher
=
_DatasetKind
.
create_fetcher
(
self
.
dataset_kind
,
self
.
dataset
,
self
.
auto_collate_batch
,
self
.
collate_fn
,
self
.
drop_lost
)
self
.
_steps
=
self
.
_infer_steps
()
self
.
_inner_dataloader
=
self
.
_create_inner_dataloader
()
def
__iter__
(
self
):
self
.
_cur_step
=
0
...
...
@@ -101,31 +138,25 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
def
_create_inner_dataloader
(
self
):
def
sample_data_generator
():
batch_data
=
None
for
step
,
data
in
enumerate
(
self
.
dataset
):
data
=
flatten
(
data
)
if
batch_data
is
None
:
batch_data
=
[[]
for
i
in
range
(
len
(
data
))]
for
idx
in
range
(
len
(
data
)):
batch_data
[
idx
].
append
(
data
[
idx
])
if
(
step
+
1
)
%
self
.
batch_size
==
0
:
partial_data
=
[]
for
d
in
batch_data
:
array
=
np
.
array
(
d
)
partial_data
.
append
(
np
.
split
(
array
,
self
.
dp_world_size
)[
self
.
dp_rank
])
yield
partial_data
[:
len
(
self
.
feed_list
)]
batch_data
=
None
for
indices
in
self
.
sampler_iter
:
assert
len
(
indices
)
%
self
.
dp_world_size
==
0
,
\
"Please set batch_size to be divisible by data parallel size"
n
=
len
(
indices
)
//
self
.
dp_world_size
cur_indices
=
[
indices
[
i
:
i
+
n
]
for
i
in
range
(
0
,
len
(
indices
),
n
)
]
batch
=
self
.
dataset_fetcher
.
fetch
(
cur_indices
[
self
.
dp_rank
])
yield
batch
[:
len
(
self
.
feed_list
)]
def
batch_data_generator
():
for
data
in
self
.
dataset
:
data
=
flatten
(
data
)
for
indices
in
self
.
sampler_iter
:
partial_data
=
[]
for
d
in
data
:
assert
d
.
shape
[
0
]
%
self
.
dp_world_size
==
0
,
\
"Please padding dataset with data parallel size"
batch
=
self
.
dataset_fetcher
.
fetch
(
indices
)
for
data
in
batch
:
assert
data
.
shape
[
0
]
%
self
.
dp_world_size
==
0
,
\
"Please padding dataset's batch_size to be divisible by data parallel size"
partial_data
.
append
(
np
.
split
(
d
,
self
.
dp_world_size
)[
self
.
dp_rank
])
np
.
split
(
d
ata
,
self
.
dp_world_size
)[
self
.
dp_rank
])
yield
partial_data
[:
len
(
self
.
feed_list
)]
dataloader
=
paddle
.
fluid
.
io
.
DataLoader
.
from_generator
(
...
...
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
3649099f
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
time
import
copy
import
logging
from
collections
import
defaultdict
...
...
@@ -306,6 +307,7 @@ class Engine:
mode
].
dist_startup_programs
self
.
_feed_vars
[
mode
]
=
self
.
_dist_contexts
[
mode
].
serial_feed_vars
self
.
_fetch_vars
[
mode
]
=
self
.
_dist_contexts
[
mode
].
serial_fetch_vars
self
.
_optimizer
=
self
.
_dist_contexts
[
mode
].
serial_optimizer
if
self
.
_nranks
>
1
:
# Traverse different rank programs and traverse each op of them,
...
...
@@ -403,7 +405,8 @@ class Engine:
epochs
=
1
,
fetches
=
None
,
steps_per_epoch
=
None
,
use_program_cache
=
False
,
collate_fn
=
None
,
use_cache
=
False
,
return_numpy
=
True
):
# TODO: callbacks
# TODO: evaluate after training
...
...
@@ -417,18 +420,24 @@ class Engine:
assert
self
.
mode
in
self
.
_dist_main_progs
,
\
"train model is not ready, please call `engine.prepare()` first."
train_dataloader
=
self
.
_create_dataloader
(
train_data
,
batch_size
,
epochs
,
steps_per_epoch
)
epochs
,
steps_per_epoch
,
collate_fn
)
usr_fetch
=
self
.
_validate_fetches
(
fetches
)
fetch_loss
=
self
.
_validate_fetches
(
self
.
fetch_vars
[
"loss"
])
fetch_list
,
fetch_map
=
self
.
_fetch_map
(
fetch_loss
,
usr_fetch
)
lr_scheduler
=
self
.
get_lr_scheduler
(
self
.
main_program
)
for
epoch
in
range
(
epochs
):
train_logs
=
{
"epoch"
:
epoch
}
for
step
,
_
in
enumerate
(
train_dataloader
):
outs
=
self
.
_executor
.
run
(
self
.
main_program
,
fetch_list
=
fetch_list
,
use_program_cache
=
use_
program_
cache
,
use_program_cache
=
use_cache
,
return_numpy
=
return_numpy
)
if
lr_scheduler
is
not
None
:
lr_scheduler
.
step
()
train_logs
[
"lr"
]
=
self
.
_optimizer
.
get_lr
()
train_logs
[
"step"
]
=
step
# inner fetches
if
fetch_loss
:
...
...
@@ -444,7 +453,8 @@ class Engine:
eval_data
,
batch_size
=
1
,
fetches
=
None
,
use_program_cache
=
False
,
collate_fn
=
None
,
use_cache
=
False
,
return_numpy
=
True
):
self
.
mode
=
'eval'
if
not
self
.
_mode_init_states
[
self
.
mode
]:
...
...
@@ -452,7 +462,9 @@ class Engine:
assert
self
.
mode
in
self
.
_dist_main_progs
,
\
"eval model is not ready, please call `engine.prepare()` first."
eval_dataloader
=
self
.
_create_dataloader
(
eval_data
,
batch_size
)
eval_dataloader
=
self
.
_create_dataloader
(
eval_data
,
batch_size
,
collate_fn
=
collate_fn
)
usr_fetch
=
self
.
_validate_fetches
(
fetches
)
fetch_loss
=
self
.
_validate_fetches
(
self
.
fetch_vars
[
"loss"
])
...
...
@@ -464,7 +476,7 @@ class Engine:
eval_logs
=
{
"step"
:
step
}
outs
=
self
.
_executor
.
run
(
self
.
main_program
,
fetch_list
=
fetch_list
,
use_program_cache
=
use_
program_
cache
,
use_program_cache
=
use_cache
,
return_numpy
=
return_numpy
)
# inner fetches
if
fetch_loss
:
...
...
@@ -489,7 +501,8 @@ class Engine:
test_data
,
batch_size
=
1
,
fetches
=
None
,
use_program_cache
=
False
,
collate_fn
=
None
,
use_cache
=
False
,
return_numpy
=
True
):
self
.
mode
=
'predict'
if
not
self
.
_mode_init_states
[
self
.
mode
]:
...
...
@@ -497,7 +510,9 @@ class Engine:
assert
self
.
mode
in
self
.
_dist_main_progs
,
\
"predict model is not ready, please call `engine.prepare()` first."
test_dataloader
=
self
.
_create_dataloader
(
test_data
,
batch_size
)
test_dataloader
=
self
.
_create_dataloader
(
test_data
,
batch_size
,
collate_fn
=
collate_fn
)
usr_fetch
=
self
.
_validate_fetches
(
fetches
)
fetch_outputs
=
self
.
_validate_fetches
(
self
.
fetch_vars
[
"outputs"
])
...
...
@@ -508,7 +523,7 @@ class Engine:
predict_logs
=
{
"step"
:
step
}
outs
=
self
.
_executor
.
run
(
self
.
main_program
,
fetch_list
=
fetch_list
,
use_program_cache
=
use_
program_
cache
,
use_program_cache
=
use_cache
,
return_numpy
=
return_numpy
)
outputs
.
append
(
outs
[:
len
(
fetch_outputs
)])
for
i
,
out
in
enumerate
(
outs
):
...
...
@@ -521,7 +536,8 @@ class Engine:
dataset
,
batch_size
,
epochs
=
1
,
steps_per_epoch
=
None
):
steps_per_epoch
=
None
,
collate_fn
=
None
):
dist_main_prog
=
self
.
_dist_main_progs
[
self
.
mode
][
self
.
_cur_rank
]
dist_startup_prog
=
self
.
_dist_startup_progs
[
self
.
mode
][
self
.
_cur_rank
]
dist_context
=
self
.
_dist_contexts
[
self
.
mode
]
...
...
@@ -554,6 +570,7 @@ class Engine:
batch_size
,
epochs
,
steps_per_epoch
,
collate_fn
,
data_parallel_world_size
=
self
.
_input_split_size
,
data_parallel_rank
=
self
.
_input_split_rank
)
...
...
@@ -645,12 +662,11 @@ class Engine:
config
=
self
.
strategy
.
recompute_configs
# extract ckpts by specific model
self
.
model
if
isinstance
(
self
.
model
,
paddle
.
nn
.
Layer
):
if
hasattr
(
self
.
model
,
"
model
"
)
and
self
.
model
.
model
.
__class__
.
__name__
==
'GPTForPretraining'
:
exact_ckpts
=
self
.
model
.
model
.
gpt
.
checkpoints
self
.
model
,
"
gpt
"
)
and
self
.
model
.
__class__
.
__name__
==
'GPTForPretraining'
:
exact_ckpts
=
self
.
model
.
gpt
.
checkpoints
else
:
exact_ckpts
=
config
[
"checkpoints"
]
...
...
@@ -659,7 +675,7 @@ class Engine:
config
[
"checkpoints"
]
=
exact_ckpts
[:]
self
.
strategy
.
recompute_configs
=
config
logs
=
{
'Model Class'
:
self
.
model
.
model
.
__class__
.
__name__
,
'Model Class'
:
self
.
model
.
__class__
.
__name__
,
'Applied Recompute ckpts'
:
exact_ckpts
}
self
.
_logger
.
info
(
logs
)
...
...
@@ -699,6 +715,15 @@ class Engine:
self
.
_saver
.
load
(
path
,
dist_main_prog
,
dist_context
,
strict
,
load_optimizer
)
@
staticmethod
def
get_lr_scheduler
(
program
):
lr_sheduler
=
None
if
hasattr
(
program
,
'lr_sheduler'
):
from
paddle.optimizer.lr
import
LRScheduler
lr_sheduler
=
program
.
lr_sheduler
assert
isinstance
(
lr_sheduler
,
LRScheduler
),
"must be LRScheduler"
return
lr_sheduler
@
property
def
mode
(
self
):
return
self
.
_mode
...
...
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
3649099f
...
...
@@ -149,6 +149,7 @@ class Parallelizer:
paddle
.
enable_static
()
else
:
optimizer
=
copy
.
deepcopy
(
optimizer
)
self
.
_dist_context
.
_serial_optimizer
=
optimizer
with
program_guard
(
main_program
,
startup_program
):
optimizer_ops
=
optimizer
.
apply_gradients
(
params_grads
)
self
.
_completer
.
complete_update_annotation
(
main_program
)
...
...
python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py
浏览文件 @
3649099f
...
...
@@ -363,11 +363,15 @@ class OptimizationTuner:
profile_args
=
" "
.
join
([
"--rank"
,
str
(
self
.
rank
),
"--device_id"
,
str
(
self
.
device_id
),
"--ctx_filename"
,
ctx_path
,
str
(
self
.
rank
),
"--device_id"
,
str
(
self
.
device_id
),
"--ctx_filename"
,
ctx_path
,
"--profile_start_step"
,
str
(
self
.
_config
.
profile_start_step
),
"--profile_end_step"
,
str
(
self
.
_config
.
profile_end_step
)
str
(
self
.
_config
.
profile_start_step
),
"--profile_end_step"
,
str
(
self
.
_config
.
profile_end_step
),
])
cmd_args
=
"-m paddle.distributed.auto_parallel.tuner.profiler"
+
" "
+
profile_args
cmd
=
[
sys
.
executable
,
"-u"
]
+
coverage_args
+
shlex
.
split
(
cmd_args
)
...
...
python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py
浏览文件 @
3649099f
...
...
@@ -31,6 +31,8 @@ from paddle.static import InputSpec
from
paddle.distributed
import
fleet
import
paddle.distributed.auto_parallel
as
auto
from
paddle.distributed.auto_parallel.engine
import
Engine
from
paddle.optimizer.lr
import
CosineAnnealingDecay
from
paddle.fluid.dataloader.collate
import
default_collate_fn
paddle
.
enable_static
()
global_process_mesh
=
auto
.
ProcessMesh
(
mesh
=
[
0
,
1
])
...
...
@@ -106,19 +108,18 @@ def train(fetch):
dropout_ratio
=
0.1
,
initializer_range
=
0.02
)
loss
=
paddle
.
nn
.
CrossEntropyLoss
()
optimizer
=
paddle
.
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.00001
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-08
,
grad_clip
=
None
)
scheduler
=
paddle
.
optimizer
.
lr
.
CosineAnnealingDecay
(
learning_rate
=
0.00001
,
T_max
=
10
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
scheduler
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-08
,
grad_clip
=
None
)
inputs_spec
=
InputSpec
([
batch_size
,
hidden_size
],
'float32'
,
'x'
)
labels_spec
=
InputSpec
([
batch_size
],
'int64'
,
'label'
)
dist_strategy
=
fleet
.
DistributedStrategy
()
dist_strategy
.
amp
=
False
dist_strategy
.
pipeline
=
False
dist_strategy
.
recompute
=
False
dist_strategy
.
semi_auto
=
True
fleet
.
init
(
is_collective
=
True
,
strategy
=
dist_strategy
)
...
...
python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py
浏览文件 @
3649099f
...
...
@@ -145,7 +145,7 @@ def main():
labels_spec
=
labels_spec
,
strategy
=
dist_strategy
)
engine
.
prepare
(
optimizer
=
optimizer
,
loss
=
loss_func
)
res
=
engine
.
fit
(
train_dataset
,
batch_size
=
None
)
engine
.
fit
(
train_dataset
,
batch_size
=
None
)
dist_context
=
engine
.
dist_context
block
=
engine
.
main_program
.
global_block
()
...
...
python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py
浏览文件 @
3649099f
...
...
@@ -282,13 +282,16 @@ class TestMLPReshard(unittest.TestCase):
if
op
.
type
==
"gelu_grad"
:
op_need_check
=
op
break
# print_program_with_dist_attr(dist_main_prog, dist_context)
# grad op should have dist attr
self
.
assertTrue
(
check_backward_dist_attr
(
dist_context
,
dist_main_prog
,
op_need_check
))
# clear _g_process_group_map
_g_process_group_map
.
clear
()
_g_process_group_map
[
0
]
=
ProcessGroup
(
0
,
[])
def
test_mlp_pp
(
self
):
global
_global_parallel_strategy
_global_parallel_strategy
=
"pp"
...
...
@@ -305,29 +308,35 @@ class TestMLPReshard(unittest.TestCase):
rank_id
=
1
dist_main_prog
,
dist_startup_prog
,
dist_params_grads
=
get_dist_prog
(
train_program
,
startup_program
,
dist_context
,
rank_id
)
for
key
in
list
(
_g_process_group_map
.
keys
()):
del
_g_process_group_map
[
key
]
_g_process_group_map
[
0
]
=
ProcessGroup
(
0
,
[])
resharder
=
Resharder
(
dist_main_prog
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
.
reshard
()
# check send and recv result
self
.
assertTrue
(
check_send_recv_result
(
dist_main_prog
,
rank_id
))
# parameter initialization of every rank should be different in the pipeline scene
self
.
assertTrue
(
check_initialization
(
dist_startup_prog
,
rank_id
))
# clear _g_process_group_map
_g_process_group_map
.
clear
()
_g_process_group_map
[
0
]
=
ProcessGroup
(
0
,
[])
def
test_mlp_pp_diff_process_mesh
(
self
):
global
_global_parallel_strategy
_global_parallel_strategy
=
"pp"
global
_global_process_mesh
_global_process_mesh
=
auto
.
ProcessMesh
(
mesh
=
[
0
,
1
])
global
PP_MESH_0
PP_MESH_0
=
auto
.
ProcessMesh
(
mesh
=
[
0
])
global
PP_MESH_1
PP_MESH_1
=
auto
.
ProcessMesh
(
mesh
=
[
1
])
train_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
dist_context
=
DistributedContext
()
rank_id
=
1
dist_main_prog
,
dist_startup_prog
,
dist_params_grads
=
get_dist_prog
(
train_program
,
startup_program
,
dist_context
,
rank_id
,
True
)
for
key
in
list
(
_g_process_group_map
.
keys
()):
del
_g_process_group_map
[
key
]
_g_process_group_map
[
0
]
=
ProcessGroup
(
0
,
[])
resharder
=
Resharder
(
dist_main_prog
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
.
reshard
()
...
...
@@ -335,6 +344,10 @@ class TestMLPReshard(unittest.TestCase):
self
.
assertTrue
(
check_send_recv_result
(
dist_main_prog
,
rank_id
))
self
.
assertTrue
(
check_initialization
(
dist_startup_prog
,
rank_id
))
# clear _g_process_group_map
_g_process_group_map
.
clear
()
_g_process_group_map
[
0
]
=
ProcessGroup
(
0
,
[])
def
test_mlp_dp
(
self
):
global
_global_parallel_strategy
_global_parallel_strategy
=
"dp"
...
...
@@ -350,12 +363,16 @@ class TestMLPReshard(unittest.TestCase):
resharder
=
Resharder
(
dist_main_prog
,
dist_startup_prog
,
rank_id
,
dist_context
,
dist_params_grads
)
resharder
.
reshard
()
# send and recv should not exist in dp scene.
self
.
assertFalse
(
check_send_recv_result
(
dist_main_prog
,
rank_id
))
# all parameters should be initialized in dp scene
self
.
assertTrue
(
check_initialization_for_dp
(
dist_startup_prog
))
# clear _g_process_group_map
_g_process_group_map
.
clear
()
_g_process_group_map
[
0
]
=
ProcessGroup
(
0
,
[])
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录