diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index cb79eec2bd81dd85f1a0233a654e07bc8a87b243..cd8a94ab3db165204f1acd13b12701fc33877b0b 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -628,14 +628,13 @@ class Transformer(base_converter.ConverterInterface): framework = ConverterUtil.get_arg( op, MaceKeyword.mace_framework_type_str).i input_len = len(op.input) - if consumer_op.type == MaceOp.BatchNorm.name and \ + if consumer_op.type == MaceOp.BatchNorm.name and ( (framework == FrameworkType.CAFFE.value and - (input_len == 2 or - (input_len == 3 and - op.input[-1] in self._consts))) or \ + (input_len == 2 or (input_len == 3 and + op.input[-1] in self._consts))) or (framework == FrameworkType.TENSORFLOW.value and (input_len == 3 or (input_len == 4 and - op.input[-1] in self._consts))): + op.input[-1] in self._consts)))): print("Fold deconv and bn: %s(%s)" % (op.name, op.type)) filter = self._consts[op.input[1]] scale = self._consts[consumer_op.input[1]]