diff --git a/mace/core/net_def_adapter.cc b/mace/core/net_def_adapter.cc index 45b66b56e49f347dd5ff123caa16f207e767c3e1..5d3915b4618c5030308ab25e82df2b3c1fc0e444 100644 --- a/mace/core/net_def_adapter.cc +++ b/mace/core/net_def_adapter.cc @@ -381,15 +381,14 @@ MaceStatus NetDefAdapter::AdaptDataFormat( int output_shape_size = op_def->output_shape_size(); for (int i = 0; i < output_shape_size; ++i) { auto output_shape = op_def->mutable_output_shape(i); - MACE_CHECK(output_shape->dims_size() == 4, - "Output shape should be 4D if the of has data format. ", - op_def->name()); - // transpose output shape format from NHWC to NCHW - int64_t height = output_shape->dims(1); - int64_t width = output_shape->dims(2); - output_shape->set_dims(1, output_shape->dims(3)); - output_shape->set_dims(2, height); - output_shape->set_dims(3, width); + if (output_shape->dims_size() == 4) { + // transpose output shape format from NHWC to NCHW + int64_t height = output_shape->dims(1); + int64_t width = output_shape->dims(2); + output_shape->set_dims(1, output_shape->dims(3)); + output_shape->set_dims(2, height); + output_shape->set_dims(3, width); + } } } } diff --git a/mace/python/tools/operator.jinja2 b/mace/python/tools/operator.jinja2 index e60057ed75be1da5edb7c5cc46fdc7c00f243c8c..b184b54a3d98f034147866d04a6b48c1af0703f9 100644 --- a/mace/python/tools/operator.jinja2 +++ b/mace/python/tools/operator.jinja2 @@ -96,9 +96,10 @@ void CreateOperator{{i}}(mace::OperatorDef *op) { {% if net.op[i].output_shape|length > 0 %} op->mutable_output_shape()->Reserve({{ net.op[i].output_shape|length }}); + mace::OutputShape * output_shape = nullptr; {% for shape in net.op[i].output_shape %} {% if shape.dims|length > 0 %} - mace::OutputShape * output_shape = op->add_output_shape(); + output_shape = op->add_output_shape(); output_shape->mutable_dims()->Reserve({{ shape.dims|length }}); {% for dim in shape.dims %}