From 3bccc1e6e275412f30baf5a0c5698eb307f90252 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 22 Feb 2019 10:39:42 +0800 Subject: [PATCH] optimize broadcast logic test=develop --- .../details/multi_devices_graph_pass.cc | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index e0246740dd..c0fb3ee833 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -925,18 +925,20 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, } void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const { - // only GPU reduce mode need to broadcast parameters to each device. - if (UseGPU()) { - if (need_broadcast_var_ || + // broad cast received parameters when training in parameter server mode. + if (need_broadcast_var_) { + // cpu reduce mode did not need to broadcast received parameters. + if (!UseGPU() && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { - if (strategy_.fuse_broadcast_op_) { - CreateFusedBroadcastOp(result, bcast_var_name_set_); - } else { - for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) { - auto &to_bcast_set = bcast_var_name_set_[dev_id]; - for (auto &bcast_name : to_bcast_set) { - CreateBroadcastOp(result, bcast_name, dev_id); - } + return; + } + if (strategy_.fuse_broadcast_op_) { + CreateFusedBroadcastOp(result, bcast_var_name_set_); + } else { + for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) { + auto &to_bcast_set = bcast_var_name_set_[dev_id]; + for (auto &bcast_name : to_bcast_set) { + CreateBroadcastOp(result, bcast_name, dev_id); } } } -- GitLab