From 2d94697a820f5cac6b3a150d538fd41a0b97594b Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sun, 10 Jun 2018 16:52:04 +0800 Subject: [PATCH] code refine --- .../details/multi_devices_graph_builder.cc | 45 +++++++------------ .../details/multi_devices_graph_builder.h | 3 ++ 2 files changed, 18 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index cd7dda143c..8a5f171ce5 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -297,21 +297,26 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, op_handle->AddInput(in); for (size_t i = 0; i < places_.size(); ++i) { - auto &vars = result->vars_.at(i).at(p_name); auto &p = places_[i]; + SetCommunicationContext(op_handle, p); + auto &vars = result->vars_.at(i).at(p_name); auto *out_var = new VarHandle(vars.size(), i, p_name, p); vars.emplace_back(out_var); op_handle->AddOutput(out_var); + } +} + +void MultiDevSSAGraphBuilder::SetCommunicationContext( + OpHandleBase *op_handle, const platform::Place &p) const { #ifdef PADDLE_WITH_CUDA - if (nccl_ctxs_ == nullptr) { - op_handle->SetDeviceContext( - p, platform::DeviceContextPool::Instance().Get(p)); - } -#else + if (nccl_ctxs_ == nullptr) { op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); -#endif } +#else + op_handle->SetDeviceContext(p, + platform::DeviceContextPool::Instance().Get(p)); +#endif } void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, @@ -334,24 +339,12 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; + SetCommunicationContext(op_handle, p); auto &vars = result->vars_[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); -#ifdef PADDLE_WITH_CUDA - if (nccl_ctxs_ == nullptr) { - op_handle->SetDeviceContext( - p, platform::DeviceContextPool::Instance().Get(p)); - } -#else - op_handle->SetDeviceContext(p, - platform::DeviceContextPool::Instance().Get(p)); -#endif - - VLOG(4) << "NCCL - - - " << p; - op_handle->DeviceContext(p)->Wait(); - VLOG(4) << "NCCL - - - " << p << " " << op_handle->DeviceContext(p); auto var = new VarHandle(vars.size() - 1, i, og, p); vars.emplace_back(var); op_handle->AddOutput(var); @@ -441,17 +434,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, auto *op_handle = result->ops_.back().get(); for (size_t i = 0; i < places_.size(); ++i) { - auto &vars = result->vars_[i][og]; auto &p = places_[i]; -#ifdef PADDLE_WITH_CUDA - if (nccl_ctxs_ == nullptr) { - op_handle->SetDeviceContext( - p, platform::DeviceContextPool::Instance().Get(p)); - } -#else - op_handle->SetDeviceContext(p, - platform::DeviceContextPool::Instance().Get(p)); -#endif + SetCommunicationContext(op_handle, p); + auto &vars = result->vars_[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 544cbe585c..bcedc9b8b8 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -111,6 +111,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { private: BuildStrategy strategy_; + + void SetCommunicationContext(OpHandleBase *op_handle, + const platform::Place &p) const; }; } // namespace details } // namespace framework -- GitLab