From c1ede468328b2108d6e2d91534e78218c9fe520c Mon Sep 17 00:00:00 2001 From: liutuo Date: Mon, 28 Jan 2019 17:29:26 +0800 Subject: [PATCH] fix fold deconv and bn --- mace/python/tools/converter_tool/transformer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index cb79eec2..cd8a94ab 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -628,14 +628,13 @@ class Transformer(base_converter.ConverterInterface): framework = ConverterUtil.get_arg( op, MaceKeyword.mace_framework_type_str).i input_len = len(op.input) - if consumer_op.type == MaceOp.BatchNorm.name and \ + if consumer_op.type == MaceOp.BatchNorm.name and ( (framework == FrameworkType.CAFFE.value and - (input_len == 2 or - (input_len == 3 and - op.input[-1] in self._consts))) or \ + (input_len == 2 or (input_len == 3 and + op.input[-1] in self._consts))) or (framework == FrameworkType.TENSORFLOW.value and (input_len == 3 or (input_len == 4 and - op.input[-1] in self._consts))): + op.input[-1] in self._consts)))): print("Fold deconv and bn: %s(%s)" % (op.name, op.type)) filter = self._consts[op.input[1]] scale = self._consts[consumer_op.input[1]] -- GitLab