diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index fe5b08d6599b849beb2da0d05a7aee9785006166..1a9b7993e4656dbbe901727db148b7be247b2b46 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -1192,7 +1192,11 @@ class Transformer(base_converter.ConverterInterface): and self._producer[op.input[0]].type \ == MaceOp.Reshape.name \ and len(op.output_shape[0].dims) == 2: - should_fold = True + producer = self._producer[op.input[0]] + reshape_input_rank = len(self.get_tensor_shape( + producer.input[0])) + if reshape_input_rank == 4: + should_fold = True if should_fold: print(