“897789b16e754aa1c1a5131cae08bff35d477508”上不存在“paddle/fluid/lite/kernels/x86/scale_compute.cc”
提交 3dfc41b8 编写于 作者: L liyin

Disable BN fold if conv weights are shared

上级 b159cb1f
...@@ -623,9 +623,9 @@ class Transformer(base_converter.ConverterInterface): ...@@ -623,9 +623,9 @@ class Transformer(base_converter.ConverterInterface):
and self.consumer_count(op.output[0]) == 1: and self.consumer_count(op.output[0]) == 1:
consumer_op = self._consumers[op.output[0]][0] consumer_op = self._consumers[op.output[0]][0]
input_len = len(op.input) input_len = len(op.input)
if consumer_op.type == MaceOp.BatchNorm.name and \ if (consumer_op.type == MaceOp.BatchNorm.name
(input_len == 2 or and (input_len == 2 or (input_len == 3 and op.input[-1] in self._consts)) # noqa
(input_len == 3 and op.input[-1] in self._consts)): and len(self._consumers[op.input[1]]) == 1):
print("Fold conv and bn: %s(%s)" % (op.name, op.type)) print("Fold conv and bn: %s(%s)" % (op.name, op.type))
filter = self._consts[op.input[1]] filter = self._consts[op.input[1]]
scale = self._consts[consumer_op.input[1]] scale = self._consts[consumer_op.input[1]]
...@@ -678,13 +678,14 @@ class Transformer(base_converter.ConverterInterface): ...@@ -678,13 +678,14 @@ class Transformer(base_converter.ConverterInterface):
framework = ConverterUtil.get_arg( framework = ConverterUtil.get_arg(
op, MaceKeyword.mace_framework_type_str).i op, MaceKeyword.mace_framework_type_str).i
input_len = len(op.input) input_len = len(op.input)
if consumer_op.type == MaceOp.BatchNorm.name and ( if (consumer_op.type == MaceOp.BatchNorm.name and (
(framework == FrameworkType.CAFFE.value and (framework == FrameworkType.CAFFE.value and
(input_len == 2 or (input_len == 3 and (input_len == 2 or (input_len == 3 and
op.input[-1] in self._consts))) or op.input[-1] in self._consts))) or
(framework == FrameworkType.TENSORFLOW.value and (framework == FrameworkType.TENSORFLOW.value and
(input_len == 3 or (input_len == 4 and (input_len == 3 or (input_len == 4 and
op.input[-1] in self._consts)))): op.input[-1] in self._consts))))
and len(self._consumers[op.input[1]]) == 1):
print("Fold deconv and bn: %s(%s)" % (op.name, op.type)) print("Fold deconv and bn: %s(%s)" % (op.name, op.type))
filter = self._consts[op.input[1]] filter = self._consts[op.input[1]]
scale = self._consts[consumer_op.input[1]] scale = self._consts[consumer_op.input[1]]
...@@ -745,9 +746,9 @@ class Transformer(base_converter.ConverterInterface): ...@@ -745,9 +746,9 @@ class Transformer(base_converter.ConverterInterface):
and self.consumer_count(op.output[0]) == 1: and self.consumer_count(op.output[0]) == 1:
consumer_op = self._consumers[op.output[0]][0] consumer_op = self._consumers[op.output[0]][0]
input_len = len(op.input) input_len = len(op.input)
if consumer_op.type == MaceOp.BatchNorm.name and \ if (consumer_op.type == MaceOp.BatchNorm.name
(input_len == 2 or and (input_len == 2 or (input_len == 3 and op.input[-1] in self._consts)) # noqa
(input_len == 3 and op.input[-1] in self._consts)): and len(self._consumers[op.input[1]]) == 1):
print("Fold depthwise conv and bn: %s(%s)" print("Fold depthwise conv and bn: %s(%s)"
% (op.name, op.type)) % (op.name, op.type))
filter = self._consts[op.input[1]] filter = self._consts[op.input[1]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册