diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 8f05e61c02d496716889f57609fc69c0f642e184..5c806f411661e41ec9b6502617b537dd037a3f67 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -294,7 +294,12 @@ class TensorflowConverter(base_converter.ConverterInterface): if op.type != MaceOp.Deconv2D.name: dilation_arg = op.arg.add() dilation_arg.name = MaceKeyword.mace_dilations_str - dilation_arg.ints.extend(tf_op.get_attr(tf_dilations_str)[1:3]) + try: + dilation_val = tf_op.get_attr(tf_dilations_str)[1:3] + except ValueError: + dilation_val = [1, 1] + + dilation_arg.ints.extend(dilation_val) def convert_elementwise(self, tf_op): op = self.convert_general_op(tf_op) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 0fa5ddd967f026757e886a39f2e84f5a63c975bf..e9952ee806445582f7245340ae310f617fcc7744 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -713,18 +713,21 @@ class Transformer(base_converter.ConverterInterface): # transpose args if op.type == MaceOp.Pad.name: for arg in op.arg: - if arg.name == MaceKeyword.mace_paddings_str and len( - arg.ints) == 4: + if arg.name == MaceKeyword.mace_paddings_str: + mace_check(len(arg.ints) == 8, + "pad dim rank should be 8.") if ConverterUtil.data_format(op) == DataFormat.NHWC \ and self._target_data_format == DataFormat.NCHW: # noqa print("Transpose pad args: %s(%s)" % (op.name, op.type)) - self.transpose_shape(arg.ints, [0, 3, 1, 2]) + self.transpose_shape(arg.ints, + [0, 1, 6, 7, 2, 3, 4, 5]) elif ConverterUtil.data_format(op) == DataFormat.NCHW \ and self._target_data_format == DataFormat.NHWC: # noqa print("Transpose pad args: %s(%s)" % (op.name, op.type)) - self.transpose_shape(arg.ints, [0, 2, 3, 1]) + self.transpose_shape(arg.ints, + [0, 1, 4, 5, 6, 7, 2, 3]) elif op.type == MaceOp.Concat.name or op.type == MaceOp.Slice.name: for arg in op.arg: if arg.name == MaceKeyword.mace_axis_str: