提交 fc4d2464 编写于 作者: C Channingss

fix bug

上级 20571503
...@@ -556,12 +556,6 @@ class PaddleOpMapper(object): ...@@ -556,12 +556,6 @@ class PaddleOpMapper(object):
'SizeTensor' in input_names and 'SizeTensor' in input_names and
len(op.input('SizeTensor')) > 0): len(op.input('SizeTensor')) > 0):
node_list = list() node_list = list()
roi_node = self.make_constant_node(
self.get_name(op.type, 'roi'), onnx_pb.TensorProto.FLOAT,
[1, 1, 1, 1, 1, 1, 1, 1])
roi_name = self.get_name(op.type, 'roi')
roi_node = self.make_constant_node(
roi_name, onnx_pb.TensorProto.FLOAT, [1, 1, 1, 1, 1, 1, 1, 1])
empty_name = self.get_name(op.type, 'empty') empty_name = self.get_name(op.type, 'empty')
empty_tensor = helper.make_tensor( empty_tensor = helper.make_tensor(
empty_name, empty_name,
...@@ -584,16 +578,8 @@ class PaddleOpMapper(object): ...@@ -584,16 +578,8 @@ class PaddleOpMapper(object):
'Slice', 'Slice',
inputs=[shape_name0, starts_name, ends_name], inputs=[shape_name0, starts_name, ends_name],
outputs=[shape_name1]) outputs=[shape_name1])
node_list.extend([ node_list.extend(
roi_node, empty_node, shape_node0, starts_node, ends_node, [empty_node, shape_node0, starts_node, ends_node, shape_node1])
shape_node1
])
# shape_name2 = self.get_name(op.type, "shape.cast")
# shape_node2 = helper.make_node(
# 'Cast',
# inputs=op.input('OutSize'),
# outputs=[shape_name2],
# to=onnx_pb.TensorProto.INT64)
if 'OutSize' in input_names and len(op.input('OutSize')) > 0: if 'OutSize' in input_names and len(op.input('OutSize')) > 0:
cast_shape_name = self.get_name(op.type, "shape.cast") cast_shape_name = self.get_name(op.type, "shape.cast")
cast_shape_node = helper.make_node( cast_shape_node = helper.make_node(
...@@ -603,7 +589,8 @@ class PaddleOpMapper(object): ...@@ -603,7 +589,8 @@ class PaddleOpMapper(object):
to=onnx_pb.TensorProto.INT64) to=onnx_pb.TensorProto.INT64)
node_list.append(cast_shape_node) node_list.append(cast_shape_node)
else: else:
concat_shape_name = op.output('Out')[0] + "@shape.concat1" concat_shape_name = self.get_name(
op.type, op.output('Out')[0] + "shape.concat")
concat_shape_node = helper.make_node( concat_shape_node = helper.make_node(
"Concat", "Concat",
inputs=op.input('SizeTensor'), inputs=op.input('SizeTensor'),
...@@ -616,7 +603,7 @@ class PaddleOpMapper(object): ...@@ -616,7 +603,7 @@ class PaddleOpMapper(object):
outputs=[cast_shape_name], outputs=[cast_shape_name],
to=onnx_pb.TensorProto.INT64) to=onnx_pb.TensorProto.INT64)
node_list.extend([concat_shape_node, cast_shape_node]) node_list.extend([concat_shape_node, cast_shape_node])
shape_name3 = op.output('Out')[0] + "@shape.concat3" shape_name3 = self.get_name(op.type, "shape.concat")
shape_node3 = helper.make_node( shape_node3 = helper.make_node(
'Concat', 'Concat',
inputs=[shape_name1, cast_shape_name], inputs=[shape_name1, cast_shape_name],
...@@ -635,24 +622,24 @@ class PaddleOpMapper(object): ...@@ -635,24 +622,24 @@ class PaddleOpMapper(object):
vals=[height, width], vals=[height, width],
raw=False)) raw=False))
node_list.append(node_h_w) node_list.append(node_h_w)
outputs_h_w_scales = op.output('Out')[0] + "@out_hw_scales"
node_h_w_scales = helper.make_node(
'Div',
inputs=[shape_name3, name_h_w],
outputs=[outputs_h_w_scales])
node_list.append(node_h_w_scales)
shape_name4 = op.output('Out')[0] + "@shape.concat4" shape_name4 = op.output('Out')[0] + "@shape.concat4"
shape_node4 = helper.make_node( shape_node4 = helper.make_node(
'Concat', 'Concat',
inputs=[shape_name1, outputs_h_w_scales], inputs=[shape_name1, name_h_w],
outputs=[shape_name4], outputs=[shape_name4],
axis=0) axis=0)
node_list.append(shape_node4) node_list.append(shape_node4)
outputs_h_w_scales = op.output('Out')[0] + "@out_hw_scales"
node_h_w_scales = helper.make_node(
'Div',
inputs=[shape_name3, shape_name4],
outputs=[outputs_h_w_scales])
node_list.append(node_h_w_scales)
result_node = helper.make_node( result_node = helper.make_node(
'Resize', 'Resize',
inputs=[op.input('X')[0], shape_name4], inputs=[op.input('X')[0], outputs_h_w_scales],
outputs=op.output('Out'), outputs=op.output('Out'),
mode='linear', ) mode='linear')
node_list.extend([result_node]) node_list.extend([result_node])
return node_list return node_list
elif 'Scale' in input_names and len(op.input('Scale')) > 0: elif 'Scale' in input_names and len(op.input('Scale')) > 0:
...@@ -670,18 +657,12 @@ class PaddleOpMapper(object): ...@@ -670,18 +657,12 @@ class PaddleOpMapper(object):
scale_node = self.make_constant_node(scale_name, scale_node = self.make_constant_node(scale_name,
onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.FLOAT,
[1, 1, scale, scale]) [1, 1, scale, scale])
roi_name = self.get_name(op.type, 'roi')
roi_node = self.make_constant_node(roi_name,
onnx_pb.TensorProto.FLOAT,
[1, 1, 1, 1, 1, 1, 1, 1])
node = helper.make_node( node = helper.make_node(
'Resize', 'Resize',
inputs=[op.input('X')[0], roi_name, scale_name], inputs=[op.input('X')[0], scale_name],
outputs=op.output('Out'), outputs=op.output('Out'),
mode='nearest', mode='nearest')
coordinate_transformation_mode=coordinate_transformation_mode return [scale_node, node]
)
return [scale_node, roi_node, node]
else: else:
raise Exception("Unexpected situation happend") raise Exception("Unexpected situation happend")
return node return node
...@@ -713,18 +694,14 @@ class PaddleOpMapper(object): ...@@ -713,18 +694,14 @@ class PaddleOpMapper(object):
scale_node = self.make_constant_node(scale_name, scale_node = self.make_constant_node(scale_name,
onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.FLOAT,
[1, 1, scale, scale]) [1, 1, scale, scale])
roi_name = self.get_name(op.type, 'roi')
roi_node = self.make_constant_node(roi_name,
onnx_pb.TensorProto.FLOAT,
[1, 1, 1, 1, 1, 1, 1, 1])
node = helper.make_node( node = helper.make_node(
'Resize', 'Resize',
inputs=[op.input('X')[0], roi_name, scale_name], inputs=[op.input('X')[0], scale_name],
outputs=op.output('Out'), outputs=op.output('Out'),
mode='nearest', mode='nearest',
coordinate_transformation_mode=coordinate_transformation_mode coordinate_transformation_mode=coordinate_transformation_mode
) )
return [scale_node, roi_node, node] return [scale_node, node]
else: else:
raise Exception("Unexpected situation happend") raise Exception("Unexpected situation happend")
return node return node
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册