From 62f1248ff5bf7aafe57bcc4be0068529330604cb Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Thu, 21 Feb 2019 13:51:53 +0800 Subject: [PATCH] fix use gpu test=develop --- .../details/multi_devices_graph_pass.cc | 20 +++++++++++-------- .../details/multi_devices_graph_pass.h | 1 + 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 24977aabdac..e0246740dd7 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -731,6 +731,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, } } insert_op = true; + need_broadcast_var_ = true; } else if (OpHaveRole(*node, OpRole::kDist)) { int op_dev_id = CreateDistTrainOp(result, node); if (node->Op()->Type() == "concat") { @@ -925,14 +926,17 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const { // only GPU reduce mode need to broadcast parameters to each device. - if (UseGPU() && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { - if (strategy_.fuse_broadcast_op_) { - CreateFusedBroadcastOp(result, bcast_var_name_set_); - } else { - for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) { - auto &to_bcast_set = bcast_var_name_set_[dev_id]; - for (auto &bcast_name : to_bcast_set) { - CreateBroadcastOp(result, bcast_name, dev_id); + if (UseGPU()) { + if (need_broadcast_var_ || + strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { + if (strategy_.fuse_broadcast_op_) { + CreateFusedBroadcastOp(result, bcast_var_name_set_); + } else { + for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) { + auto &to_bcast_set = bcast_var_name_set_[dev_id]; + for (auto &bcast_name : to_bcast_set) { + CreateBroadcastOp(result, bcast_name, dev_id); + } } } } diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 21f85dc8286..6d4386538ea 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -174,6 +174,7 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder { int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const; mutable std::vector> bcast_var_name_set_; + mutable bool need_broadcast_var_{false}; }; std::unordered_set &MultiDevSSAGraphBuilder(); -- GitLab