diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index d5ca061944f33939cea59a5275e691b1966194fa..b0bf641d9d0b54f4788b14e25caf317c8eea3c27 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -73,6 +73,9 @@ void BroadcastOpHandle::RunImpl() { int root_id = boost::get(in_tensor.place()).device; std::vector> broadcast_calls; + int type = platform::ToNCCLDataType(in_tensor.type()); + size_t numel = static_cast(in_tensor.numel()); + for (auto out_var_handle : out_var_handles) { Variable *out_var = var_scopes.at(out_var_handle->scope_idx_) ->FindVar(out_var_handle->name_); @@ -87,13 +90,11 @@ void BroadcastOpHandle::RunImpl() { send_recv_buffer = const_cast(in_tensor.data()); out_handle = out_var_handle; } else { - send_recv_buffer = - VariableVisitor::GetMutableTensor(out_var).mutable_data( - out_var_handle->place_); + send_recv_buffer = VariableVisitor::GetMutableTensor(out_var) + .Resize(in_tensor.dims()) + .mutable_data(out_var_handle->place_); } - int type = platform::ToNCCLDataType(in_tensor.type()); - size_t numel = static_cast(in_tensor.numel()); broadcast_calls.emplace_back( [send_recv_buffer, numel, type, root_id, &nccl_ctx] { PADDLE_ENFORCE(platform::dynload::ncclBcast( @@ -102,23 +103,50 @@ void BroadcastOpHandle::RunImpl() { }); } - this->RunAndRecordEvent([&] { - { - platform::NCCLGroupGuard guard; - for (auto &call : broadcast_calls) { - call(); + // FIXME(zcd): a temporary fix for some language model that has sparse + // parameter. + bool use_mutex = true; + if (in_var->IsType()) { + use_mutex = false; + } + if (use_mutex) { + this->RunAndRecordEvent([&] { + { + platform::NCCLGroupGuard guard; + for (auto &call : broadcast_calls) { + call(); + } } - } - if (!out_handle->IsTheSameVar(*in_var_handle)) { - auto out_var = var_scopes.at(in_var_handle->scope_idx_) - ->FindVar(out_var_handles[0]->name_); - paddle::framework::TensorCopy( - in_tensor, in_var_handle->place_, - *(dev_ctxes_.at(in_var_handle->place_)), - &VariableVisitor::GetMutableTensor(out_var)); - } - }); + if (!out_handle->IsTheSameVar(*in_var_handle)) { + auto out_var = var_scopes.at(in_var_handle->scope_idx_) + ->FindVar(out_var_handles[0]->name_); + paddle::framework::TensorCopy( + in_tensor, in_var_handle->place_, + *(dev_ctxes_.at(in_var_handle->place_)), + &VariableVisitor::GetMutableTensor(out_var)); + } + }); + } else { + this->RunAndRecordEventNoMutex([&] { + { + platform::NCCLGroupGuard guard; + for (auto &call : broadcast_calls) { + call(); + } + } + + if (!out_handle->IsTheSameVar(*in_var_handle)) { + auto out_var = var_scopes.at(in_var_handle->scope_idx_) + ->FindVar(out_var_handles[0]->name_); + paddle::framework::TensorCopy( + in_tensor, in_var_handle->place_, + *(dev_ctxes_.at(in_var_handle->place_)), + &VariableVisitor::GetMutableTensor(out_var)); + } + }); + } + #else PADDLE_THROW("CUDA is not enabled."); #endif diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 99dcaf27134e879fa85f57e5a675382442e9edf2..a6fe64fa80d6bf036893d49de56d7274d49a3b30 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -351,7 +351,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); - auto var = new VarHandle(vars.size() - 1, i, og, p); + auto var = new VarHandle(vars.size(), i, og, p); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -447,8 +447,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, op_handle->AddInput(prev_grad.get()); } auto &vars = result->vars_[dst_dev_id][og]; - auto var = - new VarHandle(vars.size() - 1, dst_dev_id, og, places_[dst_dev_id]); + auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]); vars.emplace_back(var); op_handle->AddOutput(var); return var; diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index f79565fe71c4aef140475c922cbbf5a1e0b7fe03..a40a8815087f246996e4601b36304afd5544234e 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -139,6 +139,29 @@ void OpHandleBase::RunAndRecordEvent(const std::function &callback) { #endif } +void OpHandleBase::RunAndRecordEventNoMutex( + const std::function &callback) { +#ifdef PADDLE_WITH_CUDA + if (!events_.empty()) { // Use event + std::function method = callback; + + for (auto &p : dev_ctxes_) { + method = [method, p, this]() { + static_cast(p.second) + ->RecordEventNoMutex( + events_.at(boost::get(p.first).device), + method); + }; + } + method(); + } else { +#endif + callback(); +#ifdef PADDLE_WITH_CUDA + } +#endif +} + void OpHandleBase::RunAndRecordEvent(platform::Place p, const std::function &callback) { #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index fbd90a3296bca92b097cab925b218b91e7f4752f..775be0233a4a841dd210edbaa2da42dd739eae80 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -85,6 +85,10 @@ class OpHandleBase { protected: void RunAndRecordEvent(const std::function &callback); + // FIXME(zcd): A temporary fix for some language model that has sparse + // parameter. + void RunAndRecordEventNoMutex(const std::function &callback); + void RunAndRecordEvent(platform::Place p, const std::function &callback); diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc index 7160e346dad0615e2fd32b70c096880af0359e1a..9a626c890fa20b9d69812acbe8d899c3f72b1ca3 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.cc +++ b/paddle/fluid/framework/details/reduce_op_handle.cc @@ -80,7 +80,9 @@ void ReduceOpHandle::RunImpl() { } if (pre_in_var->IsType()) { - this->RunAndRecordEvent([&] { + // FIXME(zcd): A temporary fix for some language model that has sparse + // parameter. + this->RunAndRecordEventNoMutex([&] { std::vector in_selected_rows = GetInputValues(in_var_handles, var_scopes); GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p, diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 292ffef1aef12732812b8c5b0020cad73b1d06fc..d37e5ee57859ec90de8a99416a1600b32796f46e 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -106,6 +106,14 @@ class CUDADeviceContext : public DeviceContext { PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); } + // FIXME(zcd): A temporary fix for some language model that has sparse + // parameter. + template + void RecordEventNoMutex(cudaEvent_t ev, Callback callback) { + callback(); + PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); + } + private: CUDAPlace place_;