提交 2b7931d5 编写于 作者: Q Qiao Longfei

refine code test=develop

上级 3f9263f6
...@@ -133,15 +133,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -133,15 +133,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
void AppendMultiDevPass(const BuildStrategy &strategy) { void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass; ir::Pass *multi_devices_pass;
if (strategy_.is_distribution_) { if (strategy_.is_distribution_) {
VLOG(3) << "multi device dist train mode"; VLOG(3) << "multi device parameter server mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else { } else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
VLOG(3) << "multi device allreduce mode"; VLOG(3) << "multi devices collective mode with allreduce";
multi_devices_pass = multi_devices_pass =
AppendPass("allreduce_mode_multi_devices_pass").get(); AppendPass("allreduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { } else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
VLOG(3) << "multi device reduce mode"; VLOG(3) << "multi deivces collective mode with reduce";
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get(); multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
} else { } else {
PADDLE_THROW("Unknown reduce strategy."); PADDLE_THROW("Unknown reduce strategy.");
......
...@@ -35,15 +35,6 @@ def _place_obj(place): ...@@ -35,15 +35,6 @@ def _place_obj(place):
return p return p
def _is_pserver_mode(main_program):
main = main_program if main_program \
else framework.default_main_program()
for op in main.global_block().ops:
if op.type in ["send", "recv"]:
return True
return False
class CompiledProgram(object): class CompiledProgram(object):
""" """
Compiles a Program for execution. Compiles a Program for execution.
...@@ -120,7 +111,8 @@ class CompiledProgram(object): ...@@ -120,7 +111,8 @@ class CompiledProgram(object):
self._exec_strategy = ExecutionStrategy() self._exec_strategy = ExecutionStrategy()
if self._build_strategy is None: if self._build_strategy is None:
self._build_strategy = BuildStrategy() self._build_strategy = BuildStrategy()
self._build_strategy.is_distribution = _is_pserver_mode(self._program) self._build_strategy.is_distribution = framework.is_pserver_mode(
self._program)
return self return self
def with_inference_optimize(self, config): def with_inference_optimize(self, config):
......
...@@ -85,6 +85,15 @@ def _current_expected_place(): ...@@ -85,6 +85,15 @@ def _current_expected_place():
return _imperative_current_expected_place_ return _imperative_current_expected_place_
def is_pserver_mode(main_program):
main = main_program if main_program \
else default_main_program()
for op in main.global_block().ops:
if op.type in ["send", "recv"]:
return True
return False
class NameScope(object): class NameScope(object):
def __init__(self, name="", parent=None): def __init__(self, name="", parent=None):
self._children = dict() self._children = dict()
......
...@@ -29,15 +29,6 @@ ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy ...@@ -29,15 +29,6 @@ ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy
BuildStrategy = core.ParallelExecutor.BuildStrategy BuildStrategy = core.ParallelExecutor.BuildStrategy
def _is_pserver_mode(main_program):
main = main_program if main_program \
else framework.default_main_program()
for op in main.global_block().ops:
if op.type in ["send", "recv"]:
return True
return False
class ParallelExecutor(object): class ParallelExecutor(object):
""" """
ParallelExecutor is designed for data parallelism, which focuses on distributing ParallelExecutor is designed for data parallelism, which focuses on distributing
...@@ -140,7 +131,7 @@ class ParallelExecutor(object): ...@@ -140,7 +131,7 @@ class ParallelExecutor(object):
# FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode, # FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
# num_trainers is 1, so the current fields of build_strategy doesn't tell if # num_trainers is 1, so the current fields of build_strategy doesn't tell if
# it's distributed model. # it's distributed model.
build_strategy.is_distribution = _is_pserver_mode( build_strategy.is_distribution = framework.is_pserver_mode(
main_program) or num_trainers > 1 main_program) or num_trainers > 1
# step4: get main_program, scope, local_scopes # step4: get main_program, scope, local_scopes
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册