From 7b5a9af2e0db141581ae6a1aefff449dbf94e708 Mon Sep 17 00:00:00 2001 From: yejianwu Date: Fri, 31 May 2019 15:32:34 +0800 Subject: [PATCH] fix gpu squeeze --- mace/python/tools/converter_tool/base_converter.py | 1 + mace/python/tools/converter_tool/transformer.py | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 1eb858f7..4e49c966 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -196,6 +196,7 @@ MaceTransposableDataFormatOps = [MaceOp.Activation, MaceOp.Reduce, MaceOp.Softmax, MaceOp.Split, + MaceOp.Squeeze, MaceOp.SqrDiffMean] diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index e9559861..c1fd9ee0 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -1431,6 +1431,18 @@ class Transformer(base_converter.ConverterInterface): if axis_arg.i == 1: axis_arg.i = 3 + elif op.type == MaceOp.Squeeze.name: + for arg in op.arg: + if arg.name == MaceKeyword.mace_axis_str: + if (src_data_format == DataFormat.NCHW + and has_data_format + and len(self._producer[op.input[0]].output_shape[0].dims) == 4 # noqa + and len(op.output_shape[0].dims) == 2 + and arg.ints == [2, 3]): + print("Transpose squeeze args: %s(%s)" + % (op.name, op.type)) + arg.ints[:] = [1, 2] + elif op.type == MaceOp.Reduce.name: for arg in op.arg: if arg.name == MaceKeyword.mace_axis_str: -- GitLab