提交 19089a52 编写于 作者: B Bin Li

Fix ONNX output shape bug

上级 9b37c124
......@@ -114,9 +114,12 @@ void MemoryOptimizer::Optimize(
const mace::OperatorDef *op_def,
const std::unordered_map<std::string, MemoryType> *mem_types) {
MACE_LATENCY_LOGGER(2, "Optimize memory");
MACE_CHECK(op_def->output_size() == op_def->output_shape_size(),
op_def->name(), "The number of output shapes is"
" not equal to the number of outputs");
if (op_def->output_size() != op_def->output_shape_size()) {
VLOG(1) << op_def->name()
<< ": the number of output shape "
<< "is not equal to the number of output";
return;
}
auto device = static_cast<DeviceType>(op_def->device_type());
DataType op_dtype = static_cast<DataType>(ProtoArgHelper::GetOptionalArg(
......
......@@ -38,7 +38,7 @@ import numpy as np
import onnx
import onnx.utils
from onnx import mapping, numpy_helper, TensorProto
from onnx import mapping, numpy_helper, shape_inference, TensorProto
from numbers import Number
IS_PYTHON3 = sys.version_info > (3,)
......@@ -420,6 +420,8 @@ class OnnxConverter(base_converter.ConverterInterface):
onnx.checker.check_model(onnx_model)
onnx_model = shape_inference.infer_shapes(onnx_model)
self._isKaldi = False
polish_available = True
......
......@@ -358,7 +358,7 @@ class Transformer(base_converter.ConverterInterface):
self.safe_remove_node(op,
self._producer.get(op.input[0], None))
return True
elif op.type == 'Reshape' and \
elif op.type == 'Reshape' and len(op.output_shape) == 1 and \
op.output_shape[0].dims == \
self.get_tensor_shape(op.input[0]):
print("Remove useless reshape: %s(%s)" % (op.name, op.type))
......@@ -1417,28 +1417,35 @@ class Transformer(base_converter.ConverterInterface):
if op.type == MaceOp.Reshape:
input_op = self._producer[op.input[0]]
input_dims = input_op.output_shape[0].dims
output_dims = op.output_shape[0].dims
if len(input_op.output_shape) != 1 or \
len(input_dims) != 4 or len(output_dims) != 4:
if len(input_op.output_shape) == 0 or len(op.output_shape) == 0:
transposable = False
else:
in_b, in_h, in_w, in_c = self.sort_feature_map_shape(
input_dims, ConverterUtil.data_format(input_op))
ou_b, ou_h, ou_w, ou_c = self.sort_feature_map_shape(
output_dims, ConverterUtil.data_format(op))
transposable = (in_b == ou_b and in_c == ou_c)
input_dims = input_op.output_shape[0].dims
output_dims = op.output_shape[0].dims
if len(input_op.output_shape) != 1 or \
len(input_dims) != 4 or len(output_dims) != 4:
transposable = False
else:
in_b, in_h, in_w, in_c = self.sort_feature_map_shape(
input_dims, ConverterUtil.data_format(input_op))
ou_b, ou_h, ou_w, ou_c = self.sort_feature_map_shape(
output_dims, ConverterUtil.data_format(op))
transposable = (in_b == ou_b and in_c == ou_c)
elif op.type == MaceOp.Squeeze:
input_dims = self._producer[op.input[0]].output_shape[0].dims
output_dims = op.output_shape[0].dims
src_df = ConverterUtil.data_format(self._model)
arg = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str)
if len(input_dims) == 4 and len(output_dims) == 2 and \
((src_df == DataFormat.NCHW and arg.ints == [2, 3]) or
(src_df == DataFormat.NHWC and arg.ints == [1, 2])):
transposable = True
else:
input_op = self._producer[op.input[0]]
if len(input_op.output_shape) == 0 or len(op.output_shape) == 0:
transposable = False
else:
input_dims = input_op.output_shape[0].dims
output_dims = op.output_shape[0].dims
src_df = ConverterUtil.data_format(self._model)
arg = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str)
if len(input_dims) == 4 and len(output_dims) == 2 and \
((src_df == DataFormat.NCHW and arg.ints == [2, 3]) or
(src_df == DataFormat.NHWC and arg.ints == [1, 2])):
transposable = True
else:
transposable = False
if op.type in MaceTransposableDataFormatOps and not transposable:
print("%s(%s) is not a transposable op in this model."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册