“97136fe97f15f408ca759a060210365512bb6f2a”上不存在“paddlespeech/s2t/exps/deepspeech2/bin/test.py”
提交 23d985f7 编写于 作者: 叶剑武

Merge branch 'update-onnx-converter' into 'master'

update onnx converter

See merge request !1114
......@@ -136,7 +136,7 @@ RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors
torchvision==0.2.2.post3
RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com \
onnx==1.3.0 \
onnx==1.5.0 \
onnx-tf==1.2.0
RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com \
......
......@@ -106,7 +106,7 @@ RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors
torchvision==0.2.2.post3
RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com \
onnx==1.3.0 \
onnx==1.5.0 \
onnx-tf==1.2.0
RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com \
......
......@@ -76,7 +76,7 @@ Optional dependencies
- pip install filelock==3.0.0
- Required by run on Android
* - ONNX
- pip install onnx==1.3.0
- pip install onnx==1.5.0
- Required by ONNX model
For python dependencies,
......
......@@ -15,6 +15,7 @@
#ifndef MACE_PORT_FILE_SYSTEM_H_
#define MACE_PORT_FILE_SYSTEM_H_
#include <cerrno>
#include <string>
#include <memory>
......
......@@ -337,7 +337,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.Conv.name: self.convert_conv2d,
OnnxOpType.ConvTranspose.name: self.convert_deconv,
OnnxOpType.DepthToSpace.name: self.convert_depth_space,
OnnxOpType.Dropout.name: self.convert_identity,
OnnxOpType.Dropout.name: self.convert_dropout,
OnnxOpType.DimRange.name: self.convert_dim_range,
OnnxOpType.Div.name: self.convert_eltwise,
OnnxOpType.Equal.name: self.convert_eltwise,
......@@ -369,6 +369,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.Relu.name: self.convert_activation,
OnnxOpType.Reshape.name: self.convert_reshape,
OnnxOpType.Reciprocal.name: self.convert_eltwise,
OnnxOpType.ReduceMean.name: self.convert_reduce,
OnnxOpType.Scale.name: self.convert_eltwise,
OnnxOpType.Sigmoid.name: self.convert_activation,
OnnxOpType.Slice.name: self.convert_slice,
......@@ -396,6 +397,8 @@ class OnnxConverter(base_converter.ConverterInterface):
ir_version = onnx_model.ir_version
opset_imp = onnx_model.opset_import
onnx.checker.check_model(onnx_model)
self._isKaldi = False
polish_available = True
......@@ -404,7 +407,7 @@ class OnnxConverter(base_converter.ConverterInterface):
domain = imp.domain
version = imp.version
print("constains ops domain: ", domain, "version:", version)
if 'kaldi2onnx' in domain:
if 'kaldi' in domain:
polish_available = False
self._data_format = DataFormat.NONE
self._isKaldi = True
......@@ -656,14 +659,13 @@ class OnnxConverter(base_converter.ConverterInterface):
def convert_concat(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Concat.name
axis_value = 1
if node.op_type == OnnxOpType.Concat.name:
if self._isKaldi is False:
mace_check('axis' in node.attrs,
'Concat op should have axis attribute.')
axis_value = node.attrs['axis']
mace_check(axis_value == 1 or axis_value == -3,
"only support concat at channel dimension")
elif node.op_type == OnnxOpType.Append.name:
else:
axis_value = -1
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
......@@ -789,6 +791,12 @@ class OnnxConverter(base_converter.ConverterInterface):
axes_arg.name = 'axes'
axes_arg.ints.extend([-1])
def convert_dropout(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Identity.name
del op.output[1:]
del op.output_shape[1:]
def convert_dynamic_lstm(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.DynamicLSTM.name
......@@ -1068,6 +1076,9 @@ class OnnxConverter(base_converter.ConverterInterface):
axis_arg.i = value
def convert_gemm(self, node):
if self._isKaldi:
self.convert_affine(node)
return
# only supports FullyConnected Style Gemm for now.
trans_a = node.attrs['transA'] if 'transA' in node.attrs else 0
trans_b = node.attrs['transB'] if 'transB' in node.attrs else 0
......
......@@ -131,9 +131,6 @@ int64_t StatMACs(const std::string &op_type,
output_shape.end(),
1,
std::multiplies<int64_t>());
} else if (op_type == "DynamicLSTM") {
macs = output_shape[0] * (filter_shape[0] * filter_shape[1]
+ output_shape[1] * filter_shape[0] / 4);
}
return macs;
}
......
tensorflow>=1.8.0
scipy>=1.0.0
filelock>=3.0.0
onnx>=1.3.0
\ No newline at end of file
onnx>=1.5.0
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册