From 0d924fc7b4b1f04f0b1b3fca7678aea1e47ce00a Mon Sep 17 00:00:00 2001 From: liutuo Date: Wed, 22 May 2019 12:55:00 +0800 Subject: [PATCH] update onnx converter --- docker/mace-dev-lite/Dockerfile | 2 +- docker/mace-dev/Dockerfile | 2 +- docs/installation/env_requirement.rst | 2 +- include/mace/port/file_system.h | 1 + .../tools/converter_tool/onnx_converter.py | 21 ++++++++++++++----- mace/utils/statistics.cc | 3 --- setup/optionals.txt | 2 +- 7 files changed, 21 insertions(+), 12 deletions(-) diff --git a/docker/mace-dev-lite/Dockerfile b/docker/mace-dev-lite/Dockerfile index 29c36b10..70f94ba8 100644 --- a/docker/mace-dev-lite/Dockerfile +++ b/docker/mace-dev-lite/Dockerfile @@ -136,7 +136,7 @@ RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors torchvision==0.2.2.post3 RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com \ - onnx==1.3.0 \ + onnx==1.5.0 \ onnx-tf==1.2.0 RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com \ diff --git a/docker/mace-dev/Dockerfile b/docker/mace-dev/Dockerfile index 23c98b39..8bb31a7d 100644 --- a/docker/mace-dev/Dockerfile +++ b/docker/mace-dev/Dockerfile @@ -106,7 +106,7 @@ RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors torchvision==0.2.2.post3 RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com \ - onnx==1.3.0 \ + onnx==1.5.0 \ onnx-tf==1.2.0 RUN pip install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com \ diff --git a/docs/installation/env_requirement.rst b/docs/installation/env_requirement.rst index 4a599ec5..465072a1 100644 --- a/docs/installation/env_requirement.rst +++ b/docs/installation/env_requirement.rst @@ -76,7 +76,7 @@ Optional dependencies - pip install filelock==3.0.0 - Required by run on Android * - ONNX - - pip install onnx==1.3.0 + - pip install onnx==1.5.0 - Required by ONNX model For python dependencies, diff --git a/include/mace/port/file_system.h b/include/mace/port/file_system.h index 4de18fbb..2117faea 100644 --- a/include/mace/port/file_system.h +++ b/include/mace/port/file_system.h @@ -15,6 +15,7 @@ #ifndef MACE_PORT_FILE_SYSTEM_H_ #define MACE_PORT_FILE_SYSTEM_H_ +#include #include #include diff --git a/mace/python/tools/converter_tool/onnx_converter.py b/mace/python/tools/converter_tool/onnx_converter.py index a024aed8..b4a8e291 100644 --- a/mace/python/tools/converter_tool/onnx_converter.py +++ b/mace/python/tools/converter_tool/onnx_converter.py @@ -337,7 +337,7 @@ class OnnxConverter(base_converter.ConverterInterface): OnnxOpType.Conv.name: self.convert_conv2d, OnnxOpType.ConvTranspose.name: self.convert_deconv, OnnxOpType.DepthToSpace.name: self.convert_depth_space, - OnnxOpType.Dropout.name: self.convert_identity, + OnnxOpType.Dropout.name: self.convert_dropout, OnnxOpType.DimRange.name: self.convert_dim_range, OnnxOpType.Div.name: self.convert_eltwise, OnnxOpType.Equal.name: self.convert_eltwise, @@ -369,6 +369,7 @@ class OnnxConverter(base_converter.ConverterInterface): OnnxOpType.Relu.name: self.convert_activation, OnnxOpType.Reshape.name: self.convert_reshape, OnnxOpType.Reciprocal.name: self.convert_eltwise, + OnnxOpType.ReduceMean.name: self.convert_reduce, OnnxOpType.Scale.name: self.convert_eltwise, OnnxOpType.Sigmoid.name: self.convert_activation, OnnxOpType.Slice.name: self.convert_slice, @@ -396,6 +397,8 @@ class OnnxConverter(base_converter.ConverterInterface): ir_version = onnx_model.ir_version opset_imp = onnx_model.opset_import + onnx.checker.check_model(onnx_model) + self._isKaldi = False polish_available = True @@ -404,7 +407,7 @@ class OnnxConverter(base_converter.ConverterInterface): domain = imp.domain version = imp.version print("constains ops domain: ", domain, "version:", version) - if 'kaldi2onnx' in domain: + if 'kaldi' in domain: polish_available = False self._data_format = DataFormat.NONE self._isKaldi = True @@ -656,14 +659,13 @@ class OnnxConverter(base_converter.ConverterInterface): def convert_concat(self, node): op = self.convert_general_op(node) op.type = MaceOp.Concat.name - axis_value = 1 - if node.op_type == OnnxOpType.Concat.name: + if self._isKaldi is False: mace_check('axis' in node.attrs, 'Concat op should have axis attribute.') axis_value = node.attrs['axis'] mace_check(axis_value == 1 or axis_value == -3, "only support concat at channel dimension") - elif node.op_type == OnnxOpType.Append.name: + else: axis_value = -1 axis_arg = op.arg.add() axis_arg.name = MaceKeyword.mace_axis_str @@ -789,6 +791,12 @@ class OnnxConverter(base_converter.ConverterInterface): axes_arg.name = 'axes' axes_arg.ints.extend([-1]) + def convert_dropout(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Identity.name + del op.output[1:] + del op.output_shape[1:] + def convert_dynamic_lstm(self, node): op = self.convert_general_op(node) op.type = MaceOp.DynamicLSTM.name @@ -1068,6 +1076,9 @@ class OnnxConverter(base_converter.ConverterInterface): axis_arg.i = value def convert_gemm(self, node): + if self._isKaldi: + self.convert_affine(node) + return # only supports FullyConnected Style Gemm for now. trans_a = node.attrs['transA'] if 'transA' in node.attrs else 0 trans_b = node.attrs['transB'] if 'transB' in node.attrs else 0 diff --git a/mace/utils/statistics.cc b/mace/utils/statistics.cc index 7ff43664..924b3a71 100644 --- a/mace/utils/statistics.cc +++ b/mace/utils/statistics.cc @@ -131,9 +131,6 @@ int64_t StatMACs(const std::string &op_type, output_shape.end(), 1, std::multiplies()); - } else if (op_type == "DynamicLSTM") { - macs = output_shape[0] * (filter_shape[0] * filter_shape[1] - + output_shape[1] * filter_shape[0] / 4); } return macs; } diff --git a/setup/optionals.txt b/setup/optionals.txt index 94187950..2c0c3427 100644 --- a/setup/optionals.txt +++ b/setup/optionals.txt @@ -1,4 +1,4 @@ tensorflow>=1.8.0 scipy>=1.0.0 filelock>=3.0.0 -onnx>=1.3.0 \ No newline at end of file +onnx>=1.5.0 \ No newline at end of file -- GitLab