From 1e1b6622fdce1b704c7753e2c16656bdc97ac24e Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 13 Sep 2018 10:44:39 +0800 Subject: [PATCH] update by comment --- paddle/fluid/framework/details/all_reduce_op_handle.cc | 6 +----- paddle/fluid/framework/details/broadcast_op_handle.cc | 6 +----- .../fluid/framework/details/data_balance_op_handle.cc | 6 ------ .../framework/details/multi_devices_graph_pass.cc | 10 +++------- 4 files changed, 5 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index 8450d8eb8b..7c5f5bd80a 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -46,11 +46,7 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, #endif void AllReduceOpHandle::RunImpl() { - if (dev_ctxes_.size() > 0UL) { - platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); - } else { - platform::RecordEvent record_event(Name(), nullptr); - } + platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); if (NoDummyInputSize() == 1) { return; // No need to all reduce when GPU count = 1; diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index 35962ade99..4fdab5cd94 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -22,11 +22,7 @@ namespace framework { namespace details { void BroadcastOpHandle::RunImpl() { - if (dev_ctxes_.size() > 0UL) { - platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); - } else { - platform::RecordEvent record_event(Name(), nullptr); - } + platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); if (places_.size() == 1) return; diff --git a/paddle/fluid/framework/details/data_balance_op_handle.cc b/paddle/fluid/framework/details/data_balance_op_handle.cc index 91f6a42e6e..8eb3568e05 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.cc +++ b/paddle/fluid/framework/details/data_balance_op_handle.cc @@ -87,12 +87,6 @@ std::vector> DataBalanceOpHandle::GetBalancePlan( } void DataBalanceOpHandle::RunImpl() { - if (dev_ctxes_.size() > 0UL) { - platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); - } else { - platform::RecordEvent record_event(Name(), nullptr); - } - PADDLE_ENFORCE_GT(places_.size(), 1, "Data balance can only be enabled when the number of " "places to run larger than 1."); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index cd6c8b50a9..11b085c5c7 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -431,10 +431,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( CreateReduceOp(&result, g_name, cur_device_id); graph->Get(kShardedVarDevice) .emplace(g_name, cur_device_id); - if (!is_dist_train) { - // will send gradients directly when distributed training - bcast_var_name_set[cur_device_id].emplace(p_name); - } + bcast_var_name_set[cur_device_id].emplace(p_name); break; case BuildStrategy::ReduceStrategy::kAllReduce: if (IsSparseGradient(g_name)) { @@ -461,9 +458,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( use_gpu = nccl_ctxs_ != nullptr; #endif - if ((use_gpu && - strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) || - is_dist_train) { + if (use_gpu && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce && + !is_dist_train) { // Insert BCast Ops for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { auto &to_bcast_set = bcast_var_name_set[dev_id]; -- GitLab