diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 42849853c81dd9669833e082fbcf631fedbfbc70..403b055c418af6a8fec1f42c59681db0f45f123c 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -16,6 +16,7 @@ #include #include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/ir/graph_helper.h" namespace paddle { namespace framework { @@ -32,9 +33,7 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( pool_(strategy.num_threads_ + 1), // add one more thread for generate op_deps fetch_ctxs_(places) { - auto &ops = graph_->Get("ops"); - - for (auto &op : ops) { + for (auto &op : ir::GetFilteredNodes(*graph_)) { int dep = static_cast(op->NotReadyInputSize()); op_deps_.emplace(op, dep); if (dep == 0) { diff --git a/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc index 5bfafa82918d7de56d87e270b2c74637d456f036..220aa88f7b02e4923e869183fd8549d559b76d02 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_check_pass.cc @@ -45,7 +45,9 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { insert_pending_var(var); } - for (auto &op : graph->Get(kGraphOps)) { + for (ir::Node *node : graph->Nodes()) { + if (!node->IsWrappedBy()) continue; + OpHandleBase *op = &node->Wrapper(); if (op->Inputs().empty()) { ready_ops.insert(op); } else { @@ -89,5 +91,4 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { REGISTER_PASS(multi_devices_check_pass, paddle::framework::details::SSAGraghBuilderWithChecker) .RequireGraphAttr(paddle::framework::details::kGraphVars) - .RequireGraphAttr(paddle::framework::details::kGraphDepVars) - .RequireGraphAttr(paddle::framework::details::kGraphOps); + .RequireGraphAttr(paddle::framework::details::kGraphDepVars); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index e072e09ece627843519be82b9a599a43e72f3fa0..58b7ea0b9e9ec24cb6310d39a791f2e134b4b39a 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -36,6 +36,11 @@ namespace framework { namespace details { namespace { +// all operators. NOTE that even we use a vector here, the operators is +// unordered. +typedef std::vector GraphOps; +const char kGraphOps[] = "ops"; + void PolishGraphToSupportDataHazards(ir::Graph *graph) { for (auto &var_map : graph->Get(kGraphVars)) { for (auto &name_pair : var_map) { @@ -458,7 +463,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( * Only variables should be the leaves of graph. */ AddOutputToLeafOps(&result); - PADDLE_ENFORCE(!ir::HasCircle(result)); + result.Erase(kGraphOps); return graph; } diff --git a/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc index 361c91dc78c08a2cbf84ee88211d389c1e2312e5..ae50905f7614375f4fd65d61628070e4dc98b5fe 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_print_pass.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include #include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_helper.h" namespace paddle { namespace framework { @@ -62,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, }); size_t op_id = 0; - for (auto &op : graph.Get(kGraphOps)) { + for (auto &op : ir::GetFilteredNodes(graph)) { std::string op_name = "op_" + std::to_string(op_id++); sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" << std::endl; diff --git a/paddle/fluid/framework/details/multi_devices_helper.h b/paddle/fluid/framework/details/multi_devices_helper.h index bed2fdb864c4c6a92239ac295c74fc70ba3499f2..5a9e06369d9b408d144d91e35417f048e404eacb 100644 --- a/paddle/fluid/framework/details/multi_devices_helper.h +++ b/paddle/fluid/framework/details/multi_devices_helper.h @@ -43,11 +43,6 @@ const char kGraphVars[] = "vars"; // aux variables to represent dependency. Useful to resolve data hazard. typedef std::unordered_set GraphDepVars; const char kGraphDepVars[] = "dep_vars"; - -// all operators. NOTE that even we use a vector here, the operators is -// unordered. -typedef std::vector GraphOps; -const char kGraphOps[] = "ops"; } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc index 19b943cb9c2a8056482b52ecd2388d090a9fb4e0..42b248650ea03406e8634e5e3ae56640e6ad8903 100644 --- a/paddle/fluid/framework/details/reference_count_pass.cc +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -19,6 +19,7 @@ #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/reference_count_pass.h" +#include "paddle/fluid/framework/ir/graph_helper.h" namespace paddle { namespace framework { @@ -156,7 +157,7 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( } }; - auto &all_ops = graph->Get(kGraphOps); + auto all_ops = ir::GetFilteredNodes(*graph); for (auto &op : all_ops) { auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs()); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 781116ccba49fa37d11a9f2cb8be771a07ad1316..05c158210cb43b9b31d8b5e5a32c31eaaf4736a4 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { @@ -59,7 +60,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( InsertPendingVar(&pending_vars, ready_vars.get(), var); } - for (auto &op : graph_->Get(details::kGraphOps)) { + for (auto &op : ir::GetFilteredNodes(*graph_)) { if (op->Inputs().empty()) { // Special case, Op has no input. ready_ops.insert(op); } else { diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 9d7aa5d32deb274fbf29481b0d4754c05d1e21b5..8830638ec8b70c3fcaaa83c2c3c819e2cc8ab795 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -102,6 +102,15 @@ class Graph { attr_dels_[attr_name] = []() {}; } + template + void Erase(const std::string &attr_name) { + PADDLE_ENFORCE(attrs_.count(attr_name) != 0, "%s not set in the graph", + attr_name); + attr_dels_[attr_name](); + attrs_.erase(attr_name); + attr_dels_.erase(attr_name); + } + const std::unordered_set &Nodes() const { return node_set_; } // Create a normal variable with non-null VarDesc. diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index ec46b38c01b8c369ab37b4fbd5497ec120d8db91..a107aaf7f578dd7f3f6cd08b7e1462c91fb02fb9 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -37,6 +37,15 @@ std::vector TopologySortOperations(const Graph &graph); std::map> BuildOperationAdjList( const Graph &graph); +template +std::vector GetFilteredNodes(const Graph &graph) { + std::vector ret; + for (ir::Node *n : graph.Nodes()) { + if (n->IsWrappedBy()) ret.push_back(&n->Wrapper()); + } + return ret; +} + } // namespace ir } // namespace framework } // namespace paddle