From 3dfc41b83429681cadc607e77e3340faedaff76c Mon Sep 17 00:00:00 2001 From: liyin Date: Mon, 8 Jul 2019 16:11:49 +0800 Subject: [PATCH] Disable BN fold if conv weights are shared --- mace/python/tools/converter_tool/transformer.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index f4d461ed..41f32c6d 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -623,9 +623,9 @@ class Transformer(base_converter.ConverterInterface): and self.consumer_count(op.output[0]) == 1: consumer_op = self._consumers[op.output[0]][0] input_len = len(op.input) - if consumer_op.type == MaceOp.BatchNorm.name and \ - (input_len == 2 or - (input_len == 3 and op.input[-1] in self._consts)): + if (consumer_op.type == MaceOp.BatchNorm.name + and (input_len == 2 or (input_len == 3 and op.input[-1] in self._consts)) # noqa + and len(self._consumers[op.input[1]]) == 1): print("Fold conv and bn: %s(%s)" % (op.name, op.type)) filter = self._consts[op.input[1]] scale = self._consts[consumer_op.input[1]] @@ -678,13 +678,14 @@ class Transformer(base_converter.ConverterInterface): framework = ConverterUtil.get_arg( op, MaceKeyword.mace_framework_type_str).i input_len = len(op.input) - if consumer_op.type == MaceOp.BatchNorm.name and ( + if (consumer_op.type == MaceOp.BatchNorm.name and ( (framework == FrameworkType.CAFFE.value and (input_len == 2 or (input_len == 3 and op.input[-1] in self._consts))) or (framework == FrameworkType.TENSORFLOW.value and (input_len == 3 or (input_len == 4 and - op.input[-1] in self._consts)))): + op.input[-1] in self._consts)))) + and len(self._consumers[op.input[1]]) == 1): print("Fold deconv and bn: %s(%s)" % (op.name, op.type)) filter = self._consts[op.input[1]] scale = self._consts[consumer_op.input[1]] @@ -745,9 +746,9 @@ class Transformer(base_converter.ConverterInterface): and self.consumer_count(op.output[0]) == 1: consumer_op = self._consumers[op.output[0]][0] input_len = len(op.input) - if consumer_op.type == MaceOp.BatchNorm.name and \ - (input_len == 2 or - (input_len == 3 and op.input[-1] in self._consts)): + if (consumer_op.type == MaceOp.BatchNorm.name + and (input_len == 2 or (input_len == 3 and op.input[-1] in self._consts)) # noqa + and len(self._consumers[op.input[1]]) == 1): print("Fold depthwise conv and bn: %s(%s)" % (op.name, op.type)) filter = self._consts[op.input[1]] -- GitLab