提交 3dfc41b8 编写于 作者: L liyin

Disable BN fold if conv weights are shared

上级 b159cb1f
......@@ -623,9 +623,9 @@ class Transformer(base_converter.ConverterInterface):
and self.consumer_count(op.output[0]) == 1:
consumer_op = self._consumers[op.output[0]][0]
input_len = len(op.input)
if consumer_op.type == MaceOp.BatchNorm.name and \
(input_len == 2 or
(input_len == 3 and op.input[-1] in self._consts)):
if (consumer_op.type == MaceOp.BatchNorm.name
and (input_len == 2 or (input_len == 3 and op.input[-1] in self._consts)) # noqa
and len(self._consumers[op.input[1]]) == 1):
print("Fold conv and bn: %s(%s)" % (op.name, op.type))
filter = self._consts[op.input[1]]
scale = self._consts[consumer_op.input[1]]
......@@ -678,13 +678,14 @@ class Transformer(base_converter.ConverterInterface):
framework = ConverterUtil.get_arg(
op, MaceKeyword.mace_framework_type_str).i
input_len = len(op.input)
if consumer_op.type == MaceOp.BatchNorm.name and (
if (consumer_op.type == MaceOp.BatchNorm.name and (
(framework == FrameworkType.CAFFE.value and
(input_len == 2 or (input_len == 3 and
op.input[-1] in self._consts))) or
(framework == FrameworkType.TENSORFLOW.value and
(input_len == 3 or (input_len == 4 and
op.input[-1] in self._consts)))):
op.input[-1] in self._consts))))
and len(self._consumers[op.input[1]]) == 1):
print("Fold deconv and bn: %s(%s)" % (op.name, op.type))
filter = self._consts[op.input[1]]
scale = self._consts[consumer_op.input[1]]
......@@ -745,9 +746,9 @@ class Transformer(base_converter.ConverterInterface):
and self.consumer_count(op.output[0]) == 1:
consumer_op = self._consumers[op.output[0]][0]
input_len = len(op.input)
if consumer_op.type == MaceOp.BatchNorm.name and \
(input_len == 2 or
(input_len == 3 and op.input[-1] in self._consts)):
if (consumer_op.type == MaceOp.BatchNorm.name
and (input_len == 2 or (input_len == 3 and op.input[-1] in self._consts)) # noqa
and len(self._consumers[op.input[1]]) == 1):
print("Fold depthwise conv and bn: %s(%s)"
% (op.name, op.type))
filter = self._consts[op.input[1]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册