diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc index a7cb9adbbf631b3dfb877dcbbd74563400326466..77a3318ff9e7bb2024dd081ede3bde63fc8ac8eb 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc @@ -41,14 +41,14 @@ std::vector> SeparateMultiDevicesGraph( auto &dev_ops = graphs[dev_id]->Get(kGraphOps); auto &dev_dummys = graphs[dev_id]->Get(kGraphDepVars); dev_ops.emplace_back(op); - graphs[dev_id]->AddNode(graph->ReleaseNode(op->Node()).release()); + graphs[dev_id]->AddNode(graph->RemoveNode(op->Node()).release()); for (auto &var : op->Inputs()) { auto dummy_ptr = dynamic_cast(var); if (dummy_ptr) { dev_dummys.insert(var); if (graph->Nodes().count(var->Node())) - graphs[dev_id]->AddNode(graph->ReleaseNode(var->Node()).release()); + graphs[dev_id]->AddNode(graph->RemoveNode(var->Node()).release()); } } for (auto &var : op->Outputs()) { @@ -56,7 +56,7 @@ std::vector> SeparateMultiDevicesGraph( if (dummy_ptr) { dev_dummys.insert(var); if (graph->Nodes().count(var->Node())) - graphs[dev_id]->AddNode(graph->ReleaseNode(var->Node()).release()); + graphs[dev_id]->AddNode(graph->RemoveNode(var->Node()).release()); } } #else @@ -72,7 +72,7 @@ std::vector> SeparateMultiDevicesGraph( for (auto &version_pair : name_pair.second) { if (graph->Nodes().count(version_pair->Node())) { graphs[dev_id]->AddNode( - graph->ReleaseNode(version_pair->Node()).release()); + graph->RemoveNode(version_pair->Node()).release()); } } } diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 40baae2ffdd6f2d31901bf2c44a3e3fb6d8ad329..b55a77451371b2dc3764eedea33126204b5d0997 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -168,7 +168,8 @@ class Graph { return ret; } - std::unique_ptr ReleaseNode(ir::Node *node) { + std::unique_ptr RemoveNode(ir::Node *node) { + PADDLE_ENFORCE(node_set_.find(node) != node_set_.end()); std::unique_ptr ret; ret.reset(nodes_.at(node).release()); nodes_.erase(node);