提交 56a2be55 编写于 作者: 叶剑武

Merge branch 'update_onnx_converter' into 'master'

update converter for tf.split and onnx.pad

See merge request !1065
......@@ -120,9 +120,8 @@ OnnxSupportedOps = [
# 'OneHot',
# 'Or',
'PRelu',
# 'Pad',
'Pad',
'PadContext',
'Padding',
'PNorm',
'Pow',
# 'RNN',
......@@ -346,8 +345,8 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.Neg.name: self.convert_eltwise,
OnnxOpType.Normalize: self.convert_normalize,
OnnxOpType.Offset.name: self.convert_identity,
OnnxOpType.Pad.name: self.convert_pad,
OnnxOpType.PadContext.name: self.convert_pad_context,
OnnxOpType.Padding.name: self.convert_identity,
OnnxOpType.PNorm.name: self.convert_pnorm,
OnnxOpType.Pow.name: self.convert_eltwise,
OnnxOpType.PRelu.name: self.convert_activation,
......@@ -982,6 +981,29 @@ class OnnxConverter(base_converter.ConverterInterface):
op = self.convert_general_op(node)
op.type = MaceOp.BatchNorm.name
def convert_pad(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Pad.name
if 'mode' in node.attrs:
mode = node.attrs['mode']
padding_type_arg = op.arg.add()
padding_type_arg.name = MaceKeyword.mace_padding_type_str
if mode == 'reflect':
padding_type_arg.i = PadType.REFLECT
elif mode == 'edge':
padding_type_arg.i = PadType.SYMMETRIC
else:
padding_type_arg.i = PadType.CONSTANT
if 'pads' in node.attrs:
paddings_arg = op.arg.add()
paddings_arg.name = MaceKeyword.mace_paddings_str
paddings_value = node.attrs['pads']
paddings_arg.ints.extend(paddings_value)
if 'value' in node.attrs:
constant_value_arg = op.arg.add()
constant_value_arg.name = MaceKeyword.mace_constant_value_str
constant_value_arg.f = node.attrs['value']
def convert_pad_context(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.PadContext.name
......
......@@ -1004,19 +1004,22 @@ class TensorflowConverter(base_converter.ConverterInterface):
def convert_split(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Split.name
axis = tf_op.inputs[0].eval().astype(np.int32)
axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis
del op.input[0]
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = axis
num_or_size_splits = tf_op.get_attr('num_split')
if num_or_size_splits == 1:
op.type = MaceOp.Identity.name
else:
op.type = MaceOp.Split.name
axis = tf_op.inputs[0].eval().astype(np.int32)
axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis
num_split_arg = op.arg.add()
num_split_arg.name = MaceKeyword.mace_num_split_str
num_split_arg.i = tf_op.get_attr('num_split')
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = axis
num_split_arg = op.arg.add()
num_split_arg.name = MaceKeyword.mace_num_split_str
num_split_arg.i = num_or_size_splits
del op.input[0]
self._skip_tensor.add(tf_op.inputs[0].name)
def convert_fake_quantize(self, tf_op):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册