未验证 提交 1882ffd5 编写于 作者: P pangyoki 提交者: GitHub

transfer block_id to CreateVarNode in multi_devices_graph_pass (#44366)

* fix CreateVarNode in multi_devices_graph_pass

* Revert "Fix var duplication bug for graph_to_program_pass (#44278)"

This reverts commit a2c4c86b.
上级 54c7dfa6
......@@ -579,12 +579,6 @@ void GraphToProgram(const Graph &graph,
VLOG(3) << "Graph to program need convert " << graph.SubGraphsSize()
<< " sub graph";
std::unordered_set<std::string> vars_in_root_block;
for (const proto::VarDesc &var : block->vars()) {
vars_in_root_block.insert(var.name());
}
for (size_t idx = 0; idx < graph.SubGraphsSize(); ++idx) {
// avoid kRootBlockIndex not 0
if (idx == kRootBlockIndex) continue;
......@@ -592,14 +586,7 @@ void GraphToProgram(const Graph &graph,
block = program_pb.add_blocks();
block->set_idx(idx);
block->set_parent_idx(kRootBlockIndex);
Graph *subgraph = graph.GetSubGraph(idx);
subgraph->SetNotOwned<std::unordered_set<std::string>>(
kGraphToProgramVarsToRemove, &vars_in_root_block);
GraphToBlock(*subgraph, block, sort_kind);
subgraph->Erase(kGraphToProgramVarsToRemove);
GraphToBlock(*graph.GetSubGraph(idx), block, sort_kind);
}
} else {
GraphToBlock(graph, block, sort_kind);
......
......@@ -111,11 +111,12 @@ details::VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph,
details::VarHandle *var = nullptr;
if (var_holder.empty()) {
if (node->Var()) {
var = new details::VarHandle(graph->CreateVarNode(node->Var()),
0,
place_offset,
node->Name(),
place);
var = new details::VarHandle(
graph->CreateVarNode(node->Var(), node->GetVarNodeBlockId()),
0,
place_offset,
node->Name(),
place);
} else {
var = new details::VarHandle(
graph->CreateEmptyNode(node->Name(), ir::Node::Type::kVariable),
......@@ -376,7 +377,8 @@ void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result,
for (ir::Node *output : node->outputs) {
ir::Node *new_node = nullptr;
if (output->Var()) {
new_node = result->CreateVarNode(output->Var());
new_node =
result->CreateVarNode(output->Var(), output->GetVarNodeBlockId());
} else {
new_node =
result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
......@@ -696,7 +698,8 @@ void MultiDevSSAGraphBuilderBase::CreateScaleLossGradOp(
CreateOpOutput(result,
op_handle,
result->CreateVarNode(out_var_node->Var()),
result->CreateVarNode(out_var_node->Var(),
out_var_node->GetVarNodeBlockId()),
places_[i],
i);
}
......@@ -1225,7 +1228,8 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
p = places_[outvar_dev_id];
ir::Node *new_node = nullptr;
if (output->Var()) {
new_node = result->CreateVarNode(output->Var());
new_node =
result->CreateVarNode(output->Var(), output->GetVarNodeBlockId());
} else {
new_node =
result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册