From a1d3bd26918e1074c8fe5395c4dcee6695a7cafd Mon Sep 17 00:00:00 2001 From: liuqi Date: Mon, 6 May 2019 14:19:24 +0800 Subject: [PATCH] Fix some bugs. 1. Bug in operator source template. 2. Output shapes of Op with data format may be not 4D. --- mace/core/net_def_adapter.cc | 17 ++++++++--------- mace/python/tools/operator.jinja2 | 3 ++- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mace/core/net_def_adapter.cc b/mace/core/net_def_adapter.cc index 45b66b56..5d3915b4 100644 --- a/mace/core/net_def_adapter.cc +++ b/mace/core/net_def_adapter.cc @@ -381,15 +381,14 @@ MaceStatus NetDefAdapter::AdaptDataFormat( int output_shape_size = op_def->output_shape_size(); for (int i = 0; i < output_shape_size; ++i) { auto output_shape = op_def->mutable_output_shape(i); - MACE_CHECK(output_shape->dims_size() == 4, - "Output shape should be 4D if the of has data format. ", - op_def->name()); - // transpose output shape format from NHWC to NCHW - int64_t height = output_shape->dims(1); - int64_t width = output_shape->dims(2); - output_shape->set_dims(1, output_shape->dims(3)); - output_shape->set_dims(2, height); - output_shape->set_dims(3, width); + if (output_shape->dims_size() == 4) { + // transpose output shape format from NHWC to NCHW + int64_t height = output_shape->dims(1); + int64_t width = output_shape->dims(2); + output_shape->set_dims(1, output_shape->dims(3)); + output_shape->set_dims(2, height); + output_shape->set_dims(3, width); + } } } } diff --git a/mace/python/tools/operator.jinja2 b/mace/python/tools/operator.jinja2 index e60057ed..b184b54a 100644 --- a/mace/python/tools/operator.jinja2 +++ b/mace/python/tools/operator.jinja2 @@ -96,9 +96,10 @@ void CreateOperator{{i}}(mace::OperatorDef *op) { {% if net.op[i].output_shape|length > 0 %} op->mutable_output_shape()->Reserve({{ net.op[i].output_shape|length }}); + mace::OutputShape * output_shape = nullptr; {% for shape in net.op[i].output_shape %} {% if shape.dims|length > 0 %} - mace::OutputShape * output_shape = op->add_output_shape(); + output_shape = op->add_output_shape(); output_shape->mutable_dims()->Reserve({{ shape.dims|length }}); {% for dim in shape.dims %} -- GitLab