diff --git a/doc/fluid/design/ir/draft.md b/doc/fluid/design/ir/draft.md index a33b5a9c9312c93247a1e1f3431061a5aad6c884..c29337cba1fe859e4968cb800e4e7d9ff6a54d31 100644 --- a/doc/fluid/design/ir/draft.md +++ b/doc/fluid/design/ir/draft.md @@ -64,6 +64,41 @@ can also contain other things that describe some properties of the `Graph` or `Graph` nodes. `Attribute` can be passed across `Pass`. However, it should be used with care. +```cpp +class Graph { + public: + explicit Graph(const ProgramDesc &program); + + bool Has(const std::string &attr_name) const; + + template + AttrType &Get(const std::string &attr_name) const; + + template + void Set(const std::string &attr_name, AttrType *attr); + const std::unordered_set &Nodes() const; + + // Create a normal variable with non-null VarDesc. + ir::Node *CreateVarNode(VarDesc *var_desc); + + // Create a normal runnable operator with OpDesc. + ir::Node *CreateOpNode(OpDesc *op_desc); + + // Create a control dependency var that connects 2 operations. The + // var doesn't hold any data. Other than that, it's no different from + // other var, considering dependency analysis. + ir::Node *CreateControlDepVar(); + + // A more free style way of creating a graph node. Mostly use for test + // or "copy" from another node. Avoid using it if possible. + ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type); + + // Clear all node information of the graph and return the ownership of the + // nodes. + std::vector> ReleaseNodes(); +}; +``` + #### Pass `Pass` represents a transformation of `Graph`. Its input @@ -71,6 +106,54 @@ is a `Graph` and its output is also a `Graph`. For example, a `Pass` can simply print out the `Graph`. A `Pass` can also fuse some `Graph`'s `Node`s. +```cpp +class Pass { + public: + + std::unique_ptr Apply(std::unique_ptr graph) const { + // Some correctness check. + auto new_graph = ApplyImpl(std::move(graph)); + // Some correctness check. + return new_graph; + } + + // Get a reference to the attributed previously set. + template + AttrType &Get(const std::string &attr_name) const; + + // Set a pointer to the attribute. Pass takes ownership of the attribute. + template + void Set(const std::string &attr_name, AttrType *attr) ; + + // Set a pointer to the attribute. Pass doesn't take ownership. Caller + // should delete the attribute. + template + void SetNotOwned(const std::string &attr_name, AttrType *attr); + + protected: + virtual std::unique_ptr ApplyImpl(std::unique_ptr graph) const = 0; +}; + +// In my_pass.cc +class MyPass : public Pass { + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const override { + // do something. + return graph; + } +} +REGISTER_PASS(my_pass, MyPass) +.RequirePassAttr("places") +.RequireGraphAttr("dep_vars"); + + +// To use the pass. +auto my_pass = ir::PassRegistry::Instance().Get("my_pass"); +graph = my_pass->Apply(std::move(graph)); +// Note: to force link my_pass.cc, in the code: +USE_PASS(my_pass); +``` + #### Optimize `Optimize` contains a series of `Pass` with defined order. @@ -86,4 +169,17 @@ maintaining the original modeling logic. * Graph is transformed from raw model logic to a form that is efficient to execute. -Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor +``` +// Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor +auto graph = Graph(program); +graph = PassRegistry::Instance().Get("op_fuse_pass").Apply(std::move(grah)); +// For more complex Pass, Optimize Process can provide Pass attributes. +auto mem_opt_pass = PassRegistry::Instance().Get("memory_optimization_pass"); +mem_opt_pass.SetNotOwned("optimize_level", 1); +mem_opt_pass->Apply(std::move(graph)); +graph = PassRegistry::Instance().Get("multi_device_pass").Apply(std::move(grah)); +graph = PassRegistry::Instance().Get("multi_device_check_pass").Apply(std::move(grah)); +Executor exe; +exe.Run(graph); + +``` diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index df2a7bf90d9be480c514d9dc70571c7f56fd8db2..139411f3e0d945f9265d19a28487c05d06722d69 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -99,7 +99,7 @@ else() endif() -cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph) +cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 9df7df1f42886d40210b16aa2ae5823e3310bfe7..5d652d37307d0a55ffee14930ae180dcd3e27841 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -31,9 +31,6 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) - -cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) - cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index b7b67916205689753bc3f9fe844945ee3e78eeb4..5ca2ed8f96244a11925dfa6af8e48458cf334ecd 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -34,30 +34,22 @@ namespace paddle { namespace framework { namespace details { +static const char kLossVarName[] = "loss_var_name"; +static const char kPlaces[] = "places"; +static const char kParams[] = "params"; +static const char kLocalScopes[] = "local_scopes"; +static const char kStrategy[] = "strategy"; + +void MultiDevSSAGraphBuilder::Init() const { + loss_var_name_ = Get(kLossVarName); + places_ = Get>(kPlaces); + local_scopes_ = Get>(kLocalScopes); + strategy_ = Get(kStrategy); #ifdef PADDLE_WITH_CUDA -MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( - const std::vector &places, - const std::string &loss_var_name, - const std::unordered_set ¶ms, - const std::vector &local_scopes, - platform::NCCLContextMap *nccl_ctxs, const BuildStrategy &strategy) - : loss_var_name_(loss_var_name), - places_(places), - local_scopes_(local_scopes), - nccl_ctxs_(nccl_ctxs), - strategy_(strategy) { -#else -MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( - const std::vector &places, - const std::string &loss_var_name, - const std::unordered_set ¶ms, - const std::vector &local_scopes, const BuildStrategy &strategy) - : loss_var_name_(loss_var_name), - places_(places), - local_scopes_(local_scopes), - strategy_(strategy) { + nccl_ctxs_ = &Get("nccl_ctxs"); #endif - for (auto &p : params) { + + for (auto &p : Get>(kParams)) { grad_names_.insert(GradVarName(p)); } balance_vars_.resize(places_.size(), 0); @@ -72,7 +64,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, ir::Node *node, size_t place_id) const { auto p = places_[place_id]; - auto *op_handle = result->Get("ops").back().get(); + auto *op_handle = result->Get(kGraphOps).back().get(); op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); @@ -239,8 +231,9 @@ std::vector SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { return sorted_ret; } -std::unique_ptr MultiDevSSAGraphBuilder::Apply( +std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( std::unique_ptr graph) const { + Init(); // Give the topology sort order and rebuild the graph structure. std::vector sorted_ops = SortOpsAndDelayOptimizeOp(*graph); auto nodes = graph->ReleaseNodes(); @@ -254,9 +247,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( std::unordered_set og_has_been_broadcast; // We cannot invoke resize. It is a bug of GCC 4.8 - result.Set("vars", new GraphVars(places_.size())); - result.Set("dep_vars", new GraphDepVars); - result.Set("ops", new GraphOps); + result.Set(kGraphVars, new GraphVars(places_.size())); + result.Set(kGraphDepVars, new GraphDepVars); + result.Set(kGraphOps, new GraphOps); + result.Set(kShardedVarDevice, new ShardedVarDevice); // find send/recv vars so that we can place the distributed training // related op in the place 0 @@ -289,11 +283,12 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( // the block. is_forwarding = false; } else { - int op_dev_id = GetOpDeviceID(node); + int op_dev_id = GetOpDeviceID(result, node); if (op_dev_id != -1) { // This op only runs on one specific device. CreateComputationalOp(&result, node, op_dev_id); for (ir::Node *n : node->outputs) { - var_name_on_devices_.emplace(n->Name(), op_dev_id); + graph->Get(kShardedVarDevice) + .emplace(n->Name(), op_dev_id); } } else { // This op runs on all devices, and its output may have parameter's @@ -330,7 +325,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( case BuildStrategy::ReduceStrategy::kReduce: cur_device_id = GetAppropriateDeviceID({g_name}); CreateReduceOp(&result, g_name, cur_device_id); - var_name_on_devices_.emplace(g_name, cur_device_id); + graph->Get(kShardedVarDevice) + .emplace(g_name, cur_device_id); bcast_var_name_set[cur_device_id].emplace(p_name); break; case BuildStrategy::ReduceStrategy::kAllReduce: @@ -416,16 +412,16 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation), local_scopes_, places_); #endif - result->Get("ops").emplace_back(op_handle); + result->Get(kGraphOps).emplace_back(op_handle); auto *in = - result->Get("vars").at(src_dev_id).at(p_name).back().get(); + result->Get(kGraphVars).at(src_dev_id).at(p_name).back().get(); op_handle->AddInput(in); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->Get("vars").at(i).at(p_name); + auto &vars = result->Get(kGraphVars).at(i).at(p_name); auto *out_var = new VarHandle( result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(), i, p_name, p); @@ -437,7 +433,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, ir::Node *node, int dev_id) const { - result->Get("ops").emplace_back( + result->Get(kGraphOps).emplace_back( new ComputationOpHandle(result->CreateOpNode(node->Op()), local_scopes_[dev_id], places_[dev_id])); CreateOpHandleIOs(result, node, dev_id); @@ -446,20 +442,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back(new AllReduceOpHandle( + result->Get(kGraphOps).emplace_back(new AllReduceOpHandle( result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), local_scopes_, places_, nccl_ctxs_)); #else - result->Get("ops").emplace_back(new AllReduceOpHandle( + result->Get(kGraphOps).emplace_back(new AllReduceOpHandle( result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), local_scopes_, places_)); #endif - auto *op_handle = result->Get("ops").back().get(); + auto *op_handle = result->Get(kGraphOps).back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->Get("vars")[i][og]; + auto &vars = result->Get(kGraphVars)[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); @@ -475,20 +471,20 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, void MultiDevSSAGraphBuilder::InsertDataBalanceOp( ir::Graph *result, const std::vector &datas) const { #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back(new DataBalanceOpHandle( + result->Get(kGraphOps).emplace_back(new DataBalanceOpHandle( result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), local_scopes_, places_, nccl_ctxs_)); #else - result->Get("ops").emplace_back(new DataBalanceOpHandle( + result->Get(kGraphOps).emplace_back(new DataBalanceOpHandle( result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), local_scopes_, places_)); #endif - auto *op_handle = result->Get("ops").back().get(); + auto *op_handle = result->Get(kGraphOps).back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); for (const std::string &d_name : datas) { - auto &vars = result->Get("vars")[i][d_name]; + auto &vars = result->Get(kGraphVars)[i][d_name]; PADDLE_ENFORCE(!vars.empty()); op_handle->AddInput(vars.back().get()); auto var = new VarHandle( @@ -512,7 +508,8 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( return is_pg_once; } -int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { +int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph, + ir::Node *node) const { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { return -1; } @@ -525,15 +522,17 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); PADDLE_ENFORCE_EQ(param_grad.size(), 2U); - int dev_id = GetVarDeviceID(param_grad[1]); + int dev_id = GetVarDeviceID(graph, param_grad[1]); PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]", node->Op()->Type(), param_grad[0], param_grad[1]); return dev_id; } -int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const { - auto got = var_name_on_devices_.find(varname); - return got == var_name_on_devices_.end() ? -1 : got->second; +int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, + const std::string &varname) const { + auto &sharded_var_device = graph.Get(kShardedVarDevice); + auto got = sharded_var_device.find(varname); + return got == sharded_var_device.end() ? -1 : got->second; } void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const { @@ -551,7 +550,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const { result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), local_scopes_.size(), local_scopes_[i], places_[i], communication_dev_ctx); - result->Get("ops").emplace_back(op_handle); + result->Get(kGraphOps).emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale // factor. So it does not depend on any other operators. @@ -572,7 +571,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; - result->Get("ops").emplace_back( + result->Get(kGraphOps).emplace_back( new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p)); CreateOpHandleIOs(result, node, scope_idx); } @@ -582,25 +581,25 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, const std::string &og, int dst_dev_id) const { #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back(new ReduceOpHandle( + result->Get(kGraphOps).emplace_back(new ReduceOpHandle( result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), local_scopes_, places_, nccl_ctxs_)); #else - result->Get("ops").emplace_back(new ReduceOpHandle( + result->Get(kGraphOps).emplace_back(new ReduceOpHandle( result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), local_scopes_, places_)); #endif - auto *op_handle = result->Get("ops").back().get(); + auto *op_handle = result->Get(kGraphOps).back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->Get("vars")[i][og]; + auto &vars = result->Get(kGraphVars)[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); } - auto &vars = result->Get("vars")[dst_dev_id][og]; + auto &vars = result->Get(kGraphVars)[dst_dev_id][og]; auto var = new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), vars.size(), dst_dev_id, og, places_[dst_dev_id]); @@ -613,11 +612,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, // on it. void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { - for (auto &prev_op : result->Get("ops")) { + for (auto &prev_op : result->Get(kGraphOps)) { if (prev_op->Name() == prev_op_name) { auto *dep_var = new DummyVarHandle(result->CreateControlDepVar()); prev_op->AddOutput(dep_var); - result->Get("dep_vars").emplace(dep_var); + result->Get(kGraphDepVars).emplace(dep_var); op->AddInput(dep_var); } } @@ -638,20 +637,23 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, if (node->Op()->Type() == "split_byref" || node->Op()->Type() == "split_selected_rows") { // TODO(paddle-dev): getting the first var is not safe. - op_dev_id = GetVarDeviceID(input_var_names[0]); + op_dev_id = GetVarDeviceID(*result, input_var_names[0]); if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { op_dev_id = GetAppropriateDeviceID(input_var_names); for (auto &varname : input_var_names) { - var_name_on_devices_.emplace(varname, op_dev_id); + result->Get(kShardedVarDevice) + .emplace(varname, op_dev_id); } } for (auto &varname : output_var_names) { - var_name_on_devices_.emplace(varname, op_dev_id); + result->Get(kShardedVarDevice) + .emplace(varname, op_dev_id); } } else if (node->Op()->Type() == "concat") { - op_dev_id = GetVarDeviceID(input_var_names[0]); + op_dev_id = GetVarDeviceID(*result, input_var_names[0]); for (auto &varname : output_var_names) { - var_name_on_devices_.emplace(varname, op_dev_id); + result->Get(kShardedVarDevice) + .emplace(varname, op_dev_id); } } else { PADDLE_ENFORCE( @@ -665,7 +667,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, CreateComputationalOp(result, node, op_dev_id); if (node->Op()->Type() == "concat") { - ConnectOp(result, result->Get("ops").back().get(), + ConnectOp(result, result->Get(kGraphOps).back().get(), "fetch_barrier"); } } @@ -676,7 +678,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, int op_dev_id = -1; if (node->Op()->Type() == "send") { // TODO(paddle-dev): getting the first var is not safe. - op_dev_id = GetVarDeviceID(node->inputs[0]->Name()); + op_dev_id = GetVarDeviceID(*result, node->inputs[0]->Name()); PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]), "This hack no longer holds, please fix."); // the variable name which contains .block means it was splited by @@ -691,7 +693,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, } op_dev_id = GetAppropriateDeviceID(input_var_names); for (auto &varname : input_var_names) { - var_name_on_devices_.emplace(varname, op_dev_id); + result->Get(kShardedVarDevice) + .emplace(varname, op_dev_id); } } } else if (node->Op()->Type() == "recv") { @@ -701,7 +704,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, } op_dev_id = GetAppropriateDeviceID(output_var_names); for (auto &varname : output_var_names) { - var_name_on_devices_.emplace(varname, op_dev_id); + result->Get(kShardedVarDevice) + .emplace(varname, op_dev_id); } } else { // send_barrier and fetch_barrier op can be scheduled on device 0 @@ -711,18 +715,18 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", node->Op()->Type()); - result->Get("ops").emplace_back(new RPCOpHandle( + result->Get(kGraphOps).emplace_back(new RPCOpHandle( result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id], node->Op()->Type(), places_[op_dev_id])); // TODO(panyx0718): This might not be needed anymore. if (node->Op()->Type() == "send_barrier") { - ConnectOp(result, result->Get("ops").back().get(), "send"); + ConnectOp(result, result->Get(kGraphOps).back().get(), "send"); } else if (node->Op()->Type() == "recv") { - ConnectOp(result, result->Get("ops").back().get(), + ConnectOp(result, result->Get(kGraphOps).back().get(), "send_barrier"); } else if (node->Op()->Type() == "fetch_barrier") { - ConnectOp(result, result->Get("ops").back().get(), "recv"); + ConnectOp(result, result->Get(kGraphOps).back().get(), "recv"); } else if (node->Op()->Type() == "send") { // do nothing } else { @@ -744,3 +748,11 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { } // namespace details } // namespace framework } // namespace paddle + +REGISTER_PASS(multi_device_pass, + paddle::framework::details::MultiDevSSAGraphBuilder) + .RequirePassAttr(paddle::framework::details::kLossVarName) + .RequirePassAttr(paddle::framework::details::kPlaces) + .RequirePassAttr(paddle::framework::details::kParams) + .RequirePassAttr(paddle::framework::details::kLocalScopes) + .RequirePassAttr(paddle::framework::details::kStrategy); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 55076f227b5ab56d66b5053173c9e915da23b15f..099dbe5abef6458c4613c9f680440734f59cb6e2 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -31,39 +31,27 @@ class Scope; namespace details { class MultiDevSSAGraphBuilder : public SSAGraphBuilder { - public: -#ifdef PADDLE_WITH_CUDA - MultiDevSSAGraphBuilder(const std::vector &places, - const std::string &loss_var_name, - const std::unordered_set ¶ms, - const std::vector &local_scopes, - platform::NCCLContextMap *nccl_ctxs, - const BuildStrategy &strategy); -#else - MultiDevSSAGraphBuilder(const std::vector &places, - const std::string &loss_var_name, - const std::unordered_set ¶ms, - const std::vector &local_scopes, - const BuildStrategy &strategy); -#endif - std::unique_ptr Apply( + protected: + std::unique_ptr ApplyImpl( std::unique_ptr graph) const override; - int GetVarDeviceID(const std::string &varname) const override; private: void CreateOpHandleIOs(ir::Graph *result, ir::Node *node, size_t device_id) const; + void Init() const; private: - std::string loss_var_name_; - const std::vector &places_; - const std::vector &local_scopes_; - std::unordered_set grad_names_; + mutable std::string loss_var_name_; + mutable std::vector places_; + mutable std::vector local_scopes_; + mutable std::unordered_set grad_names_; #ifdef PADDLE_WITH_CUDA - platform::NCCLContextMap *nccl_ctxs_; + mutable platform::NCCLContextMap *nccl_ctxs_; #endif + int GetVarDeviceID(const ir::Graph &graph, const std::string &varname) const; + bool IsScaleLossOp(ir::Node *node) const; void CreateRPCOp(ir::Graph *result, ir::Node *node) const; @@ -97,7 +85,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::string &og, std::unordered_set *og_has_been_broadcast) const; - int GetOpDeviceID(ir::Node *node) const; + int GetOpDeviceID(const ir::Graph &graph, ir::Node *node) const; void InsertAllReduceOp(ir::Graph *result, const std::string &og) const; @@ -113,9 +101,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::vector &var_names) const; private: - BuildStrategy strategy_; + mutable BuildStrategy strategy_; mutable std::unordered_map all_vars_; - mutable std::unordered_map var_name_on_devices_; mutable std::vector balance_vars_; void SetCommunicationContext(OpHandleBase *op_handle, diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h index cbfbcb1c0cd24f16773f9633310166371600790c..1b188aec5995edb73835bcf5b851952db0f95f48 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h @@ -40,6 +40,9 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ExecutionStrategy strategy, std::vector local_scopes, std::vector var_infos, std::vector places, std::unique_ptr&& underlying_executor); + + const ir::Graph& Graph() const { return underlying_executor_->Graph(); } + FeedFetchList Run(const std::vector& fetch_tensors) override; private: diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 506e7eb35cd977869424223cb863dd64dbaa9d30..575532540a624afde5f6dab25b11e9eac93c6448 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -18,7 +18,7 @@ namespace paddle { namespace framework { namespace details { void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { - for (auto &var_map : graph->Get("vars")) { + for (auto &var_map : graph->Get(kGraphVars)) { for (auto &name_pair : var_map) { if (name_pair.second.size() <= 1) { continue; @@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); read_op->AddOutput(dep_var); write_op->AddInput(dep_var); - graph->Get("dep_vars").emplace(dep_var); + graph->Get(kGraphDepVars).emplace(dep_var); } } } @@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ir::Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset) { - auto &var_holders = graph->Get("vars")[place_offset]; + auto &var_holders = graph->Get(kGraphVars)[place_offset]; auto &var_holder = var_holders[node->Name()]; VarHandle *var = nullptr; if (var_holder.empty()) { @@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, ir::Node *new_node, const platform::Place &place, size_t place_offset) { - auto &vars = graph->Get("vars")[place_offset][new_node->Name()]; + auto &vars = + graph->Get(kGraphVars)[place_offset][new_node->Name()]; size_t version = vars.size(); auto var = new VarHandle(new_node, version, place_offset, new_node->Name(), place); @@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, } void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) { - for (auto &op : graph->Get("ops")) { + for (auto &op : graph->Get(kGraphOps)) { if (!op->Outputs().empty()) { continue; } auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); - graph->Get("dep_vars").emplace(dummy_leaf); + graph->Get(kGraphDepVars).emplace(dummy_leaf); op->AddOutput(dummy_leaf); } } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 2b4f31f2ff3444f909e3be5eb810ae6737e237b2..53a4ad003d51a27a044d7a142434545eca0d5965 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -39,21 +39,25 @@ namespace details { typedef std::vector< std::unordered_map>>> GraphVars; +const char kGraphVars[] = "vars"; // aux variables to represent dependency. Useful to resolve data hazard. typedef std::unordered_set> GraphDepVars; +const char kGraphDepVars[] = "dep_vars"; // all operators. NOTE that even we use a vector here, the operators is // unordered. typedef std::vector> GraphOps; +const char kGraphOps[] = "ops"; + +typedef std::unordered_map ShardedVarDevice; +const char kShardedVarDevice[] = "sharded_var_device"; class SSAGraphBuilder : public ir::Pass { public: SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {} - virtual int GetVarDeviceID(const std::string &var_name) const = 0; - DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); protected: diff --git a/paddle/fluid/framework/details/ssa_graph_builder_factory.cc b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc deleted file mode 100644 index b4b49d3de6da2e5fd7836668619e42d10bb6b35a..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/details/ssa_graph_builder_factory.cc +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" -#include -#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" -#include "paddle/fluid/framework/details/ssa_graph_checker.h" -#include "paddle/fluid/framework/details/ssa_graph_printer.h" - -namespace paddle { -namespace framework { -namespace details { -std::unique_ptr SSAGraphBuilderFactory::Create() { - std::unique_ptr res( -#ifdef PADDLE_WITH_CUDA - new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_, - local_scopes_, nccl_ctxs_, strategy_) -#else - new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_, - local_scopes_, strategy_) -#endif - ); // NOLINT - - if (!strategy_.debug_graphviz_path_.empty()) { - std::unique_ptr fout( - new std::ofstream(strategy_.debug_graphviz_path_)); - PADDLE_ENFORCE(fout->good()); - std::unique_ptr graphviz_printer( - new GraphvizSSAGraphPrinter()); - res.reset(new SSAGraghBuilderWithPrinter( - std::move(fout), std::move(graphviz_printer), std::move(res))); - } - res.reset(new SSAGraghBuilderWithChecker(std::move(res))); - - return res; -} -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder_factory.h b/paddle/fluid/framework/details/ssa_graph_builder_factory.h deleted file mode 100644 index 91a119de83ed3d1573803e48faf86c874eed98d6..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/details/ssa_graph_builder_factory.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include -#include -#include "paddle/fluid/framework/details/build_strategy.h" -#include "paddle/fluid/framework/details/ssa_graph_builder.h" -#include "paddle/fluid/platform/place.h" - -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/platform/nccl_helper.h" -#endif - -namespace paddle { -namespace framework { -class Scope; -namespace details { - -class SSAGraphBuilderFactory { - public: - SSAGraphBuilderFactory(const std::vector& places, - const std::string& loss_var_name, - const std::unordered_set& param_names, - const std::vector& local_scopes, - const BuildStrategy& strategy) - : places_(places), - loss_var_name_(loss_var_name), - param_names_(param_names), - local_scopes_(local_scopes), - strategy_(strategy) { -#ifdef PADDLE_WITH_CUDA - nccl_ctxs_ = nullptr; -#endif - } - -#ifdef PADDLE_WITH_CUDA - void SetNCCLContextMap(platform::NCCLContextMap* nccl_ctxs) { - nccl_ctxs_ = nccl_ctxs; - } -#endif - - std::unique_ptr Create(); - - private: - std::vector places_; - std::string loss_var_name_; - std::unordered_set param_names_; - std::vector local_scopes_; - BuildStrategy strategy_; - -#ifdef PADDLE_WITH_CUDA - platform::NCCLContextMap* nccl_ctxs_; -#endif -}; - -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_checker.cc b/paddle/fluid/framework/details/ssa_graph_checker.cc index 0438b096109a287366610d06ef2bd14c765a8f43..b9e1cda1f24810009bc74a7abdf0156f723a1755 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.cc +++ b/paddle/fluid/framework/details/ssa_graph_checker.cc @@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { } }; - for (auto &var_map : graph->Get("vars")) { + for (auto &var_map : graph->Get(kGraphVars)) { for (auto &name_pair : var_map) { for (auto &version_pair : name_pair.second) { insert_pending_var(version_pair.get()); @@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { } } - for (auto &var : graph->Get("dep_vars")) { + for (auto &var : graph->Get(kGraphDepVars)) { insert_pending_var(var.get()); } - for (auto &op : graph->Get("ops")) { + for (auto &op : graph->Get(kGraphOps)) { if (op->Inputs().empty()) { ready_ops.insert(op.get()); } else { @@ -85,3 +85,10 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { } // namespace details } // namespace framework } // namespace paddle + +REGISTER_PASS(multi_device_check_pass, + paddle::framework::details::SSAGraghBuilderWithChecker) + .RequireGraphAttr(paddle::framework::details::kGraphVars) + .RequireGraphAttr(paddle::framework::details::kGraphDepVars) + .RequireGraphAttr(paddle::framework::details::kGraphOps) + .RequireGraphAttr(paddle::framework::details::kShardedVarDevice); diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index 51ce6e5ecad755613551aa6525b5cfbe4a8933ef..0e861ecb236361992d9883e3bd0e679f7563b539 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -23,26 +23,14 @@ namespace framework { namespace details { class SSAGraghBuilderWithChecker : public SSAGraphBuilder { - public: - explicit SSAGraghBuilderWithChecker( - std::unique_ptr&& builder) - : builder_(std::move(builder)) {} - - std::unique_ptr Apply( + protected: + std::unique_ptr ApplyImpl( std::unique_ptr graph) const override { - auto new_graph = builder_->Apply(std::move(graph)); - PADDLE_ENFORCE(IsValidGraph(new_graph.get())); - return new_graph; - } - - int GetVarDeviceID(const std::string& var_name) const override { - return builder_->GetVarDeviceID(var_name); + PADDLE_ENFORCE(IsValidGraph(graph.get())); + return graph; } bool IsValidGraph(const ir::Graph* graph) const; - - private: - std::unique_ptr builder_; }; } // namespace details diff --git a/paddle/fluid/framework/details/ssa_graph_executor.h b/paddle/fluid/framework/details/ssa_graph_executor.h index 8815ec89b23bc874471eefde5fa855cd2a4bde1f..96fffb7d9430cd00b3823ada9fbe9a65a6bd718c 100644 --- a/paddle/fluid/framework/details/ssa_graph_executor.h +++ b/paddle/fluid/framework/details/ssa_graph_executor.h @@ -32,7 +32,9 @@ class SSAGraphExecutor { virtual ~SSAGraphExecutor(); - virtual FeedFetchList Run(const std::vector &fetch_tensors) = 0; + virtual const ir::Graph& Graph() const = 0; + + virtual FeedFetchList Run(const std::vector& fetch_tensors) = 0; }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/ssa_graph_printer.cc b/paddle/fluid/framework/details/ssa_graph_printer.cc index 20aab1464400aa9bb1bd6af11c06269c688a8308..ec3f31ab8d135efd2c77018e90cec46b25ca5e66 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.cc +++ b/paddle/fluid/framework/details/ssa_graph_printer.cc @@ -22,7 +22,7 @@ namespace details { template static inline void IterAllVar(const ir::Graph &graph, Callback callback) { - for (auto &each : graph.Get("vars")) { + for (auto &each : graph.Get(kGraphVars)) { for (auto &pair1 : each) { for (auto &pair2 : pair1.second) { callback(*pair2); @@ -30,7 +30,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) { } } - for (auto &var : graph.Get("dep_vars")) { + for (auto &var : graph.Get(kGraphDepVars)) { callback(*var); } } @@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, }); size_t op_id = 0; - for (auto &op : graph.Get("ops")) { + for (auto &op : graph.Get(kGraphOps)) { std::string op_name = "op_" + std::to_string(op_id++); sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" << std::endl; @@ -81,3 +81,6 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, } // namespace details } // namespace framework } // namespace paddle + +REGISTER_PASS(multi_device_print_pass, + paddle::framework::details::SSAGraghBuilderWithPrinter); diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index a77c1bad3f15bca9064ded860696eb68b033b090..5eafd1805c3102dbd3cdfa68ee1495631c182b51 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -14,7 +14,9 @@ #pragma once +#include #include +#include #include #include "paddle/fluid/framework/details/ssa_graph_builder.h" @@ -34,38 +36,15 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter { }; class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { - public: - SSAGraghBuilderWithPrinter(std::ostream& sout, - std::unique_ptr&& printer, - std::unique_ptr&& builder) - : printer_(std::move(printer)), - builder_(std::move(builder)), - stream_ref_(sout) {} - - SSAGraghBuilderWithPrinter(std::unique_ptr&& sout, - std::unique_ptr&& printer, - std::unique_ptr&& builder) - : printer_(std::move(printer)), - builder_(std::move(builder)), - stream_ptr_(std::move(sout)), - stream_ref_(*stream_ptr_) {} - - std::unique_ptr Apply( + protected: + std::unique_ptr ApplyImpl( std::unique_ptr graph) const override { - auto new_graph = builder_->Apply(std::move(graph)); - printer_->Print(*new_graph, stream_ref_); - return new_graph; + std::unique_ptr fout( + new std::ofstream(Get("debug_graphviz_path"))); + PADDLE_ENFORCE(fout->good()); + Get("graph_printer").Print(*graph, *fout); + return graph; } - - int GetVarDeviceID(const std::string& var_name) const override { - return builder_->GetVarDeviceID(var_name); - } - - private: - std::unique_ptr printer_; - std::unique_ptr builder_; - std::unique_ptr stream_ptr_; - std::ostream& stream_ref_; }; } // namespace details diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index c19f74476f9a1498a7d61f5faf204e9966aea155..eec405073377b2782d7636c08e6eb3a7bd41202d 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -45,18 +45,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( std::unordered_set delayed_ops; // Transform SSAGraph to pending_ops & pending_vars - for (auto &var_map : graph_->Get("vars")) { + for (auto &var_map : graph_->Get(details::kGraphVars)) { for (auto &name_pair : var_map) { for (auto &version_pair : name_pair.second) { InsertPendingVar(&pending_vars, &ready_vars, version_pair.get()); } } } - for (auto &var : graph_->Get("dep_vars")) { + for (auto &var : graph_->Get(details::kGraphDepVars)) { InsertPendingVar(&pending_vars, &ready_vars, var.get()); } - for (auto &op : graph_->Get("ops")) { + for (auto &op : graph_->Get(details::kGraphOps)) { if (op->Inputs().empty()) { // Special case, Op has no input. ready_ops.insert(op.get()); } else { @@ -162,7 +162,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( std::unordered_map> fetched_vars; for (auto &fetch_var_name : fetch_tensors) { - for (auto &var_map : graph_->Get("vars")) { + for (auto &var_map : graph_->Get(details::kGraphVars)) { auto it = var_map.find(fetch_var_name); if (it != var_map.end()) { fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 3d67daa45e20fdea52689684397ad01f2f4cd783..82d6b5272aba161bb19067ebef054bc4bbb8701c 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -42,6 +42,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { const std::vector &places, std::unique_ptr &&graph); + const ir::Graph &Graph() const { return *graph_; } // Run a SSAGraph by a thread pool // Use topological sort algorithm FeedFetchList Run(const std::vector &fetch_tensors) override; diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 6447452ae58344273fe569c91168c7c95a901c8d..bf7d76a8a6e173e648cea5aaba9b7202d787173b 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -1,6 +1,9 @@ cc_library(node SRCS node.cc DEPS proto_desc) cc_library(graph SRCS graph.cc DEPS node) cc_library(graph_helper SRCS graph_helper.cc DEPS graph) -cc_library(pass SRCS pass.cc DEPS graph node) -cc_test(graph_test SRCS graph_test.cc DEPS graph op_registry) -cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph_helper op_registry) +cc_library(pass SRCS pass.cc DEPS graph node graph_helper) +cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper) + +cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) +cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) +cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 4f59ec82a7d1217621c95d9a4a433a9af43e95da..c9d55fbf525a1a476ac469e8e57462169a7db2da 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -40,14 +40,21 @@ class Graph { attr_dels_.clear(); } + bool Has(const std::string &attr_name) const { + return attrs_.find(attr_name) != attrs_.end(); + } + template AttrType &Get(const std::string &attr_name) const { + PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.", + attr_name); return *boost::any_cast(attrs_.at(attr_name)); } template void Set(const std::string &attr_name, AttrType *attr) { - PADDLE_ENFORCE(attrs_.count(attr_name) == 0); + PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the graph", + attr_name); attrs_[attr_name] = attr; attr_dels_[attr_name] = [attr, attr_name]() { VLOG(3) << "deleting " << attr_name; diff --git a/paddle/fluid/framework/ir/graph_viz_pass.cc b/paddle/fluid/framework/ir/graph_viz_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..8cb812d1388bf74d173a4dc7a99561e730f8e95a --- /dev/null +++ b/paddle/fluid/framework/ir/graph_viz_pass.cc @@ -0,0 +1,72 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/framework/ir/graph_viz_pass.h" + +namespace paddle { +namespace framework { +namespace ir { +static const char kGraphVizPath[] = "graph_viz_path"; + +std::unique_ptr GraphVizPass::ApplyImpl( + std::unique_ptr graph) const { + const std::string graph_viz_path = Get(kGraphVizPath); + std::unique_ptr fout(new std::ofstream(graph_viz_path)); + PADDLE_ENFORCE(fout->good()); + std::ostream& sout = *fout; + + size_t var_id = 0; + std::unordered_map vars; + + sout << "digraph G {\n"; + + for (const ir::Node* n : graph->Nodes()) { + if (n->NodeType() != ir::Node::Type::kVariable) continue; + size_t cur_var_id = var_id++; + vars[n] = cur_var_id; + + sout << "var_" << cur_var_id << " [label=\"" << n->Name() << "\"]" + << std::endl; + } + + size_t op_id = 0; + for (const ir::Node* n : graph->Nodes()) { + if (n->NodeType() != ir::Node::Type::kOperation) continue; + std::string op_name = "op_" + std::to_string(op_id++); + sout << op_name << " [label=\"" << n->Name() << "\", shape=rect]" + << std::endl; + for (auto in : n->inputs) { + std::string var_name = "var_" + std::to_string(vars[in]); + sout << var_name << " -> " << op_name << std::endl; + } + + for (auto out : n->outputs) { + std::string var_name = "var_" + std::to_string(vars[out]); + sout << op_name << " -> " << var_name << std::endl; + } + } + + sout << "}\n"; + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass) + .RequirePassAttr(paddle::framework::ir::kGraphVizPath); diff --git a/paddle/fluid/framework/ir/graph_viz_pass.h b/paddle/fluid/framework/ir/graph_viz_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..1fd8c8a26e9581ccf605d4271a49ec2e90d8b997 --- /dev/null +++ b/paddle/fluid/framework/ir/graph_viz_pass.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class GraphVizPass : public Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index c05d7d0bb54c8ba5938e08f7e8dace8f607d7b89..d7158eba62686be57499df697466797e4034ea8f 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -13,7 +13,34 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/graph_helper.h" namespace paddle { -namespace framework {} // namespace framework +namespace framework { +namespace ir { +std::unique_ptr Pass::Apply(std::unique_ptr graph) const { + PADDLE_ENFORCE(!applied_, "Pass can only Apply() once."); + PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty."); + for (const std::string& attr : required_pass_attrs_) { + PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(), + "Required pass atrribute %s not set.", attr); + } + for (const std::string& attr : required_graph_attrs_) { + PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not set.", + attr); + } + auto applied_graph = ApplyImpl(std::move(graph)); + // TODO(panyx0718): Add more verifications. + PADDLE_ENFORCE(!HasCircle(*applied_graph), + "Illegal Pass. Generated graph shouldn't has cycle."); + applied_ = true; + return applied_graph; +} + +PassRegistry& PassRegistry::Instance() { + static PassRegistry g_pass_info_map; + return g_pass_info_map; +} +} // namespace ir +} // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index f52ba788d55ddb9ed27baa3f6ff0a97e52370fe0..0f14083d259172f5b5f1ed80c7d38312d711beb5 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -14,21 +14,187 @@ limitations under the License. */ #pragma once +#include +#include +#include + #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/platform/variant.h" namespace paddle { namespace framework { namespace ir { +template +struct PassRegistrar; class Pass { public: Pass() = default; - virtual ~Pass() {} + virtual ~Pass() { + for (auto &attr : attrs_) { + if (attr_dels_.find(attr.first) != attr_dels_.end()) { + attr_dels_[attr.first](); + } + } + attrs_.clear(); + attr_dels_.clear(); + } + + std::unique_ptr Apply(std::unique_ptr graph) const; + + // Get a reference to the attributed previously set. + template + AttrType &Get(const std::string &attr_name) const { + PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(), + "%s attr not registered for pass.", attr_name); + return *boost::any_cast(attrs_.at(attr_name)); + } + + // Set a pointer to the attribute. Pass takes ownership of the attribute. + template + void Set(const std::string &attr_name, AttrType *attr) { + PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the pass", + attr_name); + attrs_[attr_name] = attr; + attr_dels_[attr_name] = [attr, attr_name]() { + VLOG(3) << "deleting " << attr_name; + delete attr; + }; + } + + // Set a pointer to the attribute. Pass doesn't take ownership. Caller + // should delete the attribute. + template + void SetNotOwned(const std::string &attr_name, AttrType *attr) { + PADDLE_ENFORCE(attrs_.count(attr_name) == 0); + attrs_[attr_name] = attr; + } + + protected: + virtual std::unique_ptr ApplyImpl( + std::unique_ptr graph) const = 0; + + private: + template + friend struct PassRegistrar; + + void RegisterRequiredPassAttrs(const std::unordered_set &attrs) { + required_pass_attrs_.insert(attrs.begin(), attrs.end()); + } + + void RegisterRequiredGraphAttrs( + const std::unordered_set &attrs) { + required_graph_attrs_.insert(attrs.begin(), attrs.end()); + } + + mutable bool applied_{false}; + std::unordered_set required_pass_attrs_; + std::unordered_set required_graph_attrs_; + std::map attrs_; + std::map> attr_dels_; +}; + +using PassCreator = std::function()>; + +class Registrar { + public: + // In our design, various kinds of passes, + // have their corresponding registry and registrar. The action of + // registration is in the constructor of a global registrar variable, which + // are not used in the code that calls package framework, and would + // be removed from the generated binary file by the linker. To avoid such + // removal, we add Touch to all registrar classes and make USE_PASS macros to + // call this method. So, as long as the callee code calls USE_PASS, the global + // registrar variable won't be removed by the linker. + void Touch() {} +}; - virtual std::unique_ptr Apply(std::unique_ptr graph) const = 0; +class PassRegistry { + public: + static PassRegistry &Instance(); + + bool Has(const std::string &pass_type) const { + return map_.find(pass_type) != map_.end(); + } + + void Insert(const std::string &pass_type, const PassCreator &pass_creator) { + PADDLE_ENFORCE(!Has(pass_type), "Pass %s has been registered", pass_type); + map_.insert({pass_type, pass_creator}); + } + + std::unique_ptr Get(const std::string &pass_type) const { + PADDLE_ENFORCE(Has(pass_type), "Pass %s has not been registered", + pass_type); + return map_.at(pass_type)(); + } + + private: + PassRegistry() = default; + std::unordered_map map_; + + DISABLE_COPY_AND_ASSIGN(PassRegistry); }; + +template +struct PassRegistrar : public Registrar { + explicit PassRegistrar(const char *pass_type) { + PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type), + "'%s' is registered more than once.", pass_type); + PassRegistry::Instance().Insert( + pass_type, [this]() -> std::unique_ptr { + std::unique_ptr pass(new PassType()); + pass->RegisterRequiredPassAttrs(this->required_pass_attrs_); + pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_); + return pass; + }); + } + + PassRegistrar &RequirePassAttr(const std::string &attr) { + required_pass_attrs_.insert(attr); + return *this; + } + + PassRegistrar &RequireGraphAttr(const std::string &attr) { + required_graph_attrs_.insert(attr); + return *this; + } + + private: + std::unordered_set required_pass_attrs_; + std::unordered_set required_graph_attrs_; +}; + +#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \ + struct __test_global_namespace_##uniq_name##__ {}; \ + static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ + __test_global_namespace_##uniq_name##__>::value, \ + msg) + +// Register a new pass that can be applied on the IR. +#define REGISTER_PASS(pass_type, pass_class) \ + STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ + __reg_pass__##pass_type, \ + "REGISTER_PASS must be called in global namespace"); \ + static ::paddle::framework::ir::PassRegistrar \ + __pass_registrar_##pass_type##__(#pass_type); \ + int TouchPassRegistrar_##pass_type() { \ + __pass_registrar_##pass_type##__.Touch(); \ + return 0; \ + } \ + static ::paddle::framework::ir::PassRegistrar \ + &__pass_tmp_registrar_##pass_type##__ __attribute__((unused)) = \ + __pass_registrar_##pass_type##__ + +#define USE_PASS(pass_type) \ + STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ + __use_pass_itself_##pass_type, \ + "USE_PASS must be called in global namespace"); \ + extern int TouchPassRegistrar_##pass_type(); \ + static int use_pass_itself_##pass_type##_ __attribute__((unused)) = \ + TouchPassRegistrar_##pass_type() + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/pass_test.cc b/paddle/fluid/framework/ir/pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5b5011412ed39e033a7a65921e9c64ce2d54c638 --- /dev/null +++ b/paddle/fluid/framework/ir/pass_test.cc @@ -0,0 +1,112 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/pass.h" +#include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { +void BuildCircleGraph(Graph* g) { + ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation); + ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation); + ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable); + ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable); + + o1->outputs.push_back(v1); + o2->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o2); + + o2->outputs.push_back(v2); + o1->inputs.push_back(v2); + v2->inputs.push_back(o2); + v2->outputs.push_back(o1); +} + +class TestPass : public Pass { + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const { + graph->Set("copy_test_pass_attr", new int); + graph->Set("copy_test_graph_attr", new int); + + int test_pass_attr = this->Get("test_pass_attr"); + graph->Get("copy_test_pass_attr") = test_pass_attr + 1; + + int test_graph_attr = graph->Get("test_graph_attr"); + graph->Get("copy_test_graph_attr") = test_graph_attr + 1; + return graph; + } +}; + +TEST(PassTest, TestPassAttrCheck) { + ProgramDesc prog; + auto pass = PassRegistry::Instance().Get("test_pass"); + std::unique_ptr graph(new Graph(prog)); + std::string exception; + try { + graph = pass->Apply(std::move(graph)); + } catch (paddle::platform::EnforceNotMet e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("test_pass_attr not set") != exception.npos); + + int val = 1; + graph.reset(new Graph(prog)); + pass->SetNotOwned("test_pass_attr", &val); + + try { + graph = pass->Apply(std::move(graph)); + } catch (paddle::platform::EnforceNotMet e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("test_graph_attr not set") != exception.npos); + + graph.reset(new Graph(prog)); + graph->Set("test_graph_attr", new int); + graph->Get("test_graph_attr") = 1; + graph = pass->Apply(std::move(graph)); + ASSERT_EQ(graph->Get("copy_test_pass_attr"), 2); + ASSERT_EQ(graph->Get("copy_test_graph_attr"), 2); + + try { + graph = pass->Apply(std::move(graph)); + } catch (paddle::platform::EnforceNotMet e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("Pass can only Apply() once") != exception.npos); + + pass = PassRegistry::Instance().Get("test_pass"); + pass->SetNotOwned("test_pass_attr", &val); + graph.reset(new Graph(prog)); + BuildCircleGraph(graph.get()); + graph->Set("test_graph_attr", new int); + graph->Get("test_graph_attr") = 2; + try { + auto tmp = pass->Apply(std::move(graph)); + } catch (paddle::platform::EnforceNotMet e) { + exception = std::string(e.what()); + } + ASSERT_TRUE(exception.find("shouldn't has cycle") != exception.npos); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(test_pass, paddle::framework::ir::TestPass) + .RequirePassAttr("test_pass_attr") + .RequireGraphAttr("test_graph_attr"); diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 02c836bea194553bb9c4bc5677cc408dd302e9ce..b5f01a9a2b76472063658f1a051a2ee3c65559b7 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -19,19 +19,80 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_viz_pass.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/nccl_helper.h" #endif #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" -#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h" +#include "paddle/fluid/framework/details/ssa_graph_checker.h" +#include "paddle/fluid/framework/details/ssa_graph_printer.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { namespace framework { +std::unique_ptr ApplyParallelExecutorPass( + const ProgramDesc &main_program, const std::vector &places, + const std::string &loss_var_name, + const std::unordered_set ¶m_names, + const std::vector &local_scopes, const bool use_cuda, +#ifdef PADDLE_WITH_CUDA + const BuildStrategy &strategy, platform::NCCLContextMap *nccl_ctxs) { +#else + const BuildStrategy &strategy) { +#endif + // Convert the program to graph. + std::unique_ptr graph(new ir::Graph(main_program)); + + // Apply a graph viz pass to record a graph. + if (!strategy.debug_graphviz_path_.empty()) { + auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); + const std::string graph_path = string::Sprintf( + "%s%s", strategy.debug_graphviz_path_.c_str(), "_original_graph"); + viz_pass->Set("graph_viz_path", new std::string(graph_path)); + graph = viz_pass->Apply(std::move(graph)); + } + + // Convert graph to run on multi-devices. + auto multi_device_pass = + ir::PassRegistry::Instance().Get("multi_device_pass"); + multi_device_pass->SetNotOwned>("places", + &places); + multi_device_pass->SetNotOwned("loss_var_name", + &loss_var_name); + multi_device_pass->SetNotOwned>( + "params", ¶m_names); + multi_device_pass->SetNotOwned>("local_scopes", + &local_scopes); + multi_device_pass->SetNotOwned("strategy", &strategy); + +#ifdef PADDLE_WITH_CUDA + platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; + multi_device_pass->SetNotOwned("nccl_ctxs", nctx); +#endif + graph = multi_device_pass->Apply(std::move(graph)); + + // Apply a graph print pass to record a graph with device info. + if (!strategy.debug_graphviz_path_.empty()) { + auto multi_device_print_pass = + ir::PassRegistry::Instance().Get("multi_device_print_pass"); + multi_device_print_pass->SetNotOwned( + "debug_graphviz_path", &strategy.debug_graphviz_path_); + multi_device_print_pass->Set( + "graph_printer", new details::GraphvizSSAGraphPrinter); + graph = multi_device_print_pass->Apply(std::move(graph)); + } + + // Verify that the graph is correct for multi-device executor. + auto multi_device_check_pass = + ir::PassRegistry::Instance().Get("multi_device_check_pass"); + graph = multi_device_check_pass->Apply(std::move(graph)); + return graph; +} + class ParallelExecutorPrivate { public: explicit ParallelExecutorPrivate(const std::vector &places) @@ -119,21 +180,19 @@ ParallelExecutor::ParallelExecutor( var_infos.back().persistable_ = var->Persistable(); } - // Step 3. Convert main_program to SSA form and dependency graph. Also, insert - // ncclOp - details::SSAGraphBuilderFactory builder_factory( - member_->places_, loss_var_name, params, member_->local_scopes_, - build_strategy); - if (member_->use_cuda_) { +// Step 3. Convert main_program to SSA form and dependency graph. Also, insert +// ncclOp #ifdef PADDLE_WITH_CUDA - builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get()); + std::unique_ptr graph = ApplyParallelExecutorPass( + main_program, member_->places_, loss_var_name, params, + member_->local_scopes_, member_->use_cuda_, build_strategy, + member_->nccl_ctxs_.get()); #else - PADDLE_THROW("Not compiled with CUDA."); + std::unique_ptr graph = ApplyParallelExecutorPass( + main_program, member_->places_, loss_var_name, params, + member_->local_scopes_, member_->use_cuda_, build_strategy); #endif - } - builder_ = builder_factory.Create(); - std::unique_ptr graph(new ir::Graph(main_program)); - graph = builder_->Apply(std::move(graph)); + member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, std::move(graph))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( @@ -146,11 +205,18 @@ void ParallelExecutor::BCastParamsToDevices( // the initializing bcast, all vars would be bcast from device(0), // otherwise // bcast from the specified device. - bool initializing = builder_.get() == nullptr ? true : false; - + bool initializing = member_->executor_ ? false : true; for (auto &var : vars) { - int var_dev_id = - builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var); + int var_dev_id = -1; + if (member_->executor_) { + auto &sharded_var_device = + member_->executor_->Graph().Get( + details::kShardedVarDevice); + if (sharded_var_device.find(var) != sharded_var_device.end()) { + var_dev_id = sharded_var_device.at(var); + } + } + if (!initializing && var_dev_id == -1) continue; framework::Variable *main_var = nullptr; @@ -286,3 +352,8 @@ ParallelExecutor::~ParallelExecutor() { } // namespace framework } // namespace paddle + +USE_PASS(graph_viz_pass); +USE_PASS(multi_device_pass); +USE_PASS(multi_device_check_pass); +USE_PASS(multi_device_print_pass); diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index ffb9934a2d702b2bf6db7ad75a6bf9867e1e9901..d624956acde86cefc4ec1dec80df3738bcf1d8be 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -70,7 +70,6 @@ class ParallelExecutor { private: ParallelExecutorPrivate *member_; - std::unique_ptr builder_; }; } // namespace framework