diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index dc69238513976500097196235a6e07315bef10c5..33d4633635528b94a3d8d0ed108398368572a36c 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -628,14 +628,13 @@ class Transformer(base_converter.ConverterInterface): framework = ConverterUtil.get_arg( op, MaceKeyword.mace_framework_type_str).i input_len = len(op.input) - if consumer_op.type == MaceOp.BatchNorm.name and \ + if consumer_op.type == MaceOp.BatchNorm.name and ( (framework == FrameworkType.CAFFE.value and - (input_len == 2 or - (input_len == 3 and - op.input[-1] in self._consts))) or \ + (input_len == 2 or (input_len == 3 and + op.input[-1] in self._consts))) or (framework == FrameworkType.TENSORFLOW.value and (input_len == 3 or (input_len == 4 and - op.input[-1] in self._consts))): + op.input[-1] in self._consts)))): print("Fold deconv and bn: %s(%s)" % (op.name, op.type)) filter = self._consts[op.input[1]] scale = self._consts[consumer_op.input[1]]