From ecdd1166b80627b652b948d6b8b317307ce0afb0 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 14 Feb 2019 16:44:09 +0800 Subject: [PATCH] cleanup code test=develop --- .../framework/details/parallel_ssa_graph_executor.cc | 8 ++++---- paddle/fluid/framework/ir/graph.h | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc index a7cb9adbb..77a3318ff 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc @@ -41,14 +41,14 @@ std::vector> SeparateMultiDevicesGraph( auto &dev_ops = graphs[dev_id]->Get(kGraphOps); auto &dev_dummys = graphs[dev_id]->Get(kGraphDepVars); dev_ops.emplace_back(op); - graphs[dev_id]->AddNode(graph->ReleaseNode(op->Node()).release()); + graphs[dev_id]->AddNode(graph->RemoveNode(op->Node()).release()); for (auto &var : op->Inputs()) { auto dummy_ptr = dynamic_cast(var); if (dummy_ptr) { dev_dummys.insert(var); if (graph->Nodes().count(var->Node())) - graphs[dev_id]->AddNode(graph->ReleaseNode(var->Node()).release()); + graphs[dev_id]->AddNode(graph->RemoveNode(var->Node()).release()); } } for (auto &var : op->Outputs()) { @@ -56,7 +56,7 @@ std::vector> SeparateMultiDevicesGraph( if (dummy_ptr) { dev_dummys.insert(var); if (graph->Nodes().count(var->Node())) - graphs[dev_id]->AddNode(graph->ReleaseNode(var->Node()).release()); + graphs[dev_id]->AddNode(graph->RemoveNode(var->Node()).release()); } } #else @@ -72,7 +72,7 @@ std::vector> SeparateMultiDevicesGraph( for (auto &version_pair : name_pair.second) { if (graph->Nodes().count(version_pair->Node())) { graphs[dev_id]->AddNode( - graph->ReleaseNode(version_pair->Node()).release()); + graph->RemoveNode(version_pair->Node()).release()); } } } diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 40baae2ff..b55a77451 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -168,7 +168,8 @@ class Graph { return ret; } - std::unique_ptr ReleaseNode(ir::Node *node) { + std::unique_ptr RemoveNode(ir::Node *node) { + PADDLE_ENFORCE(node_set_.find(node) != node_set_.end()); std::unique_ptr ret; ret.reset(nodes_.at(node).release()); nodes_.erase(node); -- GitLab