diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 1eb858f7263f9bf25030ff183dcd350b387b3896..4e49c96601a2828d3b17b160aa4a4f169746ac2c 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -196,6 +196,7 @@ MaceTransposableDataFormatOps = [MaceOp.Activation, MaceOp.Reduce, MaceOp.Softmax, MaceOp.Split, + MaceOp.Squeeze, MaceOp.SqrDiffMean] diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index e9559861220df330ad55459577b6bbf8ce301e38..c1fd9ee0323daf4be13565855ed8204cf3446ca0 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -1431,6 +1431,18 @@ class Transformer(base_converter.ConverterInterface): if axis_arg.i == 1: axis_arg.i = 3 + elif op.type == MaceOp.Squeeze.name: + for arg in op.arg: + if arg.name == MaceKeyword.mace_axis_str: + if (src_data_format == DataFormat.NCHW + and has_data_format + and len(self._producer[op.input[0]].output_shape[0].dims) == 4 # noqa + and len(op.output_shape[0].dims) == 2 + and arg.ints == [2, 3]): + print("Transpose squeeze args: %s(%s)" + % (op.name, op.type)) + arg.ints[:] = [1, 2] + elif op.type == MaceOp.Reduce.name: for arg in op.arg: if arg.name == MaceKeyword.mace_axis_str: