提交 a7152613 编写于 作者: Q Qiao Longfei

Merge branch 'fix-cpu-broadcast' of ssh://github.com/jacquesqiao/Paddle into add-communicator

...@@ -133,12 +133,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -133,12 +133,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
void AppendMultiDevPass(const BuildStrategy &strategy) { void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass; ir::Pass *multi_devices_pass;
if (strategy_.is_distribution_) { if (strategy_.is_distribution_) {
VLOG(3) << "dist train mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else { } else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
VLOG(3) << "allreduce mode";
multi_devices_pass = multi_devices_pass =
AppendPass("allreduce_mode_multi_devices_pass").get(); AppendPass("allreduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { } else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
VLOG(3) << "reduce mode";
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get(); multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
} else { } else {
PADDLE_THROW("Unknown reduce strategy."); PADDLE_THROW("Unknown reduce strategy.");
......
...@@ -925,9 +925,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ...@@ -925,9 +925,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
} }
void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const { void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
if (need_broadcast_var_ || if (UseGPU() && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
(UseGPU() &&
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce)) {
if (strategy_.fuse_broadcast_op_) { if (strategy_.fuse_broadcast_op_) {
CreateFusedBroadcastOp(result, bcast_var_name_set_); CreateFusedBroadcastOp(result, bcast_var_name_set_);
} else { } else {
......
...@@ -19,6 +19,7 @@ import sys ...@@ -19,6 +19,7 @@ import sys
from .. import compat as cpt from .. import compat as cpt
from . import core from . import core
from . import framework
__all__ = ['CompiledProgram', 'ExecutionStrategy', 'BuildStrategy'] __all__ = ['CompiledProgram', 'ExecutionStrategy', 'BuildStrategy']
...@@ -34,6 +35,15 @@ def _place_obj(place): ...@@ -34,6 +35,15 @@ def _place_obj(place):
return p return p
def _is_pserver_mode(main_program):
main = main_program if main_program \
else framework.default_main_program()
for op in main.global_block().ops:
if op.type in ["send", "recv"]:
return True
return False
class CompiledProgram(object): class CompiledProgram(object):
""" """
Compiles a Program for execution. Compiles a Program for execution.
...@@ -110,6 +120,8 @@ class CompiledProgram(object): ...@@ -110,6 +120,8 @@ class CompiledProgram(object):
self._exec_strategy = ExecutionStrategy() self._exec_strategy = ExecutionStrategy()
if self._build_strategy is None: if self._build_strategy is None:
self._build_strategy = BuildStrategy() self._build_strategy = BuildStrategy()
self._build_strategy.is_distribution = _is_pserver_mode(
self._program) or self._build_strategy.num_trainers > 1
return self return self
def with_inference_optimize(self, config): def with_inference_optimize(self, config):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册