diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index 8450d8eb8b057d18d2e004964d1e25b32c142823..7c5f5bd80a937bf1a1c891155764833d7b21c5c2 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -46,11 +46,7 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, #endif void AllReduceOpHandle::RunImpl() { - if (dev_ctxes_.size() > 0UL) { - platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); - } else { - platform::RecordEvent record_event(Name(), nullptr); - } + platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); if (NoDummyInputSize() == 1) { return; // No need to all reduce when GPU count = 1; diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index 35962ade99c024235d2bb4a203bec52bf4ef2063..4fdab5cd94358d08eac7f8b041bf16d09042f0bd 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -22,11 +22,7 @@ namespace framework { namespace details { void BroadcastOpHandle::RunImpl() { - if (dev_ctxes_.size() > 0UL) { - platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); - } else { - platform::RecordEvent record_event(Name(), nullptr); - } + platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); if (places_.size() == 1) return; diff --git a/paddle/fluid/framework/details/data_balance_op_handle.cc b/paddle/fluid/framework/details/data_balance_op_handle.cc index 91f6a42e6eab29ca5aa2ddf0442e6706f66e205f..8eb3568e0549a65b6fb32cb43cc4743260e1bbe8 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.cc +++ b/paddle/fluid/framework/details/data_balance_op_handle.cc @@ -87,12 +87,6 @@ std::vector> DataBalanceOpHandle::GetBalancePlan( } void DataBalanceOpHandle::RunImpl() { - if (dev_ctxes_.size() > 0UL) { - platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); - } else { - platform::RecordEvent record_event(Name(), nullptr); - } - PADDLE_ENFORCE_GT(places_.size(), 1, "Data balance can only be enabled when the number of " "places to run larger than 1."); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index cd6c8b50a9370428e6fbd79dec81c409cab8fc64..11b085c5c78ba2e90f7a2dd9f325bb5016399a16 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -431,10 +431,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( CreateReduceOp(&result, g_name, cur_device_id); graph->Get(kShardedVarDevice) .emplace(g_name, cur_device_id); - if (!is_dist_train) { - // will send gradients directly when distributed training - bcast_var_name_set[cur_device_id].emplace(p_name); - } + bcast_var_name_set[cur_device_id].emplace(p_name); break; case BuildStrategy::ReduceStrategy::kAllReduce: if (IsSparseGradient(g_name)) { @@ -461,9 +458,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( use_gpu = nccl_ctxs_ != nullptr; #endif - if ((use_gpu && - strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) || - is_dist_train) { + if (use_gpu && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce && + !is_dist_train) { // Insert BCast Ops for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { auto &to_bcast_set = bcast_var_name_set[dev_id];