diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 6f5d4471a97cc4efc73b9df68040ab9eccde0b1c..da0272d48e9c9f3f5e9e332fd03ff35ed65e42db 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -25,6 +25,7 @@ #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" +#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/scope.h" @@ -66,11 +67,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( } } -void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, - const OpDesc &op, +void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, const OpDesc &op, size_t place_id) const { auto p = places_[place_id]; - auto *op_handle = result->ops_.back().get(); + auto *op_handle = + boost::any_cast(result->attrs["ops"])->back().get(); op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); @@ -169,18 +170,21 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { + std::unique_ptr graph(new Graph); for (auto *var : program.Block(0).AllVars()) { all_vars_.emplace(var->Name(), var); } - auto graph = new SSAGraph(); - SSAGraph &result = *graph; + Graph &result = *graph; std::unordered_set og_has_been_broadcast; // We cannot invoke resize. It is a bug of GCC 4.8 - result.vars_ = std::vector< + result.attrs["vars"] = new std::vector< std::unordered_map>>>( places_.size()); + result.attrs["dep_vars"] = + new std::unordered_set>(); + result.attrs["ops"] = new std::vector>(); // find send/recv vars so that we can place the distributed training // realted op in the place 0 @@ -303,7 +307,15 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( */ AddOutputToLeafOps(&result); - return std::unique_ptr(graph); + std::unique_ptr ssa_graph(new SSAGraph); + ssa_graph->vars_ = + std::move(*boost::any_cast(graph->attrs["vars"])); + ssa_graph->ops_ = + std::move(*boost::any_cast(graph->attrs["ops"])); + ssa_graph->dep_vars_ = + std::move(*boost::any_cast(graph->attrs["dep_vars"])); + + return std::move(ssa_graph); } bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { @@ -327,7 +339,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext( #endif } -void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, +void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, const std::string &p_name, size_t src_dev_id) const { #ifdef PADDLE_WITH_CUDA @@ -336,42 +348,50 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, auto *op_handle = new BroadcastOpHandle(local_scopes_, places_); #endif - result->ops_.emplace_back(op_handle); - auto *in = result->vars_.at(src_dev_id).at(p_name).back().get(); + boost::any_cast(result->attrs["ops"])->emplace_back(op_handle); + auto *in = boost::any_cast(result->attrs["vars"]) + ->at(src_dev_id) + .at(p_name) + .back() + .get(); op_handle->AddInput(in); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->vars_.at(i).at(p_name); + auto &vars = + boost::any_cast(result->attrs["vars"])->at(i).at(p_name); auto *out_var = new VarHandle(vars.size(), i, p_name, p); vars.emplace_back(out_var); op_handle->AddOutput(out_var); } } -void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, +void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, const OpDesc &op, int dev_id) const { - result->ops_.emplace_back( - new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id])); + boost::any_cast(result->attrs["ops"]) + ->emplace_back( + new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id])); CreateOpHandleIOs(result, op, dev_id); } -void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, +void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA - result->ops_.emplace_back( - new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); + boost::any_cast(result->attrs["ops"]) + ->emplace_back(new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); #else - result->ops_.emplace_back(new AllReduceOpHandle(local_scopes_, places_)); + boost::any_cast(result->attrs["ops"]) + ->emplace_back(new AllReduceOpHandle(local_scopes_, places_)); #endif - auto *op_handle = result->ops_.back().get(); + auto *op_handle = + boost::any_cast(result->attrs["ops"])->back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->vars_[i][og]; + auto &vars = (*boost::any_cast(result->attrs["vars"]))[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); @@ -383,19 +403,23 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, } void MultiDevSSAGraphBuilder::InsertDataBalanceOp( - SSAGraph *result, const std::vector &datas) const { + Graph *result, const std::vector &datas) const { #ifdef PADDLE_WITH_CUDA - result->ops_.emplace_back( - new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); + boost::any_cast(result->attrs["ops"]) + ->emplace_back( + new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); #else - result->ops_.emplace_back(new DataBalanceOpHandle(local_scopes_, places_)); + boost::any_cast(result->attrs["ops"]) + ->emplace_back(new DataBalanceOpHandle(local_scopes_, places_)); #endif - auto *op_handle = result->ops_.back().get(); + auto *op_handle = + boost::any_cast(result->attrs["ops"])->back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); for (const std::string &d_name : datas) { - auto &vars = result->vars_[i][d_name]; + auto &vars = + (*boost::any_cast(result->attrs["vars"]))[i][d_name]; PADDLE_ENFORCE(!vars.empty()); op_handle->AddInput(vars.back().get()); auto var = new VarHandle(vars.size(), i, d_name, p); @@ -441,7 +465,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const { return got == var_name_on_devices_.end() ? -1 : got->second; } -void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { +void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { for (size_t i = 0; i < places_.size(); ++i) { // Insert ScaleCost OpHandle #ifdef PADDLE_WITH_CUDA @@ -456,7 +480,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { auto *op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i], places_[i], communication_dev_ctx); - result->ops_.emplace_back(op_handle); + boost::any_cast(result->attrs["ops"])->emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale // factor. So it does not depend on any other operators. @@ -469,37 +493,41 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { } } -void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result, +void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, const OpDesc &op, size_t num_places) const { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; - result->ops_.emplace_back(new ComputationOpHandle(op, s, p)); + boost::any_cast(result->attrs["ops"]) + ->emplace_back(new ComputationOpHandle(op, s, p)); CreateOpHandleIOs(result, op, scope_idx); } } -VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, +VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, const std::string &og, int dst_dev_id) const { #ifdef PADDLE_WITH_CUDA - result->ops_.emplace_back( - new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); + boost::any_cast(result->attrs["ops"]) + ->emplace_back(new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); #else - result->ops_.emplace_back(new ReduceOpHandle(local_scopes_, places_)); + boost::any_cast(result->attrs["ops"]) + ->emplace_back(new ReduceOpHandle(local_scopes_, places_)); #endif - auto *op_handle = result->ops_.back().get(); + auto *op_handle = + boost::any_cast(result->attrs["ops"])->back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->vars_[i][og]; + auto &vars = (*boost::any_cast(result->attrs["vars"]))[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); } - auto &vars = result->vars_[dst_dev_id][og]; + auto &vars = + (*boost::any_cast(result->attrs["vars"]))[dst_dev_id][og]; auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]); vars.emplace_back(var); op_handle->AddOutput(var); @@ -508,19 +536,20 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, // Find the first occurence of `prev_op_name` and make current `op` depend // on it. -void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, +void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { - for (auto &prev_op : result->ops_) { + for (auto &prev_op : (*boost::any_cast(result->attrs["ops"]))) { if (prev_op->Name() == prev_op_name) { auto *dep_var = new DummyVarHandle(); prev_op->AddOutput(dep_var); - result->dep_vars_.emplace(dep_var); + boost::any_cast(result->attrs["dep_vars"]) + ->emplace(dep_var); op->AddInput(dep_var); } } } -void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, +void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, const OpDesc &op) const { int op_dev_id = -1; if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") { @@ -550,12 +579,14 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, CreateComputationalOp(result, op, op_dev_id); if (op.Type() == "concat") { - ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); + ConnectOp(result, + boost::any_cast(result->attrs["ops"])->back().get(), + "fetch_barrier"); } } // Create RPC related op handles that connects its in ops and out ops. -void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, +void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, const OpDesc &op) const { int op_dev_id = -1; if (op.Type() == "send") { @@ -584,15 +615,22 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", op.Type()); - result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id], - op.Type(), places_[op_dev_id])); + boost::any_cast(result->attrs["ops"]) + ->emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id], op.Type(), + places_[op_dev_id])); if (op.Type() == "send_barrier") { - ConnectOp(result, result->ops_.back().get(), "send"); + ConnectOp(result, + boost::any_cast(result->attrs["ops"])->back().get(), + "send"); } else if (op.Type() == "recv") { - ConnectOp(result, result->ops_.back().get(), "send_barrier"); + ConnectOp(result, + boost::any_cast(result->attrs["ops"])->back().get(), + "send_barrier"); } else if (op.Type() == "fetch_barrier") { - ConnectOp(result, result->ops_.back().get(), "recv"); + ConnectOp(result, + boost::any_cast(result->attrs["ops"])->back().get(), + "recv"); } else if (op.Type() == "send") { // do nothing } else { diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index a964e024885e56693224a6199e00ff30beaa1df4..3d7642f522e941571c5bf05142ab327195aa73d7 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -19,6 +19,7 @@ #include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h" +#include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace platform { @@ -50,7 +51,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { int GetVarDeviceID(const std::string &varname) const override; private: - void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, + void CreateOpHandleIOs(Graph *result, const OpDesc &op, size_t device_id) const; private: @@ -65,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { bool IsScaleLossOp(const OpDesc &op) const; - void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; - void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const; + void CreateRPCOp(Graph *result, const OpDesc &op) const; + void CreateDistTrainOp(Graph *result, const OpDesc &op) const; /** * Is this operator as the end-point operator before/after send operator. @@ -81,17 +82,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { std::vector FindDistTrainRecvVars( const ProgramDesc &program) const; - void ConnectOp(SSAGraph *result, OpHandleBase *op, + void ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const; - void CreateComputationalOps(SSAGraph *result, const OpDesc &op, + void CreateComputationalOps(Graph *result, const OpDesc &op, size_t num_places) const; - void CreateScaleLossGradOp(SSAGraph *result) const; - VarHandle *CreateReduceOp(SSAGraph *result, const std::string &og, + void CreateScaleLossGradOp(Graph *result) const; + VarHandle *CreateReduceOp(Graph *result, const std::string &og, int dst_dev_id) const; - void CreateComputationalOp(SSAGraph *result, const OpDesc &op, - int dev_id) const; + void CreateComputationalOp(Graph *result, const OpDesc &op, int dev_id) const; bool IsParameterGradientOnce( const std::string &og, @@ -99,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { int GetOpDeviceID(const OpDesc &op) const; - void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; + void InsertAllReduceOp(Graph *result, const std::string &og) const; - void InsertDataBalanceOp(SSAGraph *result, + void InsertDataBalanceOp(Graph *result, const std::vector &datas) const; - void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, + void CreateBroadcastOp(Graph *result, const std::string &p_name, size_t src_dev_id) const; bool IsSparseGradient(const std::string &og) const; diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 88a21f48879a15450051ad94ed76e1c48bf23014..2c0873cc87805bb0c9eefbc8e17c8c10cdd83fff 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -17,8 +17,8 @@ namespace paddle { namespace framework { namespace details { -void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { - for (auto &var_map : graph->vars_) { +void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { + for (auto &var_map : *boost::any_cast(graph->attrs["vars"])) { for (auto &name_pair : var_map) { if (name_pair.second.size() <= 1) { continue; @@ -40,7 +40,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { auto *dep_var = new DummyVarHandle(); read_op->AddOutput(dep_var); write_op->AddInput(dep_var); - graph->dep_vars_.emplace(dep_var); + boost::any_cast(graph->attrs["dep_vars"]) + ->emplace(dep_var); } } } @@ -48,9 +49,10 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { } VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( - SSAGraph *graph, const std::string &each_var_name, + Graph *graph, const std::string &each_var_name, const platform::Place &place, size_t place_offset) { - auto &var_holders = graph->vars_[place_offset]; + auto &var_holders = + (*boost::any_cast(graph->attrs["vars"]))[place_offset]; auto &var_holder = var_holders[each_var_name]; VarHandle *var = nullptr; if (var_holder.empty()) { @@ -62,24 +64,29 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( return var; } -void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, +void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, const std::string &each_var_name, const platform::Place &place, size_t place_offset) { - auto &vars = graph->vars_[place_offset][each_var_name]; + auto &vars = + (*boost::any_cast(graph->attrs["vars"]))[place_offset] + [each_var_name]; size_t version = vars.size(); auto var = new VarHandle(version, place_offset, each_var_name, place); vars.emplace_back(var); op_handle->AddOutput(var); } -void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { - for (auto &op : graph->ops_) { +void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { + GraphOps &all_ops = *boost::any_cast(graph->attrs["ops"]); + + for (auto &op : all_ops) { if (!op->Outputs().empty()) { continue; } auto *dummy_leaf = new DummyVarHandle(); - graph->dep_vars_.emplace(dummy_leaf); + boost::any_cast(graph->attrs["dep_vars"]) + ->emplace(dummy_leaf); op->AddOutput(dummy_leaf); } } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 18612c3c1b62cf4c2ebdc221c301c59ec81c2da7..d5aabb9fd1ac112be31739fd98266fc52313bfcd 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -16,15 +16,24 @@ #include #include +#include #include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/place.h" +#include "paddle/fluid/framework/ir/graph.h" + namespace paddle { namespace framework { namespace details { +typedef std::vector< + std::unordered_map>>> + GraphVars; +typedef std::unordered_set> GraphDepVars; +typedef std::vector> GraphOps; + class SSAGraphBuilder { public: SSAGraphBuilder() {} @@ -42,20 +51,20 @@ class SSAGraphBuilder { * * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) */ - static void PolishGraphToSupportDataHazards(SSAGraph *graph); + static void PolishGraphToSupportDataHazards(Graph *graph); - static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph, + static VarHandle *CreateOrGetLatestVarHandle(Graph *graph, const std::string &each_var_name, const platform::Place &place, size_t place_offset); // Add an output variable (each_var_name, place, place_offset) to op_handle, // which belongs to graph - static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, + static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle, const std::string &each_var_name, const platform::Place &place, size_t place_offset); - static void AddOutputToLeafOps(SSAGraph *graph); + static void AddOutputToLeafOps(Graph *graph); }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index d1805d7434270096bdf6eb48f090c62ca30e16ba..8996c2d43aef90a469e80c1fb92d5a9c5b042e7c 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -27,7 +27,7 @@ namespace framework { class Graph { public: - std::map> attrs; + std::map attrs; std::vector inputs; std::vector outputs; diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 6f4bb172c6ed530ff09fc8953ef5c57825efbf79..087ebb870984a03b16ec4fcdec3f7c9aa1a0480a 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -14,6 +14,27 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/program_desc.h" + namespace paddle { -namespace framework {} // namespace framework +namespace framework { + +class Pass { + public: + Pass() = default; + virtual ~Pass() {} + virtual std::unique_ptr Apply(std::unique_ptr graph) { + return std::move(graph); + } +}; + +std::unique_ptr ProgramToGraph(const ProgramDesc& program) { + std::unique_ptr g(new Graph); + + return std::move(g); +} + +} // namespace framework } // namespace paddle