提交 86e20086 编写于 作者: S SunAhong1993

fix the onnx bug

上级 377d9b8d
...@@ -311,7 +311,6 @@ class ONNXGraph(Graph): ...@@ -311,7 +311,6 @@ class ONNXGraph(Graph):
if new_nd_name not in node.which_child: if new_nd_name not in node.which_child:
node.which_child[new_nd_name] = idx node.which_child[new_nd_name] = idx
break break
print(node.which_child)
else: else:
first_i = node.inputs.index(nd.name) first_i = node.inputs.index(nd.name)
node.which_child[nd.name] = idx node.which_child[nd.name] = idx
...@@ -334,13 +333,10 @@ class ONNXGraph(Graph): ...@@ -334,13 +333,10 @@ class ONNXGraph(Graph):
ipt_node = super(ONNXGraph, self).get_node(node.inputs[idx], copy) ipt_node = super(ONNXGraph, self).get_node(node.inputs[idx], copy)
new_ipt_name = "{}/{}".format(ipt_node.layer_name, idx) new_ipt_name = "{}/{}".format(ipt_node.layer_name, idx)
if new_ipt_name in node.which_child: if new_ipt_name in node.which_child:
print(new_ipt_name)
ipt_node.index = node.which_child[new_ipt_name] ipt_node.index = node.which_child[new_ipt_name]
print("ipt_node.index", ipt_node.index)
else: else:
if ipt_node.layer_name in node.which_child: if ipt_node.layer_name in node.which_child:
ipt_node.index = node.which_child[ipt_node.layer_name] ipt_node.index = node.which_child[ipt_node.layer_name]
print("ipt_node.index", ipt_node.index)
return ipt_node return ipt_node
......
...@@ -255,10 +255,16 @@ class OpSet9(): ...@@ -255,10 +255,16 @@ class OpSet9():
if len(node.layer.input) == 2: if len(node.layer.input) == 2:
# opset 10 # opset 10
val_scales = self.graph.get_input_node(node, idx=1, copy=True) val_scales = self.graph.get_input_node(node, idx=1, copy=True)
inputs['scale_factor'] = val_scales.name # TODO(syf): paddle.nn.functional.interpolate will support the length
# which is the same as the rank of input.
# inputs['scale_factor'] = val_scales.name
attrs['scale_factor'] = self.weights[val_scales.name].tolist()[2:]
elif len(node.layer.input) == 3: elif len(node.layer.input) == 3:
# opset 11 # opset 11
val_scales = self.graph.get_input_node(node, idx=2, copy=True) val_scales = self.graph.get_input_node(node, idx=2, copy=True)
# TODO(syf): paddle.nn.functional.interpolate will support the length
# which is the same as the rank of input.
# inputs['scale_factor'] = val_scales.name
attrs['scale_factor'] = self.weights[val_scales.name].tolist()[2:] attrs['scale_factor'] = self.weights[val_scales.name].tolist()[2:]
elif len(node.layer.input) == 4: elif len(node.layer.input) == 4:
# opset 11 # opset 11
...@@ -291,7 +297,7 @@ class OpSet9(): ...@@ -291,7 +297,7 @@ class OpSet9():
return return
elif node.layer_type == 'Upsample': elif node.layer_type == 'Upsample':
val_scales = self.graph.get_input_node(node, idx=1, copy=True) val_scales = self.graph.get_input_node(node, idx=1, copy=True)
inputs['scale'] = val_scales inputs['scale_factor'] = val_scales
mode = node.get_attr('mode', 'nearest') mode = node.get_attr('mode', 'nearest')
attrs.update({"align_corners": False, attrs.update({"align_corners": False,
...@@ -1097,10 +1103,6 @@ class OpSet9(): ...@@ -1097,10 +1103,6 @@ class OpSet9():
dtypes = set() dtypes = set()
for i in range(len(node.layer.input)): for i in range(len(node.layer.input)):
ipt = self.graph.get_input_node(node, idx=i, copy=True) ipt = self.graph.get_input_node(node, idx=i, copy=True)
try:
print(ipt.index)
except:
pass
inputs_list.append(ipt.name) inputs_list.append(ipt.name)
dtypes.add(ipt.dtype) dtypes.add(ipt.dtype)
if len(dtypes) > 1: if len(dtypes) > 1:
......
...@@ -245,10 +245,16 @@ class OpSet9(): ...@@ -245,10 +245,16 @@ class OpSet9():
if len(node.layer.input) == 2: if len(node.layer.input) == 2:
# opset 10 # opset 10
val_scales = self.graph.get_input_node(node, idx=1, copy=True) val_scales = self.graph.get_input_node(node, idx=1, copy=True)
inputs['scale_factor'] = val_scales.name # TODO(syf): paddle.nn.functional.interpolate will support the length
# which is the same as the rank of input.
# inputs['scale_factor'] = val_scales.name
attrs['scale_factor'] = self.params[val_scales.name].tolist()[2:]
elif len(node.layer.input) == 3: elif len(node.layer.input) == 3:
# opset 11 # opset 11
val_scales = self.graph.get_input_node(node, idx=2, copy=True) val_scales = self.graph.get_input_node(node, idx=2, copy=True)
# TODO(syf): paddle.nn.functional.interpolate will support the length
# which is the same as the rank of input.
# inputs['scale_factor'] = val_scales.name
attrs['scale_factor'] = self.params[val_scales.name].tolist()[2:] attrs['scale_factor'] = self.params[val_scales.name].tolist()[2:]
elif len(node.layer.input) == 4: elif len(node.layer.input) == 4:
# opset 11 # opset 11
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册