From 107b956a5b5643537b45948750d0f58c7ca320f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8F=B6=E5=89=91=E6=AD=A6?= Date: Sat, 19 Oct 2019 16:46:10 +0800 Subject: [PATCH] fix pad converter and support padv2 --- tools/python/transform/tensorflow_converter.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tools/python/transform/tensorflow_converter.py b/tools/python/transform/tensorflow_converter.py index dfb41a25..1ad82f21 100644 --- a/tools/python/transform/tensorflow_converter.py +++ b/tools/python/transform/tensorflow_converter.py @@ -102,6 +102,7 @@ TFSupportedOps = [ 'DepthToSpace', 'SpaceToDepth', 'Pad', + 'PadV2', 'ConcatV2', 'Mean', 'Prod', @@ -252,6 +253,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.DepthToSpace.name: self.convert_space_depth, TFOpType.SpaceToDepth.name: self.convert_space_depth, TFOpType.Pad.name: self.convert_pad, + TFOpType.PadV2.name: self.convert_pad, TFOpType.ConcatV2.name: self.convert_concat, TFOpType.Const.name: self.convert_nop, TFOpType.Gather.name: self.convert_gather, @@ -785,13 +787,19 @@ class TensorflowConverter(base_converter.ConverterInterface): pad_type_arg = op.arg.add() pad_type_arg.name = MaceKeyword.mace_pad_type_str - if tf_op.type == TFOpType.Pad: + if tf_op.type == TFOpType.Pad or tf_op.type == TFOpType.PadV2: if len(tf_op.inputs) == 3: constant_value_arg = op.arg.add() constant_value_arg.name = MaceKeyword.mace_constant_value_str - constant_value = tf_op.inputs[2].eval().astype(np.int32) \ - .flat[0] - constant_value_arg.i = constant_value + constant_value = tf_op.inputs[2].eval().flat[0] + tf_dt = tf_op.inputs[2].dtype + if tf_dt == tf.float32: + constant_value_arg.f = constant_value + elif tf_dt == tf.int32: + constant_value_arg.i = constant_value + else: + mace_check(False, + "Unsupported data type: %s" % tf_dt.name) self._skip_tensor.add(tf_op.inputs[2].name) pad_type_arg.i = PadType.CONSTANT.value -- GitLab