提交 19089a52 编写于 作者: B Bin Li

Fix ONNX output shape bug

上级 9b37c124
...@@ -114,9 +114,12 @@ void MemoryOptimizer::Optimize( ...@@ -114,9 +114,12 @@ void MemoryOptimizer::Optimize(
const mace::OperatorDef *op_def, const mace::OperatorDef *op_def,
const std::unordered_map<std::string, MemoryType> *mem_types) { const std::unordered_map<std::string, MemoryType> *mem_types) {
MACE_LATENCY_LOGGER(2, "Optimize memory"); MACE_LATENCY_LOGGER(2, "Optimize memory");
MACE_CHECK(op_def->output_size() == op_def->output_shape_size(), if (op_def->output_size() != op_def->output_shape_size()) {
op_def->name(), "The number of output shapes is" VLOG(1) << op_def->name()
" not equal to the number of outputs"); << ": the number of output shape "
<< "is not equal to the number of output";
return;
}
auto device = static_cast<DeviceType>(op_def->device_type()); auto device = static_cast<DeviceType>(op_def->device_type());
DataType op_dtype = static_cast<DataType>(ProtoArgHelper::GetOptionalArg( DataType op_dtype = static_cast<DataType>(ProtoArgHelper::GetOptionalArg(
......
...@@ -38,7 +38,7 @@ import numpy as np ...@@ -38,7 +38,7 @@ import numpy as np
import onnx import onnx
import onnx.utils import onnx.utils
from onnx import mapping, numpy_helper, TensorProto from onnx import mapping, numpy_helper, shape_inference, TensorProto
from numbers import Number from numbers import Number
IS_PYTHON3 = sys.version_info > (3,) IS_PYTHON3 = sys.version_info > (3,)
...@@ -420,6 +420,8 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -420,6 +420,8 @@ class OnnxConverter(base_converter.ConverterInterface):
onnx.checker.check_model(onnx_model) onnx.checker.check_model(onnx_model)
onnx_model = shape_inference.infer_shapes(onnx_model)
self._isKaldi = False self._isKaldi = False
polish_available = True polish_available = True
......
...@@ -358,7 +358,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -358,7 +358,7 @@ class Transformer(base_converter.ConverterInterface):
self.safe_remove_node(op, self.safe_remove_node(op,
self._producer.get(op.input[0], None)) self._producer.get(op.input[0], None))
return True return True
elif op.type == 'Reshape' and \ elif op.type == 'Reshape' and len(op.output_shape) == 1 and \
op.output_shape[0].dims == \ op.output_shape[0].dims == \
self.get_tensor_shape(op.input[0]): self.get_tensor_shape(op.input[0]):
print("Remove useless reshape: %s(%s)" % (op.name, op.type)) print("Remove useless reshape: %s(%s)" % (op.name, op.type))
...@@ -1417,6 +1417,9 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1417,6 +1417,9 @@ class Transformer(base_converter.ConverterInterface):
if op.type == MaceOp.Reshape: if op.type == MaceOp.Reshape:
input_op = self._producer[op.input[0]] input_op = self._producer[op.input[0]]
if len(input_op.output_shape) == 0 or len(op.output_shape) == 0:
transposable = False
else:
input_dims = input_op.output_shape[0].dims input_dims = input_op.output_shape[0].dims
output_dims = op.output_shape[0].dims output_dims = op.output_shape[0].dims
if len(input_op.output_shape) != 1 or \ if len(input_op.output_shape) != 1 or \
...@@ -1429,7 +1432,11 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1429,7 +1432,11 @@ class Transformer(base_converter.ConverterInterface):
output_dims, ConverterUtil.data_format(op)) output_dims, ConverterUtil.data_format(op))
transposable = (in_b == ou_b and in_c == ou_c) transposable = (in_b == ou_b and in_c == ou_c)
elif op.type == MaceOp.Squeeze: elif op.type == MaceOp.Squeeze:
input_dims = self._producer[op.input[0]].output_shape[0].dims input_op = self._producer[op.input[0]]
if len(input_op.output_shape) == 0 or len(op.output_shape) == 0:
transposable = False
else:
input_dims = input_op.output_shape[0].dims
output_dims = op.output_shape[0].dims output_dims = op.output_shape[0].dims
src_df = ConverterUtil.data_format(self._model) src_df = ConverterUtil.data_format(self._model)
arg = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str) arg = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册