From 25f1e6d57672c7ef5df8924d56e0b239f4b2170f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Fri, 25 May 2018 10:40:27 +0800 Subject: [PATCH] Fix pad and conv dilation issues --- .../tools/converter_tool/tensorflow_converter.py | 7 ++++++- mace/python/tools/converter_tool/transformer.py | 11 +++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 8f05e61c..5c806f41 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -294,7 +294,12 @@ class TensorflowConverter(base_converter.ConverterInterface): if op.type != MaceOp.Deconv2D.name: dilation_arg = op.arg.add() dilation_arg.name = MaceKeyword.mace_dilations_str - dilation_arg.ints.extend(tf_op.get_attr(tf_dilations_str)[1:3]) + try: + dilation_val = tf_op.get_attr(tf_dilations_str)[1:3] + except ValueError: + dilation_val = [1, 1] + + dilation_arg.ints.extend(dilation_val) def convert_elementwise(self, tf_op): op = self.convert_general_op(tf_op) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 0fa5ddd9..e9952ee8 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -713,18 +713,21 @@ class Transformer(base_converter.ConverterInterface): # transpose args if op.type == MaceOp.Pad.name: for arg in op.arg: - if arg.name == MaceKeyword.mace_paddings_str and len( - arg.ints) == 4: + if arg.name == MaceKeyword.mace_paddings_str: + mace_check(len(arg.ints) == 8, + "pad dim rank should be 8.") if ConverterUtil.data_format(op) == DataFormat.NHWC \ and self._target_data_format == DataFormat.NCHW: # noqa print("Transpose pad args: %s(%s)" % (op.name, op.type)) - self.transpose_shape(arg.ints, [0, 3, 1, 2]) + self.transpose_shape(arg.ints, + [0, 1, 6, 7, 2, 3, 4, 5]) elif ConverterUtil.data_format(op) == DataFormat.NCHW \ and self._target_data_format == DataFormat.NHWC: # noqa print("Transpose pad args: %s(%s)" % (op.name, op.type)) - self.transpose_shape(arg.ints, [0, 2, 3, 1]) + self.transpose_shape(arg.ints, + [0, 1, 4, 5, 6, 7, 2, 3]) elif op.type == MaceOp.Concat.name or op.type == MaceOp.Slice.name: for arg in op.arg: if arg.name == MaceKeyword.mace_axis_str: -- GitLab