提交 c1ede468 编写于 作者: L liutuo

fix fold deconv and bn

上级 97355004
......@@ -628,14 +628,13 @@ class Transformer(base_converter.ConverterInterface):
framework = ConverterUtil.get_arg(
op, MaceKeyword.mace_framework_type_str).i
input_len = len(op.input)
if consumer_op.type == MaceOp.BatchNorm.name and \
if consumer_op.type == MaceOp.BatchNorm.name and (
(framework == FrameworkType.CAFFE.value and
(input_len == 2 or
(input_len == 3 and
op.input[-1] in self._consts))) or \
(input_len == 2 or (input_len == 3 and
op.input[-1] in self._consts))) or
(framework == FrameworkType.TENSORFLOW.value and
(input_len == 3 or (input_len == 4 and
op.input[-1] in self._consts))):
op.input[-1] in self._consts)))):
print("Fold deconv and bn: %s(%s)" % (op.name, op.type))
filter = self._consts[op.input[1]]
scale = self._consts[consumer_op.input[1]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册