From ab72d28a5ec3efd4243df8c7cd3370b9354e009f Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Thu, 26 Jul 2018 15:02:51 +0800 Subject: [PATCH] clean up and correctness check --- doc/fluid/design/ir/draft.md | 10 +- .../details/multi_devices_graph_builder.cc | 103 ++++++++++-------- .../details/multi_devices_graph_builder.h | 4 +- .../framework/details/ssa_graph_builder.cc | 13 ++- .../framework/details/ssa_graph_builder.h | 4 + .../framework/details/ssa_graph_checker.cc | 12 +- .../framework/details/ssa_graph_checker.h | 4 +- .../framework/details/ssa_graph_printer.cc | 6 +- .../framework/details/ssa_graph_printer.h | 4 +- .../details/threaded_ssa_graph_executor.cc | 8 +- paddle/fluid/framework/ir/graph.h | 8 +- paddle/fluid/framework/ir/graph_viz_pass.cc | 8 +- paddle/fluid/framework/ir/graph_viz_pass.h | 4 +- paddle/fluid/framework/ir/pass.cc | 16 +++ paddle/fluid/framework/ir/pass.h | 70 +++++++++--- paddle/fluid/framework/parallel_executor.cc | 2 +- 16 files changed, 184 insertions(+), 92 deletions(-) diff --git a/doc/fluid/design/ir/draft.md b/doc/fluid/design/ir/draft.md index 65bfaea6a1d..e141ce09591 100644 --- a/doc/fluid/design/ir/draft.md +++ b/doc/fluid/design/ir/draft.md @@ -75,7 +75,12 @@ can also fuse some `Graph`'s `Node`s. class Pass { public: - virtual std::unique_ptr Apply(std::unique_ptr graph) const = 0; + std::unique_ptr Apply(std::unique_ptr graph) const { + // Some correctness check. + auto new_graph = ApplyImpl(std::move(graph)); + // Some correctness check. + return new_graph; + } // Get a reference to the attributed previously set. template @@ -89,6 +94,9 @@ class Pass { // should delete the attribute. template void SetNotOwned(const std::string &attr_name, AttrType *attr); + + protected: + virtual std::unique_ptr ApplyImpl(std::unique_ptr graph) const = 0; }; // In my_pass.cc diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index ff90f31cdbc..b63c2f695ac 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -34,16 +34,22 @@ namespace paddle { namespace framework { namespace details { +static const char kLossVarName[] = "loss_var_name"; +static const char kPlaces[] = "places"; +static const char kParams[] = "params"; +static const char kLocalScopes[] = "local_scopes"; +static const char kStrategy[] = "strategy"; + void MultiDevSSAGraphBuilder::Init() const { - loss_var_name_ = Get("loss_var_name"); - places_ = Get>("places"); - local_scopes_ = Get>("local_scopes"); - strategy_ = Get("strategy"); + loss_var_name_ = Get(kLossVarName); + places_ = Get>(kPlaces); + local_scopes_ = Get>(kLocalScopes); + strategy_ = Get(kStrategy); #ifdef PADDLE_WITH_CUDA nccl_ctxs_ = &Get("nccl_ctxs"); #endif - for (auto &p : Get>("params")) { + for (auto &p : Get>(kParams)) { grad_names_.insert(GradVarName(p)); } balance_vars_.resize(places_.size(), 0); @@ -58,7 +64,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, ir::Node *node, size_t place_id) const { auto p = places_[place_id]; - auto *op_handle = result->Get("ops").back().get(); + auto *op_handle = result->Get(kGraphOps).back().get(); op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); @@ -225,7 +231,7 @@ std::vector SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { return sorted_ret; } -std::unique_ptr MultiDevSSAGraphBuilder::Apply( +std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( std::unique_ptr graph) const { Init(); // Give the topology sort order and rebuild the graph structure. @@ -241,10 +247,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( std::unordered_set og_has_been_broadcast; // We cannot invoke resize. It is a bug of GCC 4.8 - result.Set("vars", new GraphVars(places_.size())); - result.Set("dep_vars", new GraphDepVars); - result.Set("ops", new GraphOps); - result.Set("sharded_var_device", new ShardedVarDevice); + result.Set(kGraphVars, new GraphVars(places_.size())); + result.Set(kGraphDepVars, new GraphDepVars); + result.Set(kGraphOps, new GraphOps); + result.Set(kShardedVarDevice, new ShardedVarDevice); // find send/recv vars so that we can place the distributed training // realted op in the place 0 @@ -281,7 +287,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( if (op_dev_id != -1) { // This op only runs on one specific device. CreateComputationalOp(&result, node, op_dev_id); for (ir::Node *n : node->outputs) { - graph->Get("sharded_var_device") + graph->Get(kShardedVarDevice) .emplace(n->Name(), op_dev_id); } } else { @@ -319,7 +325,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( case BuildStrategy::ReduceStrategy::kReduce: cur_device_id = GetAppropriateDeviceID({g_name}); CreateReduceOp(&result, g_name, cur_device_id); - graph->Get("sharded_var_device") + graph->Get(kShardedVarDevice) .emplace(g_name, cur_device_id); bcast_var_name_set[cur_device_id].emplace(p_name); break; @@ -406,16 +412,16 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation), local_scopes_, places_); #endif - result->Get("ops").emplace_back(op_handle); + result->Get(kGraphOps).emplace_back(op_handle); auto *in = - result->Get("vars").at(src_dev_id).at(p_name).back().get(); + result->Get(kGraphVars).at(src_dev_id).at(p_name).back().get(); op_handle->AddInput(in); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->Get("vars").at(i).at(p_name); + auto &vars = result->Get(kGraphVars).at(i).at(p_name); auto *out_var = new VarHandle( result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(), i, p_name, p); @@ -427,7 +433,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, ir::Node *node, int dev_id) const { - result->Get("ops").emplace_back( + result->Get(kGraphOps).emplace_back( new ComputationOpHandle(result->CreateOpNode(node->Op()), local_scopes_[dev_id], places_[dev_id])); CreateOpHandleIOs(result, node, dev_id); @@ -436,20 +442,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back(new AllReduceOpHandle( + result->Get(kGraphOps).emplace_back(new AllReduceOpHandle( result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), local_scopes_, places_, nccl_ctxs_)); #else - result->Get("ops").emplace_back(new AllReduceOpHandle( + result->Get(kGraphOps).emplace_back(new AllReduceOpHandle( result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), local_scopes_, places_)); #endif - auto *op_handle = result->Get("ops").back().get(); + auto *op_handle = result->Get(kGraphOps).back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->Get("vars")[i][og]; + auto &vars = result->Get(kGraphVars)[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); @@ -465,20 +471,20 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, void MultiDevSSAGraphBuilder::InsertDataBalanceOp( ir::Graph *result, const std::vector &datas) const { #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back(new DataBalanceOpHandle( + result->Get(kGraphOps).emplace_back(new DataBalanceOpHandle( result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), local_scopes_, places_, nccl_ctxs_)); #else - result->Get("ops").emplace_back(new DataBalanceOpHandle( + result->Get(kGraphOps).emplace_back(new DataBalanceOpHandle( result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), local_scopes_, places_)); #endif - auto *op_handle = result->Get("ops").back().get(); + auto *op_handle = result->Get(kGraphOps).back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); for (const std::string &d_name : datas) { - auto &vars = result->Get("vars")[i][d_name]; + auto &vars = result->Get(kGraphVars)[i][d_name]; PADDLE_ENFORCE(!vars.empty()); op_handle->AddInput(vars.back().get()); auto var = new VarHandle( @@ -524,7 +530,7 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph, int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, const std::string &varname) const { - auto &sharded_var_device = graph.Get("sharded_var_device"); + auto &sharded_var_device = graph.Get(kShardedVarDevice); auto got = sharded_var_device.find(varname); return got == sharded_var_device.end() ? -1 : got->second; } @@ -544,7 +550,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const { result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), local_scopes_.size(), local_scopes_[i], places_[i], communication_dev_ctx); - result->Get("ops").emplace_back(op_handle); + result->Get(kGraphOps).emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale // factor. So it does not depend on any other operators. @@ -565,7 +571,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; - result->Get("ops").emplace_back( + result->Get(kGraphOps).emplace_back( new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p)); CreateOpHandleIOs(result, node, scope_idx); } @@ -575,25 +581,25 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, const std::string &og, int dst_dev_id) const { #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back(new ReduceOpHandle( + result->Get(kGraphOps).emplace_back(new ReduceOpHandle( result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), local_scopes_, places_, nccl_ctxs_)); #else - result->Get("ops").emplace_back(new ReduceOpHandle( + result->Get(kGraphOps).emplace_back(new ReduceOpHandle( result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), local_scopes_, places_)); #endif - auto *op_handle = result->Get("ops").back().get(); + auto *op_handle = result->Get(kGraphOps).back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->Get("vars")[i][og]; + auto &vars = result->Get(kGraphVars)[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); } - auto &vars = result->Get("vars")[dst_dev_id][og]; + auto &vars = result->Get(kGraphVars)[dst_dev_id][og]; auto var = new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), vars.size(), dst_dev_id, og, places_[dst_dev_id]); @@ -606,11 +612,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, // on it. void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { - for (auto &prev_op : result->Get("ops")) { + for (auto &prev_op : result->Get(kGraphOps)) { if (prev_op->Name() == prev_op_name) { auto *dep_var = new DummyVarHandle(result->CreateControlDepVar()); prev_op->AddOutput(dep_var); - result->Get("dep_vars").emplace(dep_var); + result->Get(kGraphDepVars).emplace(dep_var); op->AddInput(dep_var); } } @@ -635,18 +641,18 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { op_dev_id = GetAppropriateDeviceID(input_var_names); for (auto &varname : input_var_names) { - result->Get("sharded_var_device") + result->Get(kShardedVarDevice) .emplace(varname, op_dev_id); } } for (auto &varname : output_var_names) { - result->Get("sharded_var_device") + result->Get(kShardedVarDevice) .emplace(varname, op_dev_id); } } else if (node->Op()->Type() == "concat") { op_dev_id = GetVarDeviceID(*result, input_var_names[0]); for (auto &varname : output_var_names) { - result->Get("sharded_var_device") + result->Get(kShardedVarDevice) .emplace(varname, op_dev_id); } } else { @@ -661,7 +667,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, CreateComputationalOp(result, node, op_dev_id); if (node->Op()->Type() == "concat") { - ConnectOp(result, result->Get("ops").back().get(), + ConnectOp(result, result->Get(kGraphOps).back().get(), "fetch_barrier"); } } @@ -687,7 +693,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, } op_dev_id = GetAppropriateDeviceID(input_var_names); for (auto &varname : input_var_names) { - result->Get("sharded_var_device") + result->Get(kShardedVarDevice) .emplace(varname, op_dev_id); } } @@ -698,7 +704,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, } op_dev_id = GetAppropriateDeviceID(output_var_names); for (auto &varname : output_var_names) { - result->Get("sharded_var_device") + result->Get(kShardedVarDevice) .emplace(varname, op_dev_id); } } else { @@ -709,17 +715,17 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", node->Op()->Type()); - result->Get("ops").emplace_back(new RPCOpHandle( + result->Get(kGraphOps).emplace_back(new RPCOpHandle( result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id], node->Op()->Type(), places_[op_dev_id])); if (node->Op()->Type() == "send_barrier") { - ConnectOp(result, result->Get("ops").back().get(), "send"); + ConnectOp(result, result->Get(kGraphOps).back().get(), "send"); } else if (node->Op()->Type() == "recv") { - ConnectOp(result, result->Get("ops").back().get(), + ConnectOp(result, result->Get(kGraphOps).back().get(), "send_barrier"); } else if (node->Op()->Type() == "fetch_barrier") { - ConnectOp(result, result->Get("ops").back().get(), "recv"); + ConnectOp(result, result->Get(kGraphOps).back().get(), "recv"); } else if (node->Op()->Type() == "send") { // do nothing } else { @@ -743,4 +749,9 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { } // namespace paddle REGISTER_PASS(multi_device_pass, - paddle::framework::details::MultiDevSSAGraphBuilder); + paddle::framework::details::MultiDevSSAGraphBuilder) + .RequirePassAttr(paddle::framework::details::kLossVarName) + .RequirePassAttr(paddle::framework::details::kPlaces) + .RequirePassAttr(paddle::framework::details::kParams) + .RequirePassAttr(paddle::framework::details::kLocalScopes) + .RequirePassAttr(paddle::framework::details::kStrategy); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index baea091af30..099dbe5abef 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -31,8 +31,8 @@ class Scope; namespace details { class MultiDevSSAGraphBuilder : public SSAGraphBuilder { - public: - std::unique_ptr Apply( + protected: + std::unique_ptr ApplyImpl( std::unique_ptr graph) const override; private: diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 506e7eb35cd..575532540a6 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -18,7 +18,7 @@ namespace paddle { namespace framework { namespace details { void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { - for (auto &var_map : graph->Get("vars")) { + for (auto &var_map : graph->Get(kGraphVars)) { for (auto &name_pair : var_map) { if (name_pair.second.size() <= 1) { continue; @@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); read_op->AddOutput(dep_var); write_op->AddInput(dep_var); - graph->Get("dep_vars").emplace(dep_var); + graph->Get(kGraphDepVars).emplace(dep_var); } } } @@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ir::Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset) { - auto &var_holders = graph->Get("vars")[place_offset]; + auto &var_holders = graph->Get(kGraphVars)[place_offset]; auto &var_holder = var_holders[node->Name()]; VarHandle *var = nullptr; if (var_holder.empty()) { @@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, ir::Node *new_node, const platform::Place &place, size_t place_offset) { - auto &vars = graph->Get("vars")[place_offset][new_node->Name()]; + auto &vars = + graph->Get(kGraphVars)[place_offset][new_node->Name()]; size_t version = vars.size(); auto var = new VarHandle(new_node, version, place_offset, new_node->Name(), place); @@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle, } void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) { - for (auto &op : graph->Get("ops")) { + for (auto &op : graph->Get(kGraphOps)) { if (!op->Outputs().empty()) { continue; } auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); - graph->Get("dep_vars").emplace(dummy_leaf); + graph->Get(kGraphDepVars).emplace(dummy_leaf); op->AddOutput(dummy_leaf); } } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index e0ad0273150..53a4ad003d5 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -39,15 +39,19 @@ namespace details { typedef std::vector< std::unordered_map>>> GraphVars; +const char kGraphVars[] = "vars"; // aux variables to represent dependency. Useful to resolve data hazard. typedef std::unordered_set> GraphDepVars; +const char kGraphDepVars[] = "dep_vars"; // all operators. NOTE that even we use a vector here, the operators is // unordered. typedef std::vector> GraphOps; +const char kGraphOps[] = "ops"; typedef std::unordered_map ShardedVarDevice; +const char kShardedVarDevice[] = "sharded_var_device"; class SSAGraphBuilder : public ir::Pass { public: diff --git a/paddle/fluid/framework/details/ssa_graph_checker.cc b/paddle/fluid/framework/details/ssa_graph_checker.cc index 2994329f485..b9e1cda1f24 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.cc +++ b/paddle/fluid/framework/details/ssa_graph_checker.cc @@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { } }; - for (auto &var_map : graph->Get("vars")) { + for (auto &var_map : graph->Get(kGraphVars)) { for (auto &name_pair : var_map) { for (auto &version_pair : name_pair.second) { insert_pending_var(version_pair.get()); @@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { } } - for (auto &var : graph->Get("dep_vars")) { + for (auto &var : graph->Get(kGraphDepVars)) { insert_pending_var(var.get()); } - for (auto &op : graph->Get("ops")) { + for (auto &op : graph->Get(kGraphOps)) { if (op->Inputs().empty()) { ready_ops.insert(op.get()); } else { @@ -87,4 +87,8 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { } // namespace paddle REGISTER_PASS(multi_device_check_pass, - paddle::framework::details::SSAGraghBuilderWithChecker); + paddle::framework::details::SSAGraghBuilderWithChecker) + .RequireGraphAttr(paddle::framework::details::kGraphVars) + .RequireGraphAttr(paddle::framework::details::kGraphDepVars) + .RequireGraphAttr(paddle::framework::details::kGraphOps) + .RequireGraphAttr(paddle::framework::details::kShardedVarDevice); diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index 25891cf74da..0e861ecb236 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -23,8 +23,8 @@ namespace framework { namespace details { class SSAGraghBuilderWithChecker : public SSAGraphBuilder { - public: - std::unique_ptr Apply( + protected: + std::unique_ptr ApplyImpl( std::unique_ptr graph) const override { PADDLE_ENFORCE(IsValidGraph(graph.get())); return graph; diff --git a/paddle/fluid/framework/details/ssa_graph_printer.cc b/paddle/fluid/framework/details/ssa_graph_printer.cc index 95d0641d72c..ec3f31ab8d1 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.cc +++ b/paddle/fluid/framework/details/ssa_graph_printer.cc @@ -22,7 +22,7 @@ namespace details { template static inline void IterAllVar(const ir::Graph &graph, Callback callback) { - for (auto &each : graph.Get("vars")) { + for (auto &each : graph.Get(kGraphVars)) { for (auto &pair1 : each) { for (auto &pair2 : pair1.second) { callback(*pair2); @@ -30,7 +30,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) { } } - for (auto &var : graph.Get("dep_vars")) { + for (auto &var : graph.Get(kGraphDepVars)) { callback(*var); } } @@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, }); size_t op_id = 0; - for (auto &op : graph.Get("ops")) { + for (auto &op : graph.Get(kGraphOps)) { std::string op_name = "op_" + std::to_string(op_id++); sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" << std::endl; diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index bd4498c0612..5eafd1805c3 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -36,8 +36,8 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter { }; class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { - public: - std::unique_ptr Apply( + protected: + std::unique_ptr ApplyImpl( std::unique_ptr graph) const override { std::unique_ptr fout( new std::ofstream(Get("debug_graphviz_path"))); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index c19f74476f9..eec40507337 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -45,18 +45,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( std::unordered_set delayed_ops; // Transform SSAGraph to pending_ops & pending_vars - for (auto &var_map : graph_->Get("vars")) { + for (auto &var_map : graph_->Get(details::kGraphVars)) { for (auto &name_pair : var_map) { for (auto &version_pair : name_pair.second) { InsertPendingVar(&pending_vars, &ready_vars, version_pair.get()); } } } - for (auto &var : graph_->Get("dep_vars")) { + for (auto &var : graph_->Get(details::kGraphDepVars)) { InsertPendingVar(&pending_vars, &ready_vars, var.get()); } - for (auto &op : graph_->Get("ops")) { + for (auto &op : graph_->Get(details::kGraphOps)) { if (op->Inputs().empty()) { // Special case, Op has no input. ready_ops.insert(op.get()); } else { @@ -162,7 +162,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( std::unordered_map> fetched_vars; for (auto &fetch_var_name : fetch_tensors) { - for (auto &var_map : graph_->Get("vars")) { + for (auto &var_map : graph_->Get(details::kGraphVars)) { auto it = var_map.find(fetch_var_name); if (it != var_map.end()) { fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 49f39df4b96..78094e46fb1 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -40,10 +40,14 @@ class Graph { attr_dels_.clear(); } + bool Has(const std::string &attr_name) const { + return attrs_.find(attr_name) != attrs_.end(); + } + template AttrType &Get(const std::string &attr_name) const { - PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(), - "%s attr not registered for graph.", attr_name); + PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.", + attr_name); return *boost::any_cast(attrs_.at(attr_name)); } diff --git a/paddle/fluid/framework/ir/graph_viz_pass.cc b/paddle/fluid/framework/ir/graph_viz_pass.cc index 7d1cff71785..8cb812d1388 100644 --- a/paddle/fluid/framework/ir/graph_viz_pass.cc +++ b/paddle/fluid/framework/ir/graph_viz_pass.cc @@ -20,10 +20,11 @@ limitations under the License. */ namespace paddle { namespace framework { namespace ir { +static const char kGraphVizPath[] = "graph_viz_path"; -std::unique_ptr GraphVizPass::Apply( +std::unique_ptr GraphVizPass::ApplyImpl( std::unique_ptr graph) const { - const std::string graph_viz_path = Get("graph_viz_path"); + const std::string graph_viz_path = Get(kGraphVizPath); std::unique_ptr fout(new std::ofstream(graph_viz_path)); PADDLE_ENFORCE(fout->good()); std::ostream& sout = *fout; @@ -67,4 +68,5 @@ std::unique_ptr GraphVizPass::Apply( } // namespace framework } // namespace paddle -REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass); +REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass) + .RequirePassAttr(paddle::framework::ir::kGraphVizPath); diff --git a/paddle/fluid/framework/ir/graph_viz_pass.h b/paddle/fluid/framework/ir/graph_viz_pass.h index 04c0c35d12f..1fd8c8a26e9 100644 --- a/paddle/fluid/framework/ir/graph_viz_pass.h +++ b/paddle/fluid/framework/ir/graph_viz_pass.h @@ -28,8 +28,8 @@ namespace framework { namespace ir { class GraphVizPass : public Pass { - public: - std::unique_ptr Apply( + protected: + std::unique_ptr ApplyImpl( std::unique_ptr graph) const override; }; diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 0e68ecb56fb..2ebc3c7430c 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -17,6 +17,22 @@ limitations under the License. */ namespace paddle { namespace framework { namespace ir { +std::unique_ptr Pass::Apply(std::unique_ptr graph) const { + for (const std::string& attr : required_pass_attrs_) { + PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(), + "Required pass atrribute %s not registered.", attr); + } + for (const std::string& attr : required_graph_attrs_) { + PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not exist.", + attr); + } + auto applied_graph = ApplyImpl(std::move(graph)); + // TODO(panyx0718): Add more verifications. + PADDLE_ENFORCE(!HasCircle(*applied_graph), + "Illegal Pass. Generated graph shouldn't has cycle."); + return applied_graph; +} + PassRegistry& PassRegistry::Instance() { static PassRegistry g_pass_info_map; return g_pass_info_map; diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index f254ef62df5..3f65794fab2 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/variant.h" @@ -26,6 +27,8 @@ limitations under the License. */ namespace paddle { namespace framework { namespace ir { +template +struct PassRegistrar; class Pass { public: @@ -40,7 +43,7 @@ class Pass { attr_dels_.clear(); } - virtual std::unique_ptr Apply(std::unique_ptr graph) const = 0; + std::unique_ptr Apply(std::unique_ptr graph) const; // Get a reference to the attributed previously set. template @@ -69,7 +72,25 @@ class Pass { attrs_[attr_name] = attr; } + protected: + virtual std::unique_ptr ApplyImpl( + std::unique_ptr graph) const = 0; + private: + template + friend struct PassRegistrar; + + void RegisterRequiredPassAttrs(const std::unordered_set &attrs) { + required_pass_attrs_.insert(attrs.begin(), attrs.end()); + } + + void RegisterRequiredGraphAttrs( + const std::unordered_set &attrs) { + required_graph_attrs_.insert(attrs.begin(), attrs.end()); + } + + std::unordered_set required_pass_attrs_; + std::unordered_set required_graph_attrs_; std::map attrs_; std::map> attr_dels_; }; @@ -119,10 +140,28 @@ struct PassRegistrar : public Registrar { explicit PassRegistrar(const char *pass_type) { PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type), "'%s' is registered more than once.", pass_type); - PassRegistry::Instance().Insert(pass_type, []() -> std::unique_ptr { - return std::unique_ptr(new PassType()); - }); + PassRegistry::Instance().Insert( + pass_type, [this]() -> std::unique_ptr { + std::unique_ptr pass(new PassType()); + pass->RegisterRequiredPassAttrs(this->required_pass_attrs_); + pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_); + return pass; + }); } + + PassRegistrar &RequirePassAttr(const std::string &attr) { + required_pass_attrs_.insert(attr); + return *this; + } + + PassRegistrar &RequireGraphAttr(const std::string &attr) { + required_graph_attrs_.insert(attr); + return *this; + } + + private: + std::unordered_set required_pass_attrs_; + std::unordered_set required_graph_attrs_; }; #define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \ @@ -132,16 +171,19 @@ struct PassRegistrar : public Registrar { msg) // Register a new pass that can be applied on the IR. -#define REGISTER_PASS(pass_type, pass_class) \ - STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ - __reg_pass__##pass_type, \ - "REGISTER_PASS must be called in global namespace"); \ - static ::paddle::framework::ir::PassRegistrar \ - __pass_registrar_##pass_type##__(#pass_type); \ - int TouchPassRegistrar_##pass_type() { \ - __pass_registrar_##pass_type##__.Touch(); \ - return 0; \ - } +#define REGISTER_PASS(pass_type, pass_class) \ + STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ + __reg_pass__##pass_type, \ + "REGISTER_PASS must be called in global namespace"); \ + static ::paddle::framework::ir::PassRegistrar \ + __pass_registrar_##pass_type##__(#pass_type); \ + int TouchPassRegistrar_##pass_type() { \ + __pass_registrar_##pass_type##__.Touch(); \ + return 0; \ + } \ + static ::paddle::framework::ir::PassRegistrar \ + &__pass_tmp_registrar_##pass_type##__ __attribute__((unused)) = \ + __pass_registrar_##pass_type##__ #define USE_PASS(pass_type) \ STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 77bed5c9994..112b48ca31e 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -213,7 +213,7 @@ void ParallelExecutor::BCastParamsToDevices( if (member_->executor_) { auto &sharded_var_device = member_->executor_->Graph().Get( - "sharded_var_device"); + details::kShardedVarDevice); if (sharded_var_device.find(var) != sharded_var_device.end()) { var_dev_id = sharded_var_device.at(var); } -- GitLab