提交 56a2be55 编写于 作者: 叶剑武

Merge branch 'update_onnx_converter' into 'master'

update converter for tf.split and onnx.pad

See merge request !1065
...@@ -120,9 +120,8 @@ OnnxSupportedOps = [ ...@@ -120,9 +120,8 @@ OnnxSupportedOps = [
# 'OneHot', # 'OneHot',
# 'Or', # 'Or',
'PRelu', 'PRelu',
# 'Pad', 'Pad',
'PadContext', 'PadContext',
'Padding',
'PNorm', 'PNorm',
'Pow', 'Pow',
# 'RNN', # 'RNN',
...@@ -346,8 +345,8 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -346,8 +345,8 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.Neg.name: self.convert_eltwise, OnnxOpType.Neg.name: self.convert_eltwise,
OnnxOpType.Normalize: self.convert_normalize, OnnxOpType.Normalize: self.convert_normalize,
OnnxOpType.Offset.name: self.convert_identity, OnnxOpType.Offset.name: self.convert_identity,
OnnxOpType.Pad.name: self.convert_pad,
OnnxOpType.PadContext.name: self.convert_pad_context, OnnxOpType.PadContext.name: self.convert_pad_context,
OnnxOpType.Padding.name: self.convert_identity,
OnnxOpType.PNorm.name: self.convert_pnorm, OnnxOpType.PNorm.name: self.convert_pnorm,
OnnxOpType.Pow.name: self.convert_eltwise, OnnxOpType.Pow.name: self.convert_eltwise,
OnnxOpType.PRelu.name: self.convert_activation, OnnxOpType.PRelu.name: self.convert_activation,
...@@ -982,6 +981,29 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -982,6 +981,29 @@ class OnnxConverter(base_converter.ConverterInterface):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.BatchNorm.name op.type = MaceOp.BatchNorm.name
def convert_pad(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Pad.name
if 'mode' in node.attrs:
mode = node.attrs['mode']
padding_type_arg = op.arg.add()
padding_type_arg.name = MaceKeyword.mace_padding_type_str
if mode == 'reflect':
padding_type_arg.i = PadType.REFLECT
elif mode == 'edge':
padding_type_arg.i = PadType.SYMMETRIC
else:
padding_type_arg.i = PadType.CONSTANT
if 'pads' in node.attrs:
paddings_arg = op.arg.add()
paddings_arg.name = MaceKeyword.mace_paddings_str
paddings_value = node.attrs['pads']
paddings_arg.ints.extend(paddings_value)
if 'value' in node.attrs:
constant_value_arg = op.arg.add()
constant_value_arg.name = MaceKeyword.mace_constant_value_str
constant_value_arg.f = node.attrs['value']
def convert_pad_context(self, node): def convert_pad_context(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.PadContext.name op.type = MaceOp.PadContext.name
......
...@@ -1004,19 +1004,22 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -1004,19 +1004,22 @@ class TensorflowConverter(base_converter.ConverterInterface):
def convert_split(self, tf_op): def convert_split(self, tf_op):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
op.type = MaceOp.Split.name num_or_size_splits = tf_op.get_attr('num_split')
axis = tf_op.inputs[0].eval().astype(np.int32) if num_or_size_splits == 1:
axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis op.type = MaceOp.Identity.name
del op.input[0] else:
op.type = MaceOp.Split.name
axis_arg = op.arg.add() axis = tf_op.inputs[0].eval().astype(np.int32)
axis_arg.name = MaceKeyword.mace_axis_str axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis
axis_arg.i = axis
num_split_arg = op.arg.add() axis_arg = op.arg.add()
num_split_arg.name = MaceKeyword.mace_num_split_str axis_arg.name = MaceKeyword.mace_axis_str
num_split_arg.i = tf_op.get_attr('num_split') axis_arg.i = axis
num_split_arg = op.arg.add()
num_split_arg.name = MaceKeyword.mace_num_split_str
num_split_arg.i = num_or_size_splits
del op.input[0]
self._skip_tensor.add(tf_op.inputs[0].name) self._skip_tensor.add(tf_op.inputs[0].name)
def convert_fake_quantize(self, tf_op): def convert_fake_quantize(self, tf_op):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册