From 0bd7a67eaf67cca487381375b29727f58b938e9d Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 19 Sep 2018 13:11:30 +0800 Subject: [PATCH] avoid creating dangling ir::Node. Node should be created by Graph::CreateXXX so that they are managed by graph. --- .../details/broadcast_op_handle_test.cc | 20 ++++---- .../fast_threaded_ssa_graph_executor.cc | 11 ++-- .../details/gather_op_handle_test.cc | 15 ++++-- .../framework/details/ssa_graph_executor.cc | 13 +++++ .../framework/details/ssa_graph_executor.h | 4 ++ .../details/threaded_ssa_graph_executor.cc | 21 ++++---- .../details/threaded_ssa_graph_executor.h | 1 - paddle/fluid/framework/ir/node.cc | 5 ++ paddle/fluid/framework/ir/node.h | 51 +++++++++++-------- 9 files changed, 85 insertions(+), 56 deletions(-) diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.cc b/paddle/fluid/framework/details/broadcast_op_handle_test.cc index 1413f7bd9ac..ab7412a19fb 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle_test.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle_test.cc @@ -96,8 +96,8 @@ struct TestBroadcastOpHandle { } param_scopes_[input_scope_idx]->Var("input"); - std::unique_ptr n( - new ir::Node("node0", ir::Node::Type::kOperation)); + std::unique_ptr n = + ir::CreateNodeForTest("node0", ir::Node::Type::kOperation); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, @@ -115,8 +115,8 @@ struct TestBroadcastOpHandle { #endif } - std::unique_ptr v( - new ir::Node("node1", ir::Node::Type::kVariable)); + std::unique_ptr v = + ir::CreateNodeForTest("node1", ir::Node::Type::kVariable); auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", gpu_list_[input_scope_idx]); vars_.emplace_back(in_var_handle); @@ -124,8 +124,8 @@ struct TestBroadcastOpHandle { // add dummy var - std::unique_ptr v2( - new ir::Node("node2", ir::Node::Type::kVariable)); + std::unique_ptr v2 = + ir::CreateNodeForTest("node2", ir::Node::Type::kVariable); vars_.emplace_back(new DummyVarHandle(v2.get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); @@ -136,8 +136,8 @@ struct TestBroadcastOpHandle { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - std::unique_ptr v3( - new ir::Node("node3", ir::Node::Type::kVariable)); + std::unique_ptr v3 = + ir::CreateNodeForTest("node3", ir::Node::Type::kVariable); VarHandle* out_var_handle = new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]); vars_.emplace_back(out_var_handle); @@ -145,8 +145,8 @@ struct TestBroadcastOpHandle { } // add dummy var - std::unique_ptr v4( - new ir::Node("node4", ir::Node::Type::kVariable)); + std::unique_ptr v4 = + ir::CreateNodeForTest("node4", ir::Node::Type::kVariable); vars_.emplace_back(new DummyVarHandle(v4.get())); DummyVarHandle* out_dummy_var_handle = static_cast(vars_.back().get()); diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 7606f2bc06b..6e22fedf1c3 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -54,7 +54,6 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( paddle::framework::FeedFetchList fetches; fetches.resize(fetch_tensors.size()); std::unordered_map> fetched_vars; - std::vector> fetch_nodes; std::vector> fetch_ops; for (auto &fetch_var_name : fetch_tensors) { @@ -75,9 +74,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( auto &vars = fetched_var_it->second; - fetch_nodes.emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); - auto *op = new FetchOpHandle(fetch_nodes.back().get(), &fetches, i, - &local_scopes_); + ir::Node *fetch_node = + graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation); + auto *op = new FetchOpHandle(fetch_node, &fetches, i, &local_scopes_); fetch_ops.emplace_back(op); for (auto &p : places_) { @@ -116,9 +115,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( num_complete += num_comp; } // Wait FetchOps. - if (!fetch_ops.empty()) { - fetch_ops.clear(); - } + ClearFetchOp(graph_.get(), &fetch_ops); return fetches; } void FastThreadedSSAGraphExecutor::RunOpAsync( diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc index c9b94d1e103..ed67e88ff6a 100644 --- a/paddle/fluid/framework/details/gather_op_handle_test.cc +++ b/paddle/fluid/framework/details/gather_op_handle_test.cc @@ -82,13 +82,15 @@ struct TestGatherOpHandle { } param_scopes_[input_scope_idx]->Var("out"); - nodes.emplace_back(new ir::Node("node", ir::Node::Type::kOperation)); + nodes.emplace_back( + ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release()); op_handle_.reset( new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); // add input for (size_t j = 0; j < gpu_list_.size(); ++j) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); - nodes.emplace_back(new ir::Node("node1", ir::Node::Type::kVariable)); + nodes.emplace_back( + ir::CreateNodeForTest("node1", ir::Node::Type::kVariable).release()); auto* in_var_handle = new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); vars_.emplace_back(in_var_handle); @@ -96,7 +98,8 @@ struct TestGatherOpHandle { } // add dummy var - nodes.emplace_back(new ir::Node("node2", ir::Node::Type::kVariable)); + nodes.emplace_back( + ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release()); vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* in_dummy_var_handle = static_cast(vars_.back().get()); @@ -104,14 +107,16 @@ struct TestGatherOpHandle { op_handle_->AddInput(in_dummy_var_handle); // add output - nodes.emplace_back(new ir::Node("node3", ir::Node::Type::kVariable)); + nodes.emplace_back( + ir::CreateNodeForTest("node3", ir::Node::Type::kVariable).release()); auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, "out", gpu_list_[input_scope_idx]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); // add dummy var - nodes.emplace_back(new ir::Node("node4", ir::Node::Type::kVariable)); + nodes.emplace_back( + ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release()); vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); diff --git a/paddle/fluid/framework/details/ssa_graph_executor.cc b/paddle/fluid/framework/details/ssa_graph_executor.cc index 09b97bd0d98..780da5478ff 100644 --- a/paddle/fluid/framework/details/ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/ssa_graph_executor.cc @@ -19,6 +19,19 @@ namespace framework { namespace details { SSAGraphExecutor::~SSAGraphExecutor() {} +void ClearFetchOp(ir::Graph* graph, + std::vector>* fetch_ops) { + if (fetch_ops->empty()) return; + + for (auto& op : *fetch_ops) { + for (auto& out_var : op->Node()->outputs) { + graph->RemoveNode(out_var); + } + graph->RemoveNode(op->Node()); + } + fetch_ops->clear(); +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_executor.h b/paddle/fluid/framework/details/ssa_graph_executor.h index 96fffb7d943..d5cf7737d56 100644 --- a/paddle/fluid/framework/details/ssa_graph_executor.h +++ b/paddle/fluid/framework/details/ssa_graph_executor.h @@ -18,6 +18,7 @@ #include #include +#include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/ir/graph.h" @@ -36,6 +37,9 @@ class SSAGraphExecutor { virtual FeedFetchList Run(const std::vector& fetch_tensors) = 0; }; + +void ClearFetchOp(ir::Graph* graph, + std::vector>* fetch_ops); } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index c9e331ef359..31beef3ae82 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -69,12 +69,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // Step 2. Insert FetchOps std::vector> fetch_ops; - std::vector> tmp_nodes; std::unordered_set> fetch_dependencies; FeedFetchList fetch_data(fetch_tensors.size()); - InsertFetchOps(fetch_tensors, &fetch_ops, &tmp_nodes, &fetch_dependencies, - &pending_ops, &pending_vars, &ready_vars, &fetch_data); + InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops, + &pending_vars, &ready_vars, &fetch_data); auto run_all_ops = [&](std::unordered_set &set) { for (auto *op : set) { @@ -136,9 +135,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( PADDLE_ENFORCE(ready_ops.empty()); // Wait FetchOps. - if (!fetch_ops.empty()) { - fetch_ops.clear(); - } + ClearFetchOp(graph_.get(), &fetch_ops); return fetch_data; } @@ -146,7 +143,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( void ThreadedSSAGraphExecutor::InsertFetchOps( const std::vector &fetch_tensors, std::vector> *fetch_ops, - std::vector> *temp_nodes, std::unordered_set> *fetch_dependencies, std::unordered_map *pending_ops, std::unordered_set *pending_vars, @@ -171,9 +167,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( auto &vars = fetched_var_it->second; - temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); - auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i, - &local_scopes_); + ir::Node *fetch_node = + graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation); + auto *op = new FetchOpHandle(fetch_node, fetch_data, i, &local_scopes_); fetch_ops->emplace_back(op); for (auto &p : places_) { @@ -184,8 +180,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( op->AddInput(var); } - temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); - auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get()); + ir::Node *fetch_var = + graph_->CreateEmptyNode("fetch", ir::Node::Type::kVariable); + auto *fetch_dummy = new DummyVarHandle(fetch_var); op->AddOutput(fetch_dummy); fetch_dependencies->emplace(fetch_dummy); this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 9135c1f5d43..512f8a4ca5a 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -73,7 +73,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { void InsertFetchOps( const std::vector &fetch_tensors, std::vector> *fetch_ops, - std::vector> *temp_nodes, std::unordered_set> *fetch_dependencies, std::unordered_map *pending_ops, std::unordered_set *pending_vars, diff --git a/paddle/fluid/framework/ir/node.cc b/paddle/fluid/framework/ir/node.cc index 2817fcf5320..9277abe8c1b 100644 --- a/paddle/fluid/framework/ir/node.cc +++ b/paddle/fluid/framework/ir/node.cc @@ -19,6 +19,11 @@ namespace framework { namespace ir { constexpr char Node::kControlDepVarName[]; int Node::count_ = 0; + +std::unique_ptr CreateNodeForTest(const std::string& name, + Node::Type type) { + return std::unique_ptr(new Node(name, type)); +} } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index d53d789d3ad..82ab1f40f3a 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -24,32 +24,12 @@ namespace paddle { namespace framework { namespace ir { +// Node should normally created by Graph::CreateXXXNode(). class Node { public: enum class Type { kOperation, kVariable }; static constexpr char kControlDepVarName[] = "__control_var"; - explicit Node(const std::string& name, Type type) - : name_(name), - var_desc_(nullptr), - op_desc_(nullptr), - type_(type), - id_(count_++) {} - - explicit Node(VarDesc* var_desc) - : name_(var_desc->Name()), - var_desc_(new VarDesc(*var_desc)), - op_desc_(nullptr), - type_(Type::kVariable), - id_(count_++) {} - - explicit Node(OpDesc* op_desc) - : name_(op_desc->Type()), - var_desc_(nullptr), - op_desc_(new OpDesc(*op_desc, op_desc->Block())), - type_(Type::kOperation), - id_(count_++) {} - Type NodeType() const { return type_; } std::string Name() const { return name_; } @@ -81,11 +61,40 @@ class Node { private: friend class Graph; + friend std::unique_ptr CreateNodeForTest(const std::string& name, + Node::Type type); + + explicit Node(const std::string& name, Type type) + : name_(name), + var_desc_(nullptr), + op_desc_(nullptr), + type_(type), + id_(count_++) {} + + explicit Node(VarDesc* var_desc) + : name_(var_desc->Name()), + var_desc_(new VarDesc(*var_desc)), + op_desc_(nullptr), + type_(Type::kVariable), + id_(count_++) {} + + explicit Node(OpDesc* op_desc) + : name_(op_desc->Type()), + var_desc_(nullptr), + op_desc_(new OpDesc(*op_desc, op_desc->Block())), + type_(Type::kOperation), + id_(count_++) {} + + Node() = delete; + static int count_; static void ResetId() { count_ = 0; } DISABLE_COPY_AND_ASSIGN(Node); }; +std::unique_ptr CreateNodeForTest(const std::string& name, + Node::Type type); + } // namespace ir } // namespace framework } // namespace paddle -- GitLab