提交 73861aeb 编写于 作者: 叶剑武

Merge branch 'add-onnx-flatten' into 'master'

add onnx flatten

See merge request !1148
...@@ -89,7 +89,7 @@ OnnxSupportedOps = [ ...@@ -89,7 +89,7 @@ OnnxSupportedOps = [
# 'Expand', # 'Expand',
'ExtractPooling', 'ExtractPooling',
# 'EyeLike', # 'EyeLike',
# 'Flatten', 'Flatten',
# 'Floor', # 'Floor',
# 'GRU', # 'GRU',
'Gather', 'Gather',
...@@ -343,6 +343,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -343,6 +343,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.Div.name: self.convert_eltwise, OnnxOpType.Div.name: self.convert_eltwise,
OnnxOpType.Equal.name: self.convert_eltwise, OnnxOpType.Equal.name: self.convert_eltwise,
OnnxOpType.ExtractPooling.name: self.convert_extract_pooling, OnnxOpType.ExtractPooling.name: self.convert_extract_pooling,
OnnxOpType.Flatten.name: self.convert_flatten,
OnnxOpType.Gather.name: self.convert_gather, OnnxOpType.Gather.name: self.convert_gather,
OnnxOpType.Gemm.name: self.convert_gemm, OnnxOpType.Gemm.name: self.convert_gemm,
OnnxOpType.GlobalAveragePool.name: self.convert_reduce, OnnxOpType.GlobalAveragePool.name: self.convert_reduce,
...@@ -883,6 +884,9 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -883,6 +884,9 @@ class OnnxConverter(base_converter.ConverterInterface):
scale_arg.f = scale scale_arg.f = scale
def convert_clip(self, node): def convert_clip(self, node):
# If clip's min value is zero,
# convert clip to activation(ReLU or ReLUX)
# so it can be fused into convolution.
is_relux = False is_relux = False
if 'min' in node.attrs: if 'min' in node.attrs:
min_value = node.attrs['min'] min_value = node.attrs['min']
...@@ -1038,6 +1042,16 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -1038,6 +1042,16 @@ class OnnxConverter(base_converter.ConverterInterface):
def convert_flatten(self, node): def convert_flatten(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.Reshape.name op.type = MaceOp.Reshape.name
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = 1
if 'axis' in node.attrs:
axis_arg.i = node.attrs['axis']
axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i
end_axis_arg = op.arg.add()
end_axis_arg.name = MaceKeyword.mace_end_axis_str
end_axis_arg.i = -1
def convert_kaldi_batchnorm(self, node): def convert_kaldi_batchnorm(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册