From afc440b614566d01139aaad3b139bc41e245d059 Mon Sep 17 00:00:00 2001 From: liutuo Date: Tue, 16 Apr 2019 16:24:10 +0800 Subject: [PATCH] update onnx converter add support splite with num_splits=1 --- .../tools/converter_tool/onnx_converter.py | 28 +++++++++++++++++-- .../converter_tool/tensorflow_converter.py | 25 +++++++++-------- 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/mace/python/tools/converter_tool/onnx_converter.py b/mace/python/tools/converter_tool/onnx_converter.py index 805cbd27..99ae2a79 100644 --- a/mace/python/tools/converter_tool/onnx_converter.py +++ b/mace/python/tools/converter_tool/onnx_converter.py @@ -120,9 +120,8 @@ OnnxSupportedOps = [ # 'OneHot', # 'Or', 'PRelu', - # 'Pad', + 'Pad', 'PadContext', - 'Padding', 'PNorm', 'Pow', # 'RNN', @@ -346,8 +345,8 @@ class OnnxConverter(base_converter.ConverterInterface): OnnxOpType.Neg.name: self.convert_eltwise, OnnxOpType.Normalize: self.convert_normalize, OnnxOpType.Offset.name: self.convert_identity, + OnnxOpType.Pad.name: self.convert_pad, OnnxOpType.PadContext.name: self.convert_pad_context, - OnnxOpType.Padding.name: self.convert_identity, OnnxOpType.PNorm.name: self.convert_pnorm, OnnxOpType.Pow.name: self.convert_eltwise, OnnxOpType.PRelu.name: self.convert_activation, @@ -982,6 +981,29 @@ class OnnxConverter(base_converter.ConverterInterface): op = self.convert_general_op(node) op.type = MaceOp.BatchNorm.name + def convert_pad(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Pad.name + if 'mode' in node.attrs: + mode = node.attrs['mode'] + padding_type_arg = op.arg.add() + padding_type_arg.name = MaceKeyword.mace_padding_type_str + if mode == 'reflect': + padding_type_arg.i = PadType.REFLECT + elif mode == 'edge': + padding_type_arg.i = PadType.SYMMETRIC + else: + padding_type_arg.i = PadType.CONSTANT + if 'pads' in node.attrs: + paddings_arg = op.arg.add() + paddings_arg.name = MaceKeyword.mace_paddings_str + paddings_value = node.attrs['pads'] + paddings_arg.ints.extend(paddings_value) + if 'value' in node.attrs: + constant_value_arg = op.arg.add() + constant_value_arg.name = MaceKeyword.mace_constant_value_str + constant_value_arg.f = node.attrs['value'] + def convert_pad_context(self, node): op = self.convert_general_op(node) op.type = MaceOp.PadContext.name diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 53d57151..59bf9d88 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -1004,19 +1004,22 @@ class TensorflowConverter(base_converter.ConverterInterface): def convert_split(self, tf_op): op = self.convert_general_op(tf_op) - op.type = MaceOp.Split.name - axis = tf_op.inputs[0].eval().astype(np.int32) - axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis - del op.input[0] - - axis_arg = op.arg.add() - axis_arg.name = MaceKeyword.mace_axis_str - axis_arg.i = axis + num_or_size_splits = tf_op.get_attr('num_split') + if num_or_size_splits == 1: + op.type = MaceOp.Identity.name + else: + op.type = MaceOp.Split.name + axis = tf_op.inputs[0].eval().astype(np.int32) + axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis - num_split_arg = op.arg.add() - num_split_arg.name = MaceKeyword.mace_num_split_str - num_split_arg.i = tf_op.get_attr('num_split') + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + axis_arg.i = axis + num_split_arg = op.arg.add() + num_split_arg.name = MaceKeyword.mace_num_split_str + num_split_arg.i = num_or_size_splits + del op.input[0] self._skip_tensor.add(tf_op.inputs[0].name) def convert_fake_quantize(self, tf_op): -- GitLab