diff --git a/tools/python/transform/transformer.py b/tools/python/transform/transformer.py index bbc94f5cee0ba9dd609f25e224d8740b2adb0d0d..881c5357b4e21545d9a0a759eae839a5a66b05a4 100644 --- a/tools/python/transform/transformer.py +++ b/tools/python/transform/transformer.py @@ -1411,24 +1411,37 @@ class Transformer(base_converter.ConverterInterface): return False def is_transposable_data_format_ops(self, op): + transposable = op.type in MaceTransposableDataFormatOps + if op.type == MaceOp.Reshape: input_op = self._producer[op.input[0]] input_dims = input_op.output_shape[0].dims output_dims = op.output_shape[0].dims - tranposable = True if len(input_op.output_shape) != 1 or \ len(input_dims) != 4 or len(output_dims) != 4: - tranposable = False + transposable = False else: in_b, in_h, in_w, in_c = self.sort_feature_map_shape( input_dims, ConverterUtil.data_format(input_op)) ou_b, ou_h, ou_w, ou_c = self.sort_feature_map_shape( output_dims, ConverterUtil.data_format(op)) - tranposable = (in_b == ou_b and in_c == ou_c) - if not tranposable: - print("In this model, reshape is not transposable op.") - return tranposable - return op.type in MaceTransposableDataFormatOps + transposable = (in_b == ou_b and in_c == ou_c) + elif op.type == MaceOp.Squeeze: + input_dims = self._producer[op.input[0]].output_shape[0].dims + output_dims = op.output_shape[0].dims + src_df = ConverterUtil.data_format(self._model) + arg = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str) + if len(input_dims) == 4 and len(output_dims) == 2 and \ + ((src_df == DataFormat.NCHW and arg.ints == [2, 3]) or + (src_df == DataFormat.NHWC and arg.ints == [1, 2])): + transposable = True + else: + transposable = False + + if op.type in MaceTransposableDataFormatOps and not transposable: + print("%s(%s) is not a transposable op in this model." + % (op.name, op.type)) + return transposable def update_data_format(self): print("update data format")