From 93401c98e1b8a3dbfde494ee6b0e766a5b38b6a1 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Wed, 6 Jun 2018 12:40:39 +0800 Subject: [PATCH] overlap rpc op memcpy in distributed training --- .../details/multi_devices_graph_builder.cc | 58 ++++++++++++++++--- .../details/multi_devices_graph_builder.h | 14 ++++- paddle/fluid/framework/parallel_executor.cc | 27 ++++++--- paddle/fluid/framework/parallel_executor.h | 3 + 4 files changed, 82 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 17baacd13ee..635faafe4c6 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -191,15 +191,54 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( }; bool is_forwarding = true; + std::unordered_map rpc_var_device_mapping; + int rpc_op_device_id = 0; + auto schedule_rpc_op = [&]() -> void { + rpc_op_device_id++; + if (rpc_op_device_id >= static_cast(places_.size())) { + rpc_op_device_id = 0; + } + }; + for (auto *op : program.Block(0).AllOps()) { if (boost::get( op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == static_cast(OpRole::kRPC)) { // append rpc op if program is distributed trainer main program. // always use the first device - CreateRPCOp(&result, *op); + if (op->Type() == "send_vars") { + auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]); + if (got == remote_vars_devices_.end()) { + schedule_rpc_op(); + } else { + rpc_op_device_id = got->second; + } + CreateRPCOp(&result, *op, rpc_op_device_id); + } else if (op->Type() == "recv") { + schedule_rpc_op(); + for (auto &varname : op->OutputArgumentNames()) { + remote_vars_devices_.insert({varname, rpc_op_device_id}); + } + CreateRPCOp(&result, *op, rpc_op_device_id); + } else { + CreateRPCOp(&result, *op, 0); + } } else if (IsDistTrainOp(*op, send_vars, recv_vars)) { - CreateDistTrainOp(&result, *op); + if (op->Type() == "split_byref") { + schedule_rpc_op(); + for (auto &varname : op->OutputArgumentNames()) { + remote_vars_devices_.insert({varname, rpc_op_device_id}); + } + CreateDistTrainOp(&result, *op, rpc_op_device_id); + } + if (op->Type() == "oncat") { + auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]); + PADDLE_ENFORCE_NE(got != remote_vars_devices_.end(), + "can not find right place to concat received var."); + CreateDistTrainOp(&result, *op, got->second); + } else { + CreateDistTrainOp(&result, *op, 0); + } } else if (IsScaleLossOp(*op)) { // user can customize loss@grad if not use_default_grad_scale_ if (strategy_.gradient_scale_ != @@ -464,17 +503,18 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, } void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, - const OpDesc &op) const { - CreateComputationalOp(result, op, 0); + const OpDesc &op, + int place_id) const { + CreateComputationalOp(result, op, place_id); if (op.Type() == "concat") { ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); } } -void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, - const OpDesc &op) const { - auto &p = places_[0]; - auto *s = local_scopes_[0]; +void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op, + int place_id) const { + auto &p = places_[place_id]; + auto *s = local_scopes_[place_id]; result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type())); if (op.Type() == "send_barrier") { @@ -493,7 +533,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, // TODO(Yancey1989): schedule rpc op on different place may // increate throughput - CreateOpHandleIOs(result, op, 0); + CreateOpHandleIOs(result, op, place_id); } bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 544cbe585c7..79c2b79a3ff 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -48,6 +48,14 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { std::unique_ptr Build(const ProgramDesc &program) const override; + int GetRemoteVarDevice(const std::string &var_name) const { + auto got = remote_vars_devices_.find(var_name); + if (got != remote_vars_devices_.end()) { + return got->second; + } + return -1; + } + private: void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, size_t place_id) const; @@ -64,8 +72,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { bool IsScaleLossOp(const OpDesc &op) const; - void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; - void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const; + void CreateRPCOp(SSAGraph *result, const OpDesc &op, int place_id) const; + void CreateDistTrainOp(SSAGraph *result, const OpDesc &op, + int place_id) const; /** * Is this operator as the end-point operator before/after send operator. @@ -111,6 +120,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { private: BuildStrategy strategy_; + mutable std::unordered_map remote_vars_devices_; }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 50c3468d556..0c1e8f3f956 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -22,7 +22,6 @@ limitations under the License. */ #include "paddle/fluid/platform/nccl_helper.h" #endif -#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/platform/profiler.h" @@ -97,15 +96,17 @@ ParallelExecutor::ParallelExecutor( // Step 2. Convert main_program to SSA form and dependency graph. Also, insert // ncclOp #ifdef PADDLE_WITH_CUDA - details::MultiDevSSAGraphBuilder builder( + builder_.reset(new details::MultiDevSSAGraphBuilder( member_->places_, loss_var_name, params, member_->local_scopes_, - member_->nccl_ctxs_.get(), build_strategy); + member_->nccl_ctxs_.get(), build_strategy)); + #else - details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, - params, member_->local_scopes_, - build_strategy); + builder_.reset(new details::MultiDevSSAGraphBuilder( + member_->places_, loss_var_name, params, member_->local_scope_, + build_strategy)); + #endif - auto graph = builder.Build(main_program); + auto graph = builder_.get()->Build(main_program); member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, std::move(graph))); @@ -146,8 +147,16 @@ void ParallelExecutor::BCastParamsToGPUs( buffer = t->mutable_data(place, main_tensor.type()); } auto &nccl_ctx = member_->nccl_ctxs_->at(place); - platform::dynload::ncclBcast(buffer, numel, data_type, 0, - nccl_ctx.comm_, nccl_ctx.stream()); + + if (builder_.get() != nullptr && + builder_->GetRemoteVarDevice(var) != -1) { + int place_id = builder_->GetRemoteVarDevice(var); + platform::dynload::ncclBcast(buffer, numel, data_type, place_id, + nccl_ctx.comm_, nccl_ctx.stream()); + } else { + platform::dynload::ncclBcast(buffer, numel, data_type, 0, + nccl_ctx.comm_, nccl_ctx.stream()); + } } } else { platform::CPUPlace cpu; diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 5247e790649..b71a440d6a0 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -19,12 +19,14 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/details/execution_strategy.h" +#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" + namespace paddle { namespace framework { @@ -68,6 +70,7 @@ class ParallelExecutor { private: ParallelExecutorPrivate *member_; + std::unique_ptr builder_; }; } // namespace framework -- GitLab