提交 aa4cd525 编写于 作者: S SunAhong1993

fix

上级 842ea89f
...@@ -205,7 +205,7 @@ class ONNXGraph(Graph): ...@@ -205,7 +205,7 @@ class ONNXGraph(Graph):
shape = raw_input( shape = raw_input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: " "Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
) )
except: except NameError:
shape = input( shape = input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: " "Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
) )
...@@ -302,7 +302,18 @@ class ONNXGraph(Graph): ...@@ -302,7 +302,18 @@ class ONNXGraph(Graph):
if opt == in_node: if opt == in_node:
self.connect(nd.name, layer_name) self.connect(nd.name, layer_name)
flag = 1 flag = 1
node.which_child[nd.name] = idx if nd.name in node.which_child:
for n_i, n_ipt in enumerate(node.inputs):
if first_i == n_i:
continue
if n_ipt == nd.name:
new_nd_name = "{}/{}".format(nd.name, n_i)
if new_nd_name not in node.which_child:
node.which_child[new_nd_name] = idx
break
else:
first_i = node.inputs.index(nd.name)
node.which_child[nd.name] = idx
self.node_map[nd.name].index = 0 self.node_map[nd.name].index = 0
break break
if flag == 1: if flag == 1:
...@@ -318,11 +329,15 @@ class ONNXGraph(Graph): ...@@ -318,11 +329,15 @@ class ONNXGraph(Graph):
if len(node.which_child) == 0: if len(node.which_child) == 0:
ipt_node = super(ONNXGraph, self).get_node(node.inputs[idx], copy) ipt_node = super(ONNXGraph, self).get_node(node.inputs[idx], copy)
return ipt_node return ipt_node
else: else:
ipt_node = super(ONNXGraph, self).get_node(node.inputs[idx], copy) ipt_node = super(ONNXGraph, self).get_node(node.inputs[idx], copy)
if ipt_node.layer_name in node.which_child: new_ipt_name = "{}/{}".format(ipt_node.layer_name, idx)
ipt_node.index = node.which_child[ipt_node.layer_name] if new_ipt_name in node.which_child:
ipt_node.index = node.which_child[new_ipt_name]
else:
if ipt_node.layer_name in node.which_child:
ipt_node.index = node.which_child[ipt_node.layer_name]
return ipt_node return ipt_node
...@@ -556,4 +571,4 @@ class ONNXDecoder(object): ...@@ -556,4 +571,4 @@ class ONNXDecoder(object):
node.input[i] = self.make_variable_name(node.input[i]) node.input[i] = self.make_variable_name(node.input[i])
for i in range(len(node.output)): for i in range(len(node.output)):
node.output[i] = self.make_variable_name(node.output[i]) node.output[i] = self.make_variable_name(node.output[i])
return model return model
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册