From 1882ffd5de151c24b41899186832ec98d9878ffc Mon Sep 17 00:00:00 2001 From: pangyoki Date: Wed, 20 Jul 2022 14:18:32 +0800 Subject: [PATCH] transfer block_id to CreateVarNode in multi_devices_graph_pass (#44366) * fix CreateVarNode in multi_devices_graph_pass * Revert "Fix var duplication bug for graph_to_program_pass (#44278)" This reverts commit a2c4c86b1c139ca8242355d673b78e746d189f54. --- paddle/fluid/framework/ir/graph_helper.cc | 15 +------------- .../multi_devices_graph_pass.cc | 20 +++++++++++-------- 2 files changed, 13 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index b0a2b6754c..97f486065a 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -579,12 +579,6 @@ void GraphToProgram(const Graph &graph, VLOG(3) << "Graph to program need convert " << graph.SubGraphsSize() << " sub graph"; - - std::unordered_set vars_in_root_block; - for (const proto::VarDesc &var : block->vars()) { - vars_in_root_block.insert(var.name()); - } - for (size_t idx = 0; idx < graph.SubGraphsSize(); ++idx) { // avoid kRootBlockIndex not 0 if (idx == kRootBlockIndex) continue; @@ -592,14 +586,7 @@ void GraphToProgram(const Graph &graph, block = program_pb.add_blocks(); block->set_idx(idx); block->set_parent_idx(kRootBlockIndex); - - Graph *subgraph = graph.GetSubGraph(idx); - subgraph->SetNotOwned>( - kGraphToProgramVarsToRemove, &vars_in_root_block); - - GraphToBlock(*subgraph, block, sort_kind); - - subgraph->Erase(kGraphToProgramVarsToRemove); + GraphToBlock(*graph.GetSubGraph(idx), block, sort_kind); } } else { GraphToBlock(graph, block, sort_kind); diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc index f1d13c23b1..54657fbcda 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc @@ -111,11 +111,12 @@ details::VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, details::VarHandle *var = nullptr; if (var_holder.empty()) { if (node->Var()) { - var = new details::VarHandle(graph->CreateVarNode(node->Var()), - 0, - place_offset, - node->Name(), - place); + var = new details::VarHandle( + graph->CreateVarNode(node->Var(), node->GetVarNodeBlockId()), + 0, + place_offset, + node->Name(), + place); } else { var = new details::VarHandle( graph->CreateEmptyNode(node->Name(), ir::Node::Type::kVariable), @@ -376,7 +377,8 @@ void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result, for (ir::Node *output : node->outputs) { ir::Node *new_node = nullptr; if (output->Var()) { - new_node = result->CreateVarNode(output->Var()); + new_node = + result->CreateVarNode(output->Var(), output->GetVarNodeBlockId()); } else { new_node = result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable); @@ -696,7 +698,8 @@ void MultiDevSSAGraphBuilderBase::CreateScaleLossGradOp( CreateOpOutput(result, op_handle, - result->CreateVarNode(out_var_node->Var()), + result->CreateVarNode(out_var_node->Var(), + out_var_node->GetVarNodeBlockId()), places_[i], i); } @@ -1225,7 +1228,8 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { p = places_[outvar_dev_id]; ir::Node *new_node = nullptr; if (output->Var()) { - new_node = result->CreateVarNode(output->Var()); + new_node = + result->CreateVarNode(output->Var(), output->GetVarNodeBlockId()); } else { new_node = result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable); -- GitLab