From 9612c7e599cef15a9641fa69a1cb78809977b549 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 27 Apr 2018 16:47:27 +0800 Subject: [PATCH] Add comments and polish code --- .../details/multi_devices_graph_builder.cc | 58 +++++++++---------- .../details/multi_devices_graph_builder.h | 11 +++- paddle/fluid/framework/details/ssa_graph.h | 10 ++++ .../framework/details/ssa_graph_builder.h | 2 + 4 files changed, 51 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 3413467b149..c2eb1c31b4f 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -58,23 +58,20 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, - const platform::Place &p, - const size_t &i) const { + size_t place_id) const { + auto p = places_[place_id]; auto *op_handle = result->ops_.back().get(); op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); - auto var_names = op.InputArgumentNames(); - - for (auto &each_var_name : var_names) { - VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i); + for (auto &each_var_name : op.InputArgumentNames()) { + VarHandle *var = + CreateOrGetLatestVarHandle(result, each_var_name, p, place_id); op_handle->AddInput(var); } - var_names = op.OutputArgumentNames(); - - for (auto &each_var_name : var_names) { - CreateOpOutput(result, op_handle, each_var_name, p, i); + for (auto &each_var_name : op.OutputArgumentNames()) { + CreateOpOutput(result, op_handle, each_var_name, p, place_id); } } @@ -84,17 +81,18 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, return false; } - auto checker = [&](const std::vector opvars, - const std::vector sendvars) -> bool { - bool is_dist_train_op = false; + /** + * Check any of opvars contains `.block` and in sendvars + */ + auto checker = [](const std::vector &opvars, + const std::vector &sendvars) -> bool { for (auto &var : opvars) { if (var.find(".block") != std::string::npos && std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) { - is_dist_train_op = true; - break; + return true; } } - return is_dist_train_op; + return false; }; if (op.Type() == "split") { @@ -117,13 +115,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( places_.size()); // Find "send" op first for split is in front of send. - OpDesc *send_op = nullptr; - for (auto *op : program.Block(0).AllOps()) { - if (op->Type() == "send") { - send_op = op; - break; - } - } + OpDesc *send_op = GetSendOpDesc(program); bool is_forwarding = true; for (auto *op : program.Block(0).AllOps()) { @@ -134,6 +126,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } else if (IsDistTrainOp(*op, send_op)) { CreateComputationalOps(&result, *op, 1); } else if (IsScaleLossOp(*op)) { + // user can customize loss@grad if skip_scale_loss_ if (!skip_scale_loss_) { CreateScaleLossGradOp(&result); } @@ -142,10 +135,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( CreateComputationalOps(&result, *op, places_.size()); if (!is_forwarding) { // Currently, we assume that once gradient is generated, it can be - // broadcast, and each gradient is only broadcast once. But there are no - // other cases, for example, we need to adjust the gradient according to - // the input when we get the gradient, which is not considered at - // present. + // broadcast, and each gradient is only broadcast once. for (auto &og : op->OutputArgumentNames()) { if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { InsertNCCLAllReduceOp(&result, og); @@ -175,6 +165,16 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( return std::unique_ptr(graph); } +OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc( + const ProgramDesc &program) const { + for (auto *op : program.Block(0).AllOps()) { + if (op->Type() == "send") { + return op; + } + } + return nullptr; +} + void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( SSAGraph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA @@ -243,7 +243,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result, auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; result->ops_.emplace_back(new ComputationOpHandle(op, s, p)); - CreateOpHandleIOs(result, op, p, scope_idx); + CreateOpHandleIOs(result, op, scope_idx); } } @@ -255,7 +255,7 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, result->ops_.emplace_back(new SendOpHandle(op, s, p)); // Create inputs for output on original place and no ssa output // is created for send op. - CreateOpHandleIOs(result, op, p, 0); + CreateOpHandleIOs(result, op, 0); } bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index dc3da70eda2..fa4d31bdc49 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -48,7 +48,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { private: void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, - const platform::Place &p, const size_t &i) const; + size_t place_id) const; private: std::string loss_var_name_; @@ -65,6 +65,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { void CreateSendOp(SSAGraph *result, const OpDesc &op) const; + /** + * Is this operator as the end-point operator before/after send operator. + */ bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const; void CreateComputationalOps(SSAGraph *result, const OpDesc &op, @@ -77,6 +80,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { std::unordered_set *og_has_been_broadcast) const; void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const; + + /** + * Get send op in the global block of program. + * nullptr if not found. + */ + OpDesc *GetSendOpDesc(const ProgramDesc &program) const; }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/ssa_graph.h b/paddle/fluid/framework/details/ssa_graph.h index 72684e7f97f..e996a00c162 100644 --- a/paddle/fluid/framework/details/ssa_graph.h +++ b/paddle/fluid/framework/details/ssa_graph.h @@ -25,12 +25,22 @@ namespace paddle { namespace framework { namespace details { +// A SSA graph used by parallel executor. struct SSAGraph { + // all variable in each devices. + // The outside vector is the device vector. Each element of this vector is a + // map from variable name to variables. The variables, who have the same name, + // will have a different version. The offset in the + // `std::vector>` is the version of varaibles. std::vector< std::unordered_map>>> vars_; + // aux variables to represent dependency. Useful to resolve data hazard. std::unordered_set> dep_vars_; + + // all operators. NOTE that even we use a vector here, the operators is + // unordered. std::vector> ops_; }; diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index be1f0460e45..64e5d93081e 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -48,6 +48,8 @@ class SSAGraphBuilder { const platform::Place &place, size_t place_offset); + // Add an output variable (each_var_name, place, place_offset) to op_handle, + // which belongs to graph static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, const std::string &each_var_name, const platform::Place &place, size_t place_offset); -- GitLab