提交 f558f281 编写于 作者: C Channingss

fix bug of Resize(opset11)

上级 cafb4c66
...@@ -93,16 +93,13 @@ class OpSet11(OpSet10): ...@@ -93,16 +93,13 @@ class OpSet11(OpSet10):
else: else:
coordinate_transformation_mode = 'half_pixel' coordinate_transformation_mode = 'half_pixel'
roi_name = self.get_name(op.type, 'roi')
roi_node = self.make_constant_node(roi_name, onnx_pb.TensorProto.FLOAT,
[1, 1, 1, 1, 1, 1, 1, 1])
if ('OutSize' in input_names and len(op.input('OutSize')) > 0) or ( if ('OutSize' in input_names and len(op.input('OutSize')) > 0) or (
'SizeTensor' in input_names and 'SizeTensor' in input_names and
len(op.input('SizeTensor')) > 0): len(op.input('SizeTensor')) > 0):
node_list = list() node_list = list()
roi_node = self.make_constant_node(
self.get_name(op.type, 'roi'), onnx_pb.TensorProto.FLOAT,
[1, 1, 1, 1, 1, 1, 1, 1])
roi_name = self.get_name(op.type, 'roi')
roi_node = self.make_constant_node(
roi_name, onnx_pb.TensorProto.FLOAT, [1, 1, 1, 1, 1, 1, 1, 1])
empty_name = self.get_name(op.type, 'empty') empty_name = self.get_name(op.type, 'empty')
empty_tensor = helper.make_tensor( empty_tensor = helper.make_tensor(
empty_name, empty_name,
...@@ -168,7 +165,7 @@ class OpSet11(OpSet10): ...@@ -168,7 +165,7 @@ class OpSet11(OpSet10):
elif 'Scale' in input_names and len(op.input('Scale')) > 0: elif 'Scale' in input_names and len(op.input('Scale')) > 0:
node = helper.make_node( node = helper.make_node(
'Resize', 'Resize',
inputs=[op.input('X')[0], op.input('Scale')[0]], inputs=[op.input('X')[0], roi_name, op.input('Scale')[0]],
outputs=op.output('Out'), outputs=op.output('Out'),
mode='linear', mode='linear',
coordinate_transformation_mode=coordinate_transformation_mode) coordinate_transformation_mode=coordinate_transformation_mode)
...@@ -180,10 +177,6 @@ class OpSet11(OpSet10): ...@@ -180,10 +177,6 @@ class OpSet11(OpSet10):
scale_node = self.make_constant_node(scale_name, scale_node = self.make_constant_node(scale_name,
onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.FLOAT,
[1, 1, scale, scale]) [1, 1, scale, scale])
roi_name = self.get_name(op.type, 'roi')
roi_node = self.make_constant_node(roi_name,
onnx_pb.TensorProto.FLOAT,
[1, 1, 1, 1, 1, 1, 1, 1])
node = helper.make_node( node = helper.make_node(
'Resize', 'Resize',
inputs=[op.input('X')[0], roi_name, scale_name], inputs=[op.input('X')[0], roi_name, scale_name],
...@@ -194,7 +187,7 @@ class OpSet11(OpSet10): ...@@ -194,7 +187,7 @@ class OpSet11(OpSet10):
return [scale_node, roi_node, node] return [scale_node, roi_node, node]
else: else:
raise Exception("Unexpected situation happend") raise Exception("Unexpected situation happend")
return node return [roi_node, node]
def nearest_interp(self, op, block): def nearest_interp(self, op, block):
input_names = op.input_names input_names = op.input_names
...@@ -204,17 +197,20 @@ class OpSet11(OpSet10): ...@@ -204,17 +197,20 @@ class OpSet11(OpSet10):
coordinate_transformation_mode = 'align_corners' coordinate_transformation_mode = 'align_corners'
else: else:
coordinate_transformation_mode = 'asymmetric' coordinate_transformation_mode = 'asymmetric'
roi_name = self.get_name(op.type, 'roi')
roi_node = self.make_constant_node(roi_name, onnx_pb.TensorProto.FLOAT,
[1, 1, 1, 1, 1, 1, 1, 1])
if 'OutSize' in input_names and len(op.input('OutSize')) > 0: if 'OutSize' in input_names and len(op.input('OutSize')) > 0:
node = helper.make_node( node = helper.make_node(
'Resize', 'Resize',
inputs=[op.input('X')[0], '', op.input('OutSize')[0]], inputs=[op.input('X')[0], roi_name, op.input('OutSize')[0]],
outputs=op.output('Out'), outputs=op.output('Out'),
mode='nearest', mode='nearest',
coordinate_transformation_mode=coordinate_transformation_mode) coordinate_transformation_mode=coordinate_transformation_mode)
elif 'Scale' in input_names and len(op.input('Scale')) > 0: elif 'Scale' in input_names and len(op.input('Scale')) > 0:
node = helper.make_node( node = helper.make_node(
'Resize', 'Resize',
inputs=[op.input('X')[0], op.input('Scale')[0]], inputs=[op.input('X')[0], roi_name, op.input('Scale')[0]],
outputs=op.output('Out'), outputs=op.output('Out'),
mode='nearest', mode='nearest',
coordinate_transformation_mode=coordinate_transformation_mode) coordinate_transformation_mode=coordinate_transformation_mode)
...@@ -226,10 +222,6 @@ class OpSet11(OpSet10): ...@@ -226,10 +222,6 @@ class OpSet11(OpSet10):
scale_node = self.make_constant_node(scale_name, scale_node = self.make_constant_node(scale_name,
onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.FLOAT,
[1, 1, scale, scale]) [1, 1, scale, scale])
roi_name = self.get_name(op.type, 'roi')
roi_node = self.make_constant_node(roi_name,
onnx_pb.TensorProto.FLOAT,
[1, 1, 1, 1, 1, 1, 1, 1])
node = helper.make_node( node = helper.make_node(
'Resize', 'Resize',
inputs=[op.input('X')[0], roi_name, scale_name], inputs=[op.input('X')[0], roi_name, scale_name],
...@@ -240,7 +232,7 @@ class OpSet11(OpSet10): ...@@ -240,7 +232,7 @@ class OpSet11(OpSet10):
return [scale_node, roi_node, node] return [scale_node, roi_node, node]
else: else:
raise Exception("Unexpected situation happend") raise Exception("Unexpected situation happend")
return node return [roi_node, node]
def hard_swish(self, op, block): def hard_swish(self, op, block):
min_name = self.get_name(op.type, 'min') min_name = self.get_name(op.type, 'min')
......
...@@ -562,6 +562,15 @@ class OpSet9(object): ...@@ -562,6 +562,15 @@ class OpSet9(object):
keepdims=op.attr('keep_dim')) keepdims=op.attr('keep_dim'))
return node return node
def cast(self, op, block):
dtype = op.attr('out_dtype')
node = helper.make_node(
'Cast',
inputs=op.input('X'),
outputs=op.output('Out'),
to=self.paddle_onnx_dtype_map[dtype])
return node
def bilinear_interp(self, op, block): def bilinear_interp(self, op, block):
input_names = op.input_names input_names = op.input_names
input_shape = block.vars[op.input('X')[0]].shape input_shape = block.vars[op.input('X')[0]].shape
...@@ -673,7 +682,7 @@ class OpSet9(object): ...@@ -673,7 +682,7 @@ class OpSet9(object):
input_names = op.input_names input_names = op.input_names
if op.attr('align_corners'): if op.attr('align_corners'):
raise Exception( raise Exception(
"Resize in onnx(opset<=10) only support coordinate_transformation_mode: 'asymmetric', Try converting with --onnx_opest 11" "Resize in onnx(opset<=10) only support coordinate_transformation_mode: 'asymmetric', Try converting with --onnx_opset 11"
) )
if 'OutSize' in input_names and len(op.input('OutSize')) > 0: if 'OutSize' in input_names and len(op.input('OutSize')) > 0:
node = helper.make_node( node = helper.make_node(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册