未验证 提交 f2db475a 编写于 作者: C chengduo 提交者: GitHub

update ParallelExecutor (#17204)

test=develop
上级 950aec55
......@@ -96,19 +96,31 @@ class ParallelExecutor(object):
if build_strategy is None:
build_strategy = BuildStrategy()
build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id
# TODO(paddle-dev): trainer_id and num_trainers should be removed from parameter list.
if num_trainers != 1 and build_strategy.num_trainers != num_trainers:
sys.stderr.write(
'The value of build_strategy.num_trainers[%d] is overwritten '
'by the passed num_trainers[%d].\n' %
(build_strategy.num_trainers, num_trainers))
build_strategy.num_trainers = num_trainers
if trainer_id != 0 and build_strategy.trainer_id != trainer_id:
sys.stderr.write(
'The value of build_strategy.trainer_id[%d] is overwritten '
'by the passed trainer_id[%d].\n' %
(build_strategy.trainer_id, trainer_id))
build_strategy.trainer_id = trainer_id
self._places = framework.cuda_places(
) if use_cuda else framework.cpu_places()
self._scope = scope if scope is not None else executor.global_scope()
if main_program is not None and main_program._enable_dgc:
assert num_trainers > 1, "dgc is not useful when num_trainers <= 1"
assert build_strategy.num_trainers > 1, "dgc is not useful when num_trainers <= 1"
assert build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy.AllReduce, "dgc \
only used for allreduce"
assert num_trainers * len(
assert build_strategy.num_trainers * len(
self._places) > 1, "dgc is not useful for single card training"
assert use_cuda, "dgc only used under cuda"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册