From 19089a52b542ed9cd28390d34138dc1372da0dd6 Mon Sep 17 00:00:00 2001 From: Bin Li Date: Wed, 15 Jul 2020 15:54:07 +0800 Subject: [PATCH] Fix ONNX output shape bug --- mace/core/memory_optimizer.cc | 9 +++-- tools/python/transform/onnx_converter.py | 4 ++- tools/python/transform/transformer.py | 45 ++++++++++++++---------- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/mace/core/memory_optimizer.cc b/mace/core/memory_optimizer.cc index a6aacc2c..38948b5c 100644 --- a/mace/core/memory_optimizer.cc +++ b/mace/core/memory_optimizer.cc @@ -114,9 +114,12 @@ void MemoryOptimizer::Optimize( const mace::OperatorDef *op_def, const std::unordered_map *mem_types) { MACE_LATENCY_LOGGER(2, "Optimize memory"); - MACE_CHECK(op_def->output_size() == op_def->output_shape_size(), - op_def->name(), "The number of output shapes is" - " not equal to the number of outputs"); + if (op_def->output_size() != op_def->output_shape_size()) { + VLOG(1) << op_def->name() + << ": the number of output shape " + << "is not equal to the number of output"; + return; + } auto device = static_cast(op_def->device_type()); DataType op_dtype = static_cast(ProtoArgHelper::GetOptionalArg( diff --git a/tools/python/transform/onnx_converter.py b/tools/python/transform/onnx_converter.py index 8f0b485b..e4217a53 100644 --- a/tools/python/transform/onnx_converter.py +++ b/tools/python/transform/onnx_converter.py @@ -38,7 +38,7 @@ import numpy as np import onnx import onnx.utils -from onnx import mapping, numpy_helper, TensorProto +from onnx import mapping, numpy_helper, shape_inference, TensorProto from numbers import Number IS_PYTHON3 = sys.version_info > (3,) @@ -420,6 +420,8 @@ class OnnxConverter(base_converter.ConverterInterface): onnx.checker.check_model(onnx_model) + onnx_model = shape_inference.infer_shapes(onnx_model) + self._isKaldi = False polish_available = True diff --git a/tools/python/transform/transformer.py b/tools/python/transform/transformer.py index 4a27465b..c35973a0 100644 --- a/tools/python/transform/transformer.py +++ b/tools/python/transform/transformer.py @@ -358,7 +358,7 @@ class Transformer(base_converter.ConverterInterface): self.safe_remove_node(op, self._producer.get(op.input[0], None)) return True - elif op.type == 'Reshape' and \ + elif op.type == 'Reshape' and len(op.output_shape) == 1 and \ op.output_shape[0].dims == \ self.get_tensor_shape(op.input[0]): print("Remove useless reshape: %s(%s)" % (op.name, op.type)) @@ -1417,28 +1417,35 @@ class Transformer(base_converter.ConverterInterface): if op.type == MaceOp.Reshape: input_op = self._producer[op.input[0]] - input_dims = input_op.output_shape[0].dims - output_dims = op.output_shape[0].dims - if len(input_op.output_shape) != 1 or \ - len(input_dims) != 4 or len(output_dims) != 4: + if len(input_op.output_shape) == 0 or len(op.output_shape) == 0: transposable = False else: - in_b, in_h, in_w, in_c = self.sort_feature_map_shape( - input_dims, ConverterUtil.data_format(input_op)) - ou_b, ou_h, ou_w, ou_c = self.sort_feature_map_shape( - output_dims, ConverterUtil.data_format(op)) - transposable = (in_b == ou_b and in_c == ou_c) + input_dims = input_op.output_shape[0].dims + output_dims = op.output_shape[0].dims + if len(input_op.output_shape) != 1 or \ + len(input_dims) != 4 or len(output_dims) != 4: + transposable = False + else: + in_b, in_h, in_w, in_c = self.sort_feature_map_shape( + input_dims, ConverterUtil.data_format(input_op)) + ou_b, ou_h, ou_w, ou_c = self.sort_feature_map_shape( + output_dims, ConverterUtil.data_format(op)) + transposable = (in_b == ou_b and in_c == ou_c) elif op.type == MaceOp.Squeeze: - input_dims = self._producer[op.input[0]].output_shape[0].dims - output_dims = op.output_shape[0].dims - src_df = ConverterUtil.data_format(self._model) - arg = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str) - if len(input_dims) == 4 and len(output_dims) == 2 and \ - ((src_df == DataFormat.NCHW and arg.ints == [2, 3]) or - (src_df == DataFormat.NHWC and arg.ints == [1, 2])): - transposable = True - else: + input_op = self._producer[op.input[0]] + if len(input_op.output_shape) == 0 or len(op.output_shape) == 0: transposable = False + else: + input_dims = input_op.output_shape[0].dims + output_dims = op.output_shape[0].dims + src_df = ConverterUtil.data_format(self._model) + arg = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str) + if len(input_dims) == 4 and len(output_dims) == 2 and \ + ((src_df == DataFormat.NCHW and arg.ints == [2, 3]) or + (src_df == DataFormat.NHWC and arg.ints == [1, 2])): + transposable = True + else: + transposable = False if op.type in MaceTransposableDataFormatOps and not transposable: print("%s(%s) is not a transposable op in this model." -- GitLab