提交 23d985f7 编写于 作者: 叶剑武

Merge branch 'update-onnx-converter' into 'master'

update onnx converter

See merge request !1114
...@@ -136,7 +136,7 @@ RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors ...@@ -136,7 +136,7 @@ RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors
torchvision==0.2.2.post3 torchvision==0.2.2.post3
RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com \ RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com \
onnx==1.3.0 \ onnx==1.5.0 \
onnx-tf==1.2.0 onnx-tf==1.2.0
RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com \ RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com \
......
...@@ -106,7 +106,7 @@ RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors ...@@ -106,7 +106,7 @@ RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors
torchvision==0.2.2.post3 torchvision==0.2.2.post3
RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com \ RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com \
onnx==1.3.0 \ onnx==1.5.0 \
onnx-tf==1.2.0 onnx-tf==1.2.0
RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com \ RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com \
......
...@@ -76,7 +76,7 @@ Optional dependencies ...@@ -76,7 +76,7 @@ Optional dependencies
- pip install filelock==3.0.0 - pip install filelock==3.0.0
- Required by run on Android - Required by run on Android
* - ONNX * - ONNX
- pip install onnx==1.3.0 - pip install onnx==1.5.0
- Required by ONNX model - Required by ONNX model
For python dependencies, For python dependencies,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#ifndef MACE_PORT_FILE_SYSTEM_H_ #ifndef MACE_PORT_FILE_SYSTEM_H_
#define MACE_PORT_FILE_SYSTEM_H_ #define MACE_PORT_FILE_SYSTEM_H_
#include <cerrno>
#include <string> #include <string>
#include <memory> #include <memory>
......
...@@ -337,7 +337,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -337,7 +337,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.Conv.name: self.convert_conv2d, OnnxOpType.Conv.name: self.convert_conv2d,
OnnxOpType.ConvTranspose.name: self.convert_deconv, OnnxOpType.ConvTranspose.name: self.convert_deconv,
OnnxOpType.DepthToSpace.name: self.convert_depth_space, OnnxOpType.DepthToSpace.name: self.convert_depth_space,
OnnxOpType.Dropout.name: self.convert_identity, OnnxOpType.Dropout.name: self.convert_dropout,
OnnxOpType.DimRange.name: self.convert_dim_range, OnnxOpType.DimRange.name: self.convert_dim_range,
OnnxOpType.Div.name: self.convert_eltwise, OnnxOpType.Div.name: self.convert_eltwise,
OnnxOpType.Equal.name: self.convert_eltwise, OnnxOpType.Equal.name: self.convert_eltwise,
...@@ -369,6 +369,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -369,6 +369,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.Relu.name: self.convert_activation, OnnxOpType.Relu.name: self.convert_activation,
OnnxOpType.Reshape.name: self.convert_reshape, OnnxOpType.Reshape.name: self.convert_reshape,
OnnxOpType.Reciprocal.name: self.convert_eltwise, OnnxOpType.Reciprocal.name: self.convert_eltwise,
OnnxOpType.ReduceMean.name: self.convert_reduce,
OnnxOpType.Scale.name: self.convert_eltwise, OnnxOpType.Scale.name: self.convert_eltwise,
OnnxOpType.Sigmoid.name: self.convert_activation, OnnxOpType.Sigmoid.name: self.convert_activation,
OnnxOpType.Slice.name: self.convert_slice, OnnxOpType.Slice.name: self.convert_slice,
...@@ -396,6 +397,8 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -396,6 +397,8 @@ class OnnxConverter(base_converter.ConverterInterface):
ir_version = onnx_model.ir_version ir_version = onnx_model.ir_version
opset_imp = onnx_model.opset_import opset_imp = onnx_model.opset_import
onnx.checker.check_model(onnx_model)
self._isKaldi = False self._isKaldi = False
polish_available = True polish_available = True
...@@ -404,7 +407,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -404,7 +407,7 @@ class OnnxConverter(base_converter.ConverterInterface):
domain = imp.domain domain = imp.domain
version = imp.version version = imp.version
print("constains ops domain: ", domain, "version:", version) print("constains ops domain: ", domain, "version:", version)
if 'kaldi2onnx' in domain: if 'kaldi' in domain:
polish_available = False polish_available = False
self._data_format = DataFormat.NONE self._data_format = DataFormat.NONE
self._isKaldi = True self._isKaldi = True
...@@ -656,14 +659,13 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -656,14 +659,13 @@ class OnnxConverter(base_converter.ConverterInterface):
def convert_concat(self, node): def convert_concat(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.Concat.name op.type = MaceOp.Concat.name
axis_value = 1 if self._isKaldi is False:
if node.op_type == OnnxOpType.Concat.name:
mace_check('axis' in node.attrs, mace_check('axis' in node.attrs,
'Concat op should have axis attribute.') 'Concat op should have axis attribute.')
axis_value = node.attrs['axis'] axis_value = node.attrs['axis']
mace_check(axis_value == 1 or axis_value == -3, mace_check(axis_value == 1 or axis_value == -3,
"only support concat at channel dimension") "only support concat at channel dimension")
elif node.op_type == OnnxOpType.Append.name: else:
axis_value = -1 axis_value = -1
axis_arg = op.arg.add() axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str axis_arg.name = MaceKeyword.mace_axis_str
...@@ -789,6 +791,12 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -789,6 +791,12 @@ class OnnxConverter(base_converter.ConverterInterface):
axes_arg.name = 'axes' axes_arg.name = 'axes'
axes_arg.ints.extend([-1]) axes_arg.ints.extend([-1])
def convert_dropout(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Identity.name
del op.output[1:]
del op.output_shape[1:]
def convert_dynamic_lstm(self, node): def convert_dynamic_lstm(self, node):
op = self.convert_general_op(node) op = self.convert_general_op(node)
op.type = MaceOp.DynamicLSTM.name op.type = MaceOp.DynamicLSTM.name
...@@ -1068,6 +1076,9 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -1068,6 +1076,9 @@ class OnnxConverter(base_converter.ConverterInterface):
axis_arg.i = value axis_arg.i = value
def convert_gemm(self, node): def convert_gemm(self, node):
if self._isKaldi:
self.convert_affine(node)
return
# only supports FullyConnected Style Gemm for now. # only supports FullyConnected Style Gemm for now.
trans_a = node.attrs['transA'] if 'transA' in node.attrs else 0 trans_a = node.attrs['transA'] if 'transA' in node.attrs else 0
trans_b = node.attrs['transB'] if 'transB' in node.attrs else 0 trans_b = node.attrs['transB'] if 'transB' in node.attrs else 0
......
...@@ -131,9 +131,6 @@ int64_t StatMACs(const std::string &op_type, ...@@ -131,9 +131,6 @@ int64_t StatMACs(const std::string &op_type,
output_shape.end(), output_shape.end(),
1, 1,
std::multiplies<int64_t>()); std::multiplies<int64_t>());
} else if (op_type == "DynamicLSTM") {
macs = output_shape[0] * (filter_shape[0] * filter_shape[1]
+ output_shape[1] * filter_shape[0] / 4);
} }
return macs; return macs;
} }
......
tensorflow>=1.8.0 tensorflow>=1.8.0
scipy>=1.0.0 scipy>=1.0.0
filelock>=3.0.0 filelock>=3.0.0
onnx>=1.3.0 onnx>=1.5.0
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册