提交 1a795788 编写于 作者: A Alexander Buslaev 提交者: LiuTuo

Add support of upsample operator in onnx convertor (#539)

* add support of upsample operator in onnx convertor

* flake8 fixes

* fix cutting of unnecessary inputs
上级 94ced0b9
...@@ -180,7 +180,7 @@ OnnxSupportedOps = [ ...@@ -180,7 +180,7 @@ OnnxSupportedOps = [
# 'TopK', # 'TopK',
'Transpose', 'Transpose',
'Unsqueeze', 'Unsqueeze',
# 'Upsample', 'Upsample',
# 'Xor', # 'Xor',
] ]
...@@ -392,6 +392,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -392,6 +392,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.TargetRMSNorm: self.convert_target_rms_norm, OnnxOpType.TargetRMSNorm: self.convert_target_rms_norm,
OnnxOpType.Transpose.name: self.convert_transpose, OnnxOpType.Transpose.name: self.convert_transpose,
OnnxOpType.Unsqueeze.name: self.convert_unsqueeze, OnnxOpType.Unsqueeze.name: self.convert_unsqueeze,
OnnxOpType.Upsample.name: self.convert_upsample
} }
self._option = option self._option = option
self._mace_net_def = mace_pb2.NetDef() self._mace_net_def = mace_pb2.NetDef()
...@@ -454,7 +455,13 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -454,7 +455,13 @@ class OnnxConverter(base_converter.ConverterInterface):
tensor.name = name tensor.name = name
tensor.dims.extend(list(shape)) tensor.dims.extend(list(shape))
tensor.data_type = data_type tensor.data_type = data_type
tensor.float_data.extend(value.flat)
if tensor.data_type == mace_pb2.DT_INT32:
tensor.int32_data.extend(value.astype(np.int32).flat)
elif tensor.data_type == mace_pb2.DT_FLOAT:
tensor.float_data.extend(value.astype(np.float32).flat)
else:
mace_check(False, "Not supported tensor type: %s" % name)
def run(self): def run(self):
graph_def = self._onnx_model.graph graph_def = self._onnx_model.graph
...@@ -1533,3 +1540,26 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -1533,3 +1540,26 @@ class OnnxConverter(base_converter.ConverterInterface):
offset_arg = op.arg.add() offset_arg = op.arg.add()
offset_arg.name = 'offset' offset_arg.name = 'offset'
offset_arg.i = offset offset_arg.i = offset
def convert_upsample(self, node):
op = self.convert_general_op(node)
del op.input[1:] # cut all unnecessary inputs (onnx>=1.5)
output_size = self._graph_shapes_dict[op.output[0]]
output_size = np.array(output_size[-2:]).astype(np.int32)
if node.attrs['mode'] == 'nearest':
op.type = MaceOp.ResizeNearestNeighbor.name
size_tensor_name = op.name + ":size"
self.add_tensor(size_tensor_name, output_size.shape,
mace_pb2.DT_INT32, output_size)
op.input.append(size_tensor_name)
else:
op.type = MaceOp.ResizeBilinear.name
size_arg = op.arg.add()
size_arg.name = MaceKeyword.mace_resize_size_str
size_arg.ints.extend(output_size.tolist())
align_corners_arg = op.arg.add()
align_corners_arg.name = MaceKeyword.mace_align_corners_str
align_corners_arg.i = node.attrs.get('align_corners', 0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册