From adf5615e541455a41051ff47b1651ee2870bd8d9 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 6 Nov 2018 11:29:10 +0800 Subject: [PATCH] clean kGraphOp test=develop --- .../details/fast_threaded_ssa_graph_executor.cc | 5 ++--- .../framework/details/multi_devices_graph_check_pass.cc | 7 ++++--- .../fluid/framework/details/multi_devices_graph_pass.cc | 7 ++++++- .../framework/details/multi_devices_graph_print_pass.cc | 3 ++- paddle/fluid/framework/details/multi_devices_helper.h | 5 ----- paddle/fluid/framework/details/reference_count_pass.cc | 3 ++- .../framework/details/threaded_ssa_graph_executor.cc | 3 ++- paddle/fluid/framework/ir/graph.h | 9 +++++++++ paddle/fluid/framework/ir/graph_helper.h | 9 +++++++++ 9 files changed, 36 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 42849853c..403b055c4 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -16,6 +16,7 @@ #include #include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/ir/graph_helper.h" namespace paddle { namespace framework { @@ -32,9 +33,7 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( pool_(strategy.num_threads_ + 1), // add one more thread for generate op_deps fetch_ctxs_(places) { - auto &ops = graph_->Get("ops"); - - for (auto &op : ops) { + for (auto &op : ir::GetFilteredNodes(*graph_)) { int dep = static_cast(op->NotReadyInputSize()); op_deps_.emplace(op, dep); if (dep == 0) { diff --git a/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc index 5bfafa829..220aa88f7 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc @@ -45,7 +45,9 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { insert_pending_var(var); } - for (auto &op : graph->Get(kGraphOps)) { + for (ir::Node *node : graph->Nodes()) { + if (!node->IsWrappedBy()) continue; + OpHandleBase *op = &node->Wrapper(); if (op->Inputs().empty()) { ready_ops.insert(op); } else { @@ -89,5 +91,4 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { REGISTER_PASS(multi_devices_check_pass, paddle::framework::details::SSAGraghBuilderWithChecker) .RequireGraphAttr(paddle::framework::details::kGraphVars) - .RequireGraphAttr(paddle::framework::details::kGraphDepVars) - .RequireGraphAttr(paddle::framework::details::kGraphOps); + .RequireGraphAttr(paddle::framework::details::kGraphDepVars); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index e072e09ec..58b7ea0b9 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -36,6 +36,11 @@ namespace framework { namespace details { namespace { +// all operators. NOTE that even we use a vector here, the operators is +// unordered. +typedef std::vector GraphOps; +const char kGraphOps[] = "ops"; + void PolishGraphToSupportDataHazards(ir::Graph *graph) { for (auto &var_map : graph->Get(kGraphVars)) { for (auto &name_pair : var_map) { @@ -458,7 +463,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( * Only variables should be the leaves of graph. */ AddOutputToLeafOps(&result); - PADDLE_ENFORCE(!ir::HasCircle(result)); + result.Erase(kGraphOps); return graph; } diff --git a/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc index 361c91dc7..ae50905f7 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include #include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_helper.h" namespace paddle { namespace framework { @@ -62,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, }); size_t op_id = 0; - for (auto &op : graph.Get(kGraphOps)) { + for (auto &op : ir::GetFilteredNodes(graph)) { std::string op_name = "op_" + std::to_string(op_id++); sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" << std::endl; diff --git a/paddle/fluid/framework/details/multi_devices_helper.h b/paddle/fluid/framework/details/multi_devices_helper.h index bed2fdb86..5a9e06369 100644 --- a/paddle/fluid/framework/details/multi_devices_helper.h +++ b/paddle/fluid/framework/details/multi_devices_helper.h @@ -43,11 +43,6 @@ const char kGraphVars[] = "vars"; // aux variables to represent dependency. Useful to resolve data hazard. typedef std::unordered_set GraphDepVars; const char kGraphDepVars[] = "dep_vars"; - -// all operators. NOTE that even we use a vector here, the operators is -// unordered. -typedef std::vector GraphOps; -const char kGraphOps[] = "ops"; } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc index 19b943cb9..42b248650 100644 --- a/paddle/fluid/framework/details/reference_count_pass.cc +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -19,6 +19,7 @@ #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/reference_count_pass.h" +#include "paddle/fluid/framework/ir/graph_helper.h" namespace paddle { namespace framework { @@ -156,7 +157,7 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( } }; - auto &all_ops = graph->Get(kGraphOps); + auto all_ops = ir::GetFilteredNodes(*graph); for (auto &op : all_ops) { auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs()); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 781116ccb..05c158210 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { @@ -59,7 +60,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( InsertPendingVar(&pending_vars, ready_vars.get(), var); } - for (auto &op : graph_->Get(details::kGraphOps)) { + for (auto &op : ir::GetFilteredNodes(*graph_)) { if (op->Inputs().empty()) { // Special case, Op has no input. ready_ops.insert(op); } else { diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 9d7aa5d32..8830638ec 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -102,6 +102,15 @@ class Graph { attr_dels_[attr_name] = []() {}; } + template + void Erase(const std::string &attr_name) { + PADDLE_ENFORCE(attrs_.count(attr_name) != 0, "%s not set in the graph", + attr_name); + attr_dels_[attr_name](); + attrs_.erase(attr_name); + attr_dels_.erase(attr_name); + } + const std::unordered_set &Nodes() const { return node_set_; } // Create a normal variable with non-null VarDesc. diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index ec46b38c0..a107aaf7f 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -37,6 +37,15 @@ std::vector TopologySortOperations(const Graph &graph); std::map> BuildOperationAdjList( const Graph &graph); +template +std::vector GetFilteredNodes(const Graph &graph) { + std::vector ret; + for (ir::Node *n : graph.Nodes()) { + if (n->IsWrappedBy()) ret.push_back(&n->Wrapper()); + } + return ret; +} + } // namespace ir } // namespace framework } // namespace paddle -- GitLab