diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index 1d9f1bd6e417e30f0799f0bbed1739cedb4e8fbf..b0bf641d9d0b54f4788b14e25caf317c8eea3c27 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -103,23 +103,50 @@ void BroadcastOpHandle::RunImpl() { }); } - this->RunAndRecordEvent([&] { - { - platform::NCCLGroupGuard guard; - for (auto &call : broadcast_calls) { - call(); + // FIXME(zcd): a temporary fix for some language model that has sparse + // parameter. + bool use_mutex = true; + if (in_var->IsType()) { + use_mutex = false; + } + if (use_mutex) { + this->RunAndRecordEvent([&] { + { + platform::NCCLGroupGuard guard; + for (auto &call : broadcast_calls) { + call(); + } } - } - if (!out_handle->IsTheSameVar(*in_var_handle)) { - auto out_var = var_scopes.at(in_var_handle->scope_idx_) - ->FindVar(out_var_handles[0]->name_); - paddle::framework::TensorCopy( - in_tensor, in_var_handle->place_, - *(dev_ctxes_.at(in_var_handle->place_)), - &VariableVisitor::GetMutableTensor(out_var)); - } - }); + if (!out_handle->IsTheSameVar(*in_var_handle)) { + auto out_var = var_scopes.at(in_var_handle->scope_idx_) + ->FindVar(out_var_handles[0]->name_); + paddle::framework::TensorCopy( + in_tensor, in_var_handle->place_, + *(dev_ctxes_.at(in_var_handle->place_)), + &VariableVisitor::GetMutableTensor(out_var)); + } + }); + } else { + this->RunAndRecordEventNoMutex([&] { + { + platform::NCCLGroupGuard guard; + for (auto &call : broadcast_calls) { + call(); + } + } + + if (!out_handle->IsTheSameVar(*in_var_handle)) { + auto out_var = var_scopes.at(in_var_handle->scope_idx_) + ->FindVar(out_var_handles[0]->name_); + paddle::framework::TensorCopy( + in_tensor, in_var_handle->place_, + *(dev_ctxes_.at(in_var_handle->place_)), + &VariableVisitor::GetMutableTensor(out_var)); + } + }); + } + #else PADDLE_THROW("CUDA is not enabled."); #endif diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index f79565fe71c4aef140475c922cbbf5a1e0b7fe03..a40a8815087f246996e4601b36304afd5544234e 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -139,6 +139,29 @@ void OpHandleBase::RunAndRecordEvent(const std::function &callback) { #endif } +void OpHandleBase::RunAndRecordEventNoMutex( + const std::function &callback) { +#ifdef PADDLE_WITH_CUDA + if (!events_.empty()) { // Use event + std::function method = callback; + + for (auto &p : dev_ctxes_) { + method = [method, p, this]() { + static_cast(p.second) + ->RecordEventNoMutex( + events_.at(boost::get(p.first).device), + method); + }; + } + method(); + } else { +#endif + callback(); +#ifdef PADDLE_WITH_CUDA + } +#endif +} + void OpHandleBase::RunAndRecordEvent(platform::Place p, const std::function &callback) { #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index fbd90a3296bca92b097cab925b218b91e7f4752f..775be0233a4a841dd210edbaa2da42dd739eae80 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -85,6 +85,10 @@ class OpHandleBase { protected: void RunAndRecordEvent(const std::function &callback); + // FIXME(zcd): A temporary fix for some language model that has sparse + // parameter. + void RunAndRecordEventNoMutex(const std::function &callback); + void RunAndRecordEvent(platform::Place p, const std::function &callback); diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc index 7160e346dad0615e2fd32b70c096880af0359e1a..9a626c890fa20b9d69812acbe8d899c3f72b1ca3 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.cc +++ b/paddle/fluid/framework/details/reduce_op_handle.cc @@ -80,7 +80,9 @@ void ReduceOpHandle::RunImpl() { } if (pre_in_var->IsType()) { - this->RunAndRecordEvent([&] { + // FIXME(zcd): A temporary fix for some language model that has sparse + // parameter. + this->RunAndRecordEventNoMutex([&] { std::vector in_selected_rows = GetInputValues(in_var_handles, var_scopes); GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p, diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 292ffef1aef12732812b8c5b0020cad73b1d06fc..d37e5ee57859ec90de8a99416a1600b32796f46e 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -106,6 +106,14 @@ class CUDADeviceContext : public DeviceContext { PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); } + // FIXME(zcd): A temporary fix for some language model that has sparse + // parameter. + template + void RecordEventNoMutex(cudaEvent_t ev, Callback callback) { + callback(); + PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); + } + private: CUDAPlace place_;