提交 ecdd1166 编写于 作者: Y Yancey1989

cleanup code test=develop

上级 73005ee0
...@@ -41,14 +41,14 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph( ...@@ -41,14 +41,14 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
auto &dev_ops = graphs[dev_id]->Get<GraphOps>(kGraphOps); auto &dev_ops = graphs[dev_id]->Get<GraphOps>(kGraphOps);
auto &dev_dummys = graphs[dev_id]->Get<GraphDepVars>(kGraphDepVars); auto &dev_dummys = graphs[dev_id]->Get<GraphDepVars>(kGraphDepVars);
dev_ops.emplace_back(op); dev_ops.emplace_back(op);
graphs[dev_id]->AddNode(graph->ReleaseNode(op->Node()).release()); graphs[dev_id]->AddNode(graph->RemoveNode(op->Node()).release());
for (auto &var : op->Inputs()) { for (auto &var : op->Inputs()) {
auto dummy_ptr = dynamic_cast<DummyVarHandle *>(var); auto dummy_ptr = dynamic_cast<DummyVarHandle *>(var);
if (dummy_ptr) { if (dummy_ptr) {
dev_dummys.insert(var); dev_dummys.insert(var);
if (graph->Nodes().count(var->Node())) if (graph->Nodes().count(var->Node()))
graphs[dev_id]->AddNode(graph->ReleaseNode(var->Node()).release()); graphs[dev_id]->AddNode(graph->RemoveNode(var->Node()).release());
} }
} }
for (auto &var : op->Outputs()) { for (auto &var : op->Outputs()) {
...@@ -56,7 +56,7 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph( ...@@ -56,7 +56,7 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
if (dummy_ptr) { if (dummy_ptr) {
dev_dummys.insert(var); dev_dummys.insert(var);
if (graph->Nodes().count(var->Node())) if (graph->Nodes().count(var->Node()))
graphs[dev_id]->AddNode(graph->ReleaseNode(var->Node()).release()); graphs[dev_id]->AddNode(graph->RemoveNode(var->Node()).release());
} }
} }
#else #else
...@@ -72,7 +72,7 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph( ...@@ -72,7 +72,7 @@ std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
if (graph->Nodes().count(version_pair->Node())) { if (graph->Nodes().count(version_pair->Node())) {
graphs[dev_id]->AddNode( graphs[dev_id]->AddNode(
graph->ReleaseNode(version_pair->Node()).release()); graph->RemoveNode(version_pair->Node()).release());
} }
} }
} }
......
...@@ -168,7 +168,8 @@ class Graph { ...@@ -168,7 +168,8 @@ class Graph {
return ret; return ret;
} }
std::unique_ptr<ir::Node> ReleaseNode(ir::Node *node) { std::unique_ptr<ir::Node> RemoveNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
std::unique_ptr<ir::Node> ret; std::unique_ptr<ir::Node> ret;
ret.reset(nodes_.at(node).release()); ret.reset(nodes_.at(node).release());
nodes_.erase(node); nodes_.erase(node);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册