未验证 提交 132fc6cf 编写于 作者: S Shichao Zhang 提交者: GitHub

[fix] add scope_name definition (#689)

* [fix] add scope_name definition

* remove unused variables

* Make code style consistent
上级 bb7ca948
...@@ -1241,7 +1241,6 @@ def aten_contiguous(mapper, graph, node): ...@@ -1241,7 +1241,6 @@ def aten_contiguous(mapper, graph, node):
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node) inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
...@@ -1560,7 +1559,6 @@ def aten_detach(mapper, graph, node): ...@@ -1560,7 +1559,6 @@ def aten_detach(mapper, graph, node):
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node) inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
...@@ -1645,6 +1643,7 @@ def aten_div_(mapper, graph, node): ...@@ -1645,6 +1643,7 @@ def aten_div_(mapper, graph, node):
%bx_bw.3 (-): 被除数。 %bx_bw.3 (-): 被除数。
%2678 (int): 除数。 %2678 (int): 除数。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -2726,7 +2725,6 @@ def aten_index_select(mapper, graph, node): ...@@ -2726,7 +2725,6 @@ def aten_index_select(mapper, graph, node):
mapper._check_input(graph, inputs_node[1], inputs_name[1], mapper._check_input(graph, inputs_node[1], inputs_name[1],
current_outputs, scope_name) current_outputs, scope_name)
layer_inputs["axis"] = inputs_name[1] layer_inputs["axis"] = inputs_name[1]
current_inputs.append(inputs_name[1])
# 处理输入2,即%371 # 处理输入2,即%371
mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs, mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs,
scope_name) scope_name)
...@@ -2737,7 +2735,7 @@ def aten_index_select(mapper, graph, node): ...@@ -2737,7 +2735,7 @@ def aten_index_select(mapper, graph, node):
graph.add_layer( graph.add_layer(
"prim.index_select", "prim.index_select",
inputs=layer_inputs, inputs=layer_inputs,
outputs=current_outputs, outputs=layer_outputs,
scope_name=scope_name, scope_name=scope_name,
**layer_attrs) **layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -3308,7 +3306,6 @@ def aten_masked_fill_(mapper, graph, node): ...@@ -3308,7 +3306,6 @@ def aten_masked_fill_(mapper, graph, node):
scope_name = mapper.normalize_scope_name(node) scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node) inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = [] current_inputs = []
...@@ -3418,7 +3415,6 @@ def aten_masked_fill(mapper, graph, node): ...@@ -3418,7 +3415,6 @@ def aten_masked_fill(mapper, graph, node):
scope_name = mapper.normalize_scope_name(node) scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node) inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = [] current_inputs = []
...@@ -4553,7 +4549,6 @@ def aten_rsub(mapper, graph, node): ...@@ -4553,7 +4549,6 @@ def aten_rsub(mapper, graph, node):
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node) inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
...@@ -4593,7 +4588,6 @@ def aten_ScalarImplicit(mapper, graph, node): ...@@ -4593,7 +4588,6 @@ def aten_ScalarImplicit(mapper, graph, node):
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node) inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
...@@ -4651,7 +4645,7 @@ def aten_select(mapper, graph, node): ...@@ -4651,7 +4645,7 @@ def aten_select(mapper, graph, node):
graph.add_layer( graph.add_layer(
"prim.select", "prim.select",
inputs=layer_inputs, inputs=layer_inputs,
outputs=current_outputs, outputs=layer_outputs,
scope_name=scope_name, scope_name=scope_name,
**layer_attrs) **layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -5376,7 +5370,6 @@ def aten_transpose(mapper, graph, node): ...@@ -5376,7 +5370,6 @@ def aten_transpose(mapper, graph, node):
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node) inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册