diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.cc b/paddle/fluid/framework/details/broadcast_op_handle_test.cc index 1413f7bd9ac515ae7dceee62de8f3bc74e3a2efc..ab7412a19fbd13fa39dbae9af528d158cc9ddbd0 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle_test.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle_test.cc @@ -96,8 +96,8 @@ struct TestBroadcastOpHandle { } param_scopes_[input_scope_idx]->Var("input"); - std::unique_ptr n( - new ir::Node("node0", ir::Node::Type::kOperation)); + std::unique_ptr n = + ir::CreateNodeForTest("node0", ir::Node::Type::kOperation); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, @@ -115,8 +115,8 @@ struct TestBroadcastOpHandle { #endif } - std::unique_ptr v( - new ir::Node("node1", ir::Node::Type::kVariable)); + std::unique_ptr v = + ir::CreateNodeForTest("node1", ir::Node::Type::kVariable); auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", gpu_list_[input_scope_idx]); vars_.emplace_back(in_var_handle); @@ -124,8 +124,8 @@ struct TestBroadcastOpHandle { // add dummy var - std::unique_ptr v2( - new ir::Node("node2", ir::Node::Type::kVariable)); + std::unique_ptr v2 = + ir::CreateNodeForTest("node2", ir::Node::Type::kVariable); vars_.emplace_back(new DummyVarHandle(v2.get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); @@ -136,8 +136,8 @@ struct TestBroadcastOpHandle { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - std::unique_ptr v3( - new ir::Node("node3", ir::Node::Type::kVariable)); + std::unique_ptr v3 = + ir::CreateNodeForTest("node3", ir::Node::Type::kVariable); VarHandle* out_var_handle = new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]); vars_.emplace_back(out_var_handle); @@ -145,8 +145,8 @@ struct TestBroadcastOpHandle { } // add dummy var - std::unique_ptr v4( - new ir::Node("node4", ir::Node::Type::kVariable)); + std::unique_ptr v4 = + ir::CreateNodeForTest("node4", ir::Node::Type::kVariable); vars_.emplace_back(new DummyVarHandle(v4.get())); DummyVarHandle* out_dummy_var_handle = static_cast(vars_.back().get()); diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 7606f2bc06b2ecf07c5649eeae1a2d5587a8880c..6e22fedf1c39428528c00cce4c9a4460dfb95cb3 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -54,7 +54,6 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( paddle::framework::FeedFetchList fetches; fetches.resize(fetch_tensors.size()); std::unordered_map> fetched_vars; - std::vector> fetch_nodes; std::vector> fetch_ops; for (auto &fetch_var_name : fetch_tensors) { @@ -75,9 +74,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( auto &vars = fetched_var_it->second; - fetch_nodes.emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); - auto *op = new FetchOpHandle(fetch_nodes.back().get(), &fetches, i, - &local_scopes_); + ir::Node *fetch_node = + graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation); + auto *op = new FetchOpHandle(fetch_node, &fetches, i, &local_scopes_); fetch_ops.emplace_back(op); for (auto &p : places_) { @@ -116,9 +115,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( num_complete += num_comp; } // Wait FetchOps. - if (!fetch_ops.empty()) { - fetch_ops.clear(); - } + ClearFetchOp(graph_.get(), &fetch_ops); return fetches; } void FastThreadedSSAGraphExecutor::RunOpAsync( diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc index c9b94d1e1039df6ff27f9ffe225b2a50c35a5c50..ed67e88ff6a7fe9efd93e5dfd4d7bdf4c43aac2e 100644 --- a/paddle/fluid/framework/details/gather_op_handle_test.cc +++ b/paddle/fluid/framework/details/gather_op_handle_test.cc @@ -82,13 +82,15 @@ struct TestGatherOpHandle { } param_scopes_[input_scope_idx]->Var("out"); - nodes.emplace_back(new ir::Node("node", ir::Node::Type::kOperation)); + nodes.emplace_back( + ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release()); op_handle_.reset( new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); // add input for (size_t j = 0; j < gpu_list_.size(); ++j) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); - nodes.emplace_back(new ir::Node("node1", ir::Node::Type::kVariable)); + nodes.emplace_back( + ir::CreateNodeForTest("node1", ir::Node::Type::kVariable).release()); auto* in_var_handle = new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); vars_.emplace_back(in_var_handle); @@ -96,7 +98,8 @@ struct TestGatherOpHandle { } // add dummy var - nodes.emplace_back(new ir::Node("node2", ir::Node::Type::kVariable)); + nodes.emplace_back( + ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release()); vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* in_dummy_var_handle = static_cast(vars_.back().get()); @@ -104,14 +107,16 @@ struct TestGatherOpHandle { op_handle_->AddInput(in_dummy_var_handle); // add output - nodes.emplace_back(new ir::Node("node3", ir::Node::Type::kVariable)); + nodes.emplace_back( + ir::CreateNodeForTest("node3", ir::Node::Type::kVariable).release()); auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, "out", gpu_list_[input_scope_idx]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); // add dummy var - nodes.emplace_back(new ir::Node("node4", ir::Node::Type::kVariable)); + nodes.emplace_back( + ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release()); vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); diff --git a/paddle/fluid/framework/details/ssa_graph_executor.cc b/paddle/fluid/framework/details/ssa_graph_executor.cc index 09b97bd0d98dc4ad1124dcbc495cff921bf03efc..780da5478ff34ecd7096d0ef62b72bf1088dd221 100644 --- a/paddle/fluid/framework/details/ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/ssa_graph_executor.cc @@ -19,6 +19,19 @@ namespace framework { namespace details { SSAGraphExecutor::~SSAGraphExecutor() {} +void ClearFetchOp(ir::Graph* graph, + std::vector>* fetch_ops) { + if (fetch_ops->empty()) return; + + for (auto& op : *fetch_ops) { + for (auto& out_var : op->Node()->outputs) { + graph->RemoveNode(out_var); + } + graph->RemoveNode(op->Node()); + } + fetch_ops->clear(); +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_executor.h b/paddle/fluid/framework/details/ssa_graph_executor.h index 96fffb7d9430cd00b3823ada9fbe9a65a6bd718c..d5cf7737d565c523995e6685b73c57e5a6f0197b 100644 --- a/paddle/fluid/framework/details/ssa_graph_executor.h +++ b/paddle/fluid/framework/details/ssa_graph_executor.h @@ -18,6 +18,7 @@ #include #include +#include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/ir/graph.h" @@ -36,6 +37,9 @@ class SSAGraphExecutor { virtual FeedFetchList Run(const std::vector& fetch_tensors) = 0; }; + +void ClearFetchOp(ir::Graph* graph, + std::vector>* fetch_ops); } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index c9e331ef359f853263f8dad38dd0a2be4d9618ad..31beef3ae829d72570ee7c879dac71ed600cd216 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -69,12 +69,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // Step 2. Insert FetchOps std::vector> fetch_ops; - std::vector> tmp_nodes; std::unordered_set> fetch_dependencies; FeedFetchList fetch_data(fetch_tensors.size()); - InsertFetchOps(fetch_tensors, &fetch_ops, &tmp_nodes, &fetch_dependencies, - &pending_ops, &pending_vars, &ready_vars, &fetch_data); + InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops, + &pending_vars, &ready_vars, &fetch_data); auto run_all_ops = [&](std::unordered_set &set) { for (auto *op : set) { @@ -136,9 +135,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( PADDLE_ENFORCE(ready_ops.empty()); // Wait FetchOps. - if (!fetch_ops.empty()) { - fetch_ops.clear(); - } + ClearFetchOp(graph_.get(), &fetch_ops); return fetch_data; } @@ -146,7 +143,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( void ThreadedSSAGraphExecutor::InsertFetchOps( const std::vector &fetch_tensors, std::vector> *fetch_ops, - std::vector> *temp_nodes, std::unordered_set> *fetch_dependencies, std::unordered_map *pending_ops, std::unordered_set *pending_vars, @@ -171,9 +167,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( auto &vars = fetched_var_it->second; - temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); - auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i, - &local_scopes_); + ir::Node *fetch_node = + graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation); + auto *op = new FetchOpHandle(fetch_node, fetch_data, i, &local_scopes_); fetch_ops->emplace_back(op); for (auto &p : places_) { @@ -184,8 +180,9 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( op->AddInput(var); } - temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); - auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get()); + ir::Node *fetch_var = + graph_->CreateEmptyNode("fetch", ir::Node::Type::kVariable); + auto *fetch_dummy = new DummyVarHandle(fetch_var); op->AddOutput(fetch_dummy); fetch_dependencies->emplace(fetch_dummy); this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 9135c1f5d435d5e2c60eb90c80803361aa31a3c4..512f8a4ca5a9b82a395dde11722b8db44ea5ec27 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -73,7 +73,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { void InsertFetchOps( const std::vector &fetch_tensors, std::vector> *fetch_ops, - std::vector> *temp_nodes, std::unordered_set> *fetch_dependencies, std::unordered_map *pending_ops, std::unordered_set *pending_vars, diff --git a/paddle/fluid/framework/ir/node.cc b/paddle/fluid/framework/ir/node.cc index 2817fcf5320f00affdcba097681c7ab20f0eb227..9277abe8c1b79c5f76f4610d0554bf337f329518 100644 --- a/paddle/fluid/framework/ir/node.cc +++ b/paddle/fluid/framework/ir/node.cc @@ -19,6 +19,11 @@ namespace framework { namespace ir { constexpr char Node::kControlDepVarName[]; int Node::count_ = 0; + +std::unique_ptr CreateNodeForTest(const std::string& name, + Node::Type type) { + return std::unique_ptr(new Node(name, type)); +} } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index d53d789d3ad27b8f9606a396264d91e5f07a9d10..82ab1f40f3a0eee3874a9715f3acf8fdb2c05ece 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -24,32 +24,12 @@ namespace paddle { namespace framework { namespace ir { +// Node should normally created by Graph::CreateXXXNode(). class Node { public: enum class Type { kOperation, kVariable }; static constexpr char kControlDepVarName[] = "__control_var"; - explicit Node(const std::string& name, Type type) - : name_(name), - var_desc_(nullptr), - op_desc_(nullptr), - type_(type), - id_(count_++) {} - - explicit Node(VarDesc* var_desc) - : name_(var_desc->Name()), - var_desc_(new VarDesc(*var_desc)), - op_desc_(nullptr), - type_(Type::kVariable), - id_(count_++) {} - - explicit Node(OpDesc* op_desc) - : name_(op_desc->Type()), - var_desc_(nullptr), - op_desc_(new OpDesc(*op_desc, op_desc->Block())), - type_(Type::kOperation), - id_(count_++) {} - Type NodeType() const { return type_; } std::string Name() const { return name_; } @@ -81,11 +61,40 @@ class Node { private: friend class Graph; + friend std::unique_ptr CreateNodeForTest(const std::string& name, + Node::Type type); + + explicit Node(const std::string& name, Type type) + : name_(name), + var_desc_(nullptr), + op_desc_(nullptr), + type_(type), + id_(count_++) {} + + explicit Node(VarDesc* var_desc) + : name_(var_desc->Name()), + var_desc_(new VarDesc(*var_desc)), + op_desc_(nullptr), + type_(Type::kVariable), + id_(count_++) {} + + explicit Node(OpDesc* op_desc) + : name_(op_desc->Type()), + var_desc_(nullptr), + op_desc_(new OpDesc(*op_desc, op_desc->Block())), + type_(Type::kOperation), + id_(count_++) {} + + Node() = delete; + static int count_; static void ResetId() { count_ = 0; } DISABLE_COPY_AND_ASSIGN(Node); }; +std::unique_ptr CreateNodeForTest(const std::string& name, + Node::Type type); + } // namespace ir } // namespace framework } // namespace paddle