未验证 提交 9b3f48d7 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #11616 from chengduoZH/fix_parallel_exe

Enhance Parallel Executor stable
...@@ -73,6 +73,9 @@ void BroadcastOpHandle::RunImpl() { ...@@ -73,6 +73,9 @@ void BroadcastOpHandle::RunImpl() {
int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device; int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
std::vector<std::function<void()>> broadcast_calls; std::vector<std::function<void()>> broadcast_calls;
int type = platform::ToNCCLDataType(in_tensor.type());
size_t numel = static_cast<size_t>(in_tensor.numel());
for (auto out_var_handle : out_var_handles) { for (auto out_var_handle : out_var_handles) {
Variable *out_var = var_scopes.at(out_var_handle->scope_idx_) Variable *out_var = var_scopes.at(out_var_handle->scope_idx_)
->FindVar(out_var_handle->name_); ->FindVar(out_var_handle->name_);
...@@ -87,13 +90,11 @@ void BroadcastOpHandle::RunImpl() { ...@@ -87,13 +90,11 @@ void BroadcastOpHandle::RunImpl() {
send_recv_buffer = const_cast<void *>(in_tensor.data<void>()); send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
out_handle = out_var_handle; out_handle = out_var_handle;
} else { } else {
send_recv_buffer = send_recv_buffer = VariableVisitor::GetMutableTensor(out_var)
VariableVisitor::GetMutableTensor(out_var).mutable_data( .Resize(in_tensor.dims())
out_var_handle->place_); .mutable_data(out_var_handle->place_);
} }
int type = platform::ToNCCLDataType(in_tensor.type());
size_t numel = static_cast<size_t>(in_tensor.numel());
broadcast_calls.emplace_back( broadcast_calls.emplace_back(
[send_recv_buffer, numel, type, root_id, &nccl_ctx] { [send_recv_buffer, numel, type, root_id, &nccl_ctx] {
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE(platform::dynload::ncclBcast(
...@@ -102,23 +103,50 @@ void BroadcastOpHandle::RunImpl() { ...@@ -102,23 +103,50 @@ void BroadcastOpHandle::RunImpl() {
}); });
} }
this->RunAndRecordEvent([&] { // FIXME(zcd): a temporary fix for some language model that has sparse
{ // parameter.
platform::NCCLGroupGuard guard; bool use_mutex = true;
for (auto &call : broadcast_calls) { if (in_var->IsType<paddle::framework::SelectedRows>()) {
call(); use_mutex = false;
}
if (use_mutex) {
this->RunAndRecordEvent([&] {
{
platform::NCCLGroupGuard guard;
for (auto &call : broadcast_calls) {
call();
}
} }
}
if (!out_handle->IsTheSameVar(*in_var_handle)) { if (!out_handle->IsTheSameVar(*in_var_handle)) {
auto out_var = var_scopes.at(in_var_handle->scope_idx_) auto out_var = var_scopes.at(in_var_handle->scope_idx_)
->FindVar(out_var_handles[0]->name_); ->FindVar(out_var_handles[0]->name_);
paddle::framework::TensorCopy( paddle::framework::TensorCopy(
in_tensor, in_var_handle->place_, in_tensor, in_var_handle->place_,
*(dev_ctxes_.at(in_var_handle->place_)), *(dev_ctxes_.at(in_var_handle->place_)),
&VariableVisitor::GetMutableTensor(out_var)); &VariableVisitor::GetMutableTensor(out_var));
} }
}); });
} else {
this->RunAndRecordEventNoMutex([&] {
{
platform::NCCLGroupGuard guard;
for (auto &call : broadcast_calls) {
call();
}
}
if (!out_handle->IsTheSameVar(*in_var_handle)) {
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
->FindVar(out_var_handles[0]->name_);
paddle::framework::TensorCopy(
in_tensor, in_var_handle->place_,
*(dev_ctxes_.at(in_var_handle->place_)),
&VariableVisitor::GetMutableTensor(out_var));
}
});
}
#else #else
PADDLE_THROW("CUDA is not enabled."); PADDLE_THROW("CUDA is not enabled.");
#endif #endif
......
...@@ -351,7 +351,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, ...@@ -351,7 +351,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
auto var = new VarHandle(vars.size() - 1, i, og, p); auto var = new VarHandle(vars.size(), i, og, p);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
...@@ -447,8 +447,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, ...@@ -447,8 +447,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
} }
auto &vars = result->vars_[dst_dev_id][og]; auto &vars = result->vars_[dst_dev_id][og];
auto var = auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]);
new VarHandle(vars.size() - 1, dst_dev_id, og, places_[dst_dev_id]);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
return var; return var;
......
...@@ -139,6 +139,29 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) { ...@@ -139,6 +139,29 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
#endif #endif
} }
void OpHandleBase::RunAndRecordEventNoMutex(
const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event
std::function<void()> method = callback;
for (auto &p : dev_ctxes_) {
method = [method, p, this]() {
static_cast<platform::CUDADeviceContext *>(p.second)
->RecordEventNoMutex(
events_.at(boost::get<platform::CUDAPlace>(p.first).device),
method);
};
}
method();
} else {
#endif
callback();
#ifdef PADDLE_WITH_CUDA
}
#endif
}
void OpHandleBase::RunAndRecordEvent(platform::Place p, void OpHandleBase::RunAndRecordEvent(platform::Place p,
const std::function<void()> &callback) { const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -85,6 +85,10 @@ class OpHandleBase { ...@@ -85,6 +85,10 @@ class OpHandleBase {
protected: protected:
void RunAndRecordEvent(const std::function<void()> &callback); void RunAndRecordEvent(const std::function<void()> &callback);
// FIXME(zcd): A temporary fix for some language model that has sparse
// parameter.
void RunAndRecordEventNoMutex(const std::function<void()> &callback);
void RunAndRecordEvent(platform::Place p, void RunAndRecordEvent(platform::Place p,
const std::function<void()> &callback); const std::function<void()> &callback);
......
...@@ -80,7 +80,9 @@ void ReduceOpHandle::RunImpl() { ...@@ -80,7 +80,9 @@ void ReduceOpHandle::RunImpl() {
} }
if (pre_in_var->IsType<framework::SelectedRows>()) { if (pre_in_var->IsType<framework::SelectedRows>()) {
this->RunAndRecordEvent([&] { // FIXME(zcd): A temporary fix for some language model that has sparse
// parameter.
this->RunAndRecordEventNoMutex([&] {
std::vector<const SelectedRows *> in_selected_rows = std::vector<const SelectedRows *> in_selected_rows =
GetInputValues<SelectedRows>(in_var_handles, var_scopes); GetInputValues<SelectedRows>(in_var_handles, var_scopes);
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p, GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,
......
...@@ -106,6 +106,14 @@ class CUDADeviceContext : public DeviceContext { ...@@ -106,6 +106,14 @@ class CUDADeviceContext : public DeviceContext {
PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
} }
// FIXME(zcd): A temporary fix for some language model that has sparse
// parameter.
template <typename Callback>
void RecordEventNoMutex(cudaEvent_t ev, Callback callback) {
callback();
PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
}
private: private:
CUDAPlace place_; CUDAPlace place_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册