提交 2b7931d5 编写于 作者: Q Qiao Longfei

refine code test=develop

上级 3f9263f6
......@@ -133,15 +133,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass;
if (strategy_.is_distribution_) {
VLOG(3) << "multi device dist train mode";
VLOG(3) << "multi device parameter server mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
VLOG(3) << "multi device allreduce mode";
VLOG(3) << "multi devices collective mode with allreduce";
multi_devices_pass =
AppendPass("allreduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
VLOG(3) << "multi device reduce mode";
VLOG(3) << "multi deivces collective mode with reduce";
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
} else {
PADDLE_THROW("Unknown reduce strategy.");
......
......@@ -35,15 +35,6 @@ def _place_obj(place):
return p
def _is_pserver_mode(main_program):
main = main_program if main_program \
else framework.default_main_program()
for op in main.global_block().ops:
if op.type in ["send", "recv"]:
return True
return False
class CompiledProgram(object):
"""
Compiles a Program for execution.
......@@ -120,7 +111,8 @@ class CompiledProgram(object):
self._exec_strategy = ExecutionStrategy()
if self._build_strategy is None:
self._build_strategy = BuildStrategy()
self._build_strategy.is_distribution = _is_pserver_mode(self._program)
self._build_strategy.is_distribution = framework.is_pserver_mode(
self._program)
return self
def with_inference_optimize(self, config):
......
......@@ -85,6 +85,15 @@ def _current_expected_place():
return _imperative_current_expected_place_
def is_pserver_mode(main_program):
main = main_program if main_program \
else default_main_program()
for op in main.global_block().ops:
if op.type in ["send", "recv"]:
return True
return False
class NameScope(object):
def __init__(self, name="", parent=None):
self._children = dict()
......
......@@ -29,15 +29,6 @@ ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy
BuildStrategy = core.ParallelExecutor.BuildStrategy
def _is_pserver_mode(main_program):
main = main_program if main_program \
else framework.default_main_program()
for op in main.global_block().ops:
if op.type in ["send", "recv"]:
return True
return False
class ParallelExecutor(object):
"""
ParallelExecutor is designed for data parallelism, which focuses on distributing
......@@ -140,7 +131,7 @@ class ParallelExecutor(object):
# FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
# num_trainers is 1, so the current fields of build_strategy doesn't tell if
# it's distributed model.
build_strategy.is_distribution = _is_pserver_mode(
build_strategy.is_distribution = framework.is_pserver_mode(
main_program) or num_trainers > 1
# step4: get main_program, scope, local_scopes
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册