提交 a1d3bd26 编写于 作者: L liuqi

Fix some bugs.

1. Bug in operator source template.
2. Output shapes of Op with data format may be not 4D.
上级 3e9bb73e
...@@ -381,9 +381,7 @@ MaceStatus NetDefAdapter::AdaptDataFormat( ...@@ -381,9 +381,7 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
int output_shape_size = op_def->output_shape_size(); int output_shape_size = op_def->output_shape_size();
for (int i = 0; i < output_shape_size; ++i) { for (int i = 0; i < output_shape_size; ++i) {
auto output_shape = op_def->mutable_output_shape(i); auto output_shape = op_def->mutable_output_shape(i);
MACE_CHECK(output_shape->dims_size() == 4, if (output_shape->dims_size() == 4) {
"Output shape should be 4D if the of has data format. ",
op_def->name());
// transpose output shape format from NHWC to NCHW // transpose output shape format from NHWC to NCHW
int64_t height = output_shape->dims(1); int64_t height = output_shape->dims(1);
int64_t width = output_shape->dims(2); int64_t width = output_shape->dims(2);
...@@ -393,6 +391,7 @@ MaceStatus NetDefAdapter::AdaptDataFormat( ...@@ -393,6 +391,7 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
} }
} }
} }
}
*op_output_df = op_data_format; *op_output_df = op_data_format;
// the output memory type of transpose op is based on the consumer op's device // the output memory type of transpose op is based on the consumer op's device
......
...@@ -96,9 +96,10 @@ void CreateOperator{{i}}(mace::OperatorDef *op) { ...@@ -96,9 +96,10 @@ void CreateOperator{{i}}(mace::OperatorDef *op) {
{% if net.op[i].output_shape|length > 0 %} {% if net.op[i].output_shape|length > 0 %}
op->mutable_output_shape()->Reserve({{ net.op[i].output_shape|length }}); op->mutable_output_shape()->Reserve({{ net.op[i].output_shape|length }});
mace::OutputShape * output_shape = nullptr;
{% for shape in net.op[i].output_shape %} {% for shape in net.op[i].output_shape %}
{% if shape.dims|length > 0 %} {% if shape.dims|length > 0 %}
mace::OutputShape * output_shape = op->add_output_shape(); output_shape = op->add_output_shape();
output_shape->mutable_dims()->Reserve({{ shape.dims|length }}); output_shape->mutable_dims()->Reserve({{ shape.dims|length }});
{% for dim in shape.dims %} {% for dim in shape.dims %}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册