提交 2eecaa15 编写于 作者: L liuqi

Fix in/out name check bug.

上级 5b12c75f
...@@ -25,8 +25,9 @@ MaceStatus BufferToImageFunctor<DeviceType::GPU, T>::operator()( ...@@ -25,8 +25,9 @@ MaceStatus BufferToImageFunctor<DeviceType::GPU, T>::operator()(
const BufferType type, const BufferType type,
Tensor *image, Tensor *image,
StatsFuture *future) { StatsFuture *future) {
auto formatted_buffer_shape = FormatBufferShape(buffer->shape(), type);
std::vector<size_t> image_shape; std::vector<size_t> image_shape;
CalImage2DShape(buffer->shape(), type, &image_shape, wino_blk_size_); CalImage2DShape(formatted_buffer_shape, type, &image_shape, wino_blk_size_);
if (type == WINOGRAD_FILTER) { if (type == WINOGRAD_FILTER) {
std::vector<index_t> new_shape = std::vector<index_t> new_shape =
CalWinogradShape(buffer->shape(), type, wino_blk_size_); CalWinogradShape(buffer->shape(), type, wino_blk_size_);
...@@ -136,30 +137,10 @@ MaceStatus BufferToImageFunctor<DeviceType::GPU, T>::operator()( ...@@ -136,30 +137,10 @@ MaceStatus BufferToImageFunctor<DeviceType::GPU, T>::operator()(
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3))); b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
} else if (type == ARGUMENT) { } else if (type == ARGUMENT) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0))); b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0)));
} else if (type == IN_OUT_CHANNEL) {
if (buffer->dim_size() == 4) { // NHWC
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
} else if (buffer->dim_size() == 2) { // NC
b2f_kernel.setArg(idx++, static_cast<uint32_t>(1));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(1));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
} else {
MACE_NOT_IMPLEMENTED;
}
} else if (type == IN_OUT_WIDTH || type == IN_OUT_HEIGHT) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
if (buffer->dim_size() < 4) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(1));
} else {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
}
} else { } else {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1))); b2f_kernel.setArg(idx++, static_cast<uint32_t>(formatted_buffer_shape[1]));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2))); b2f_kernel.setArg(idx++, static_cast<uint32_t>(formatted_buffer_shape[2]));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3))); b2f_kernel.setArg(idx++, static_cast<uint32_t>(formatted_buffer_shape[3]));
} }
b2f_kernel.setArg(idx++, *(image->opencl_image())); b2f_kernel.setArg(idx++, *(image->opencl_image()));
......
...@@ -160,7 +160,7 @@ MaceStatus FCWXKernel(cl::Kernel *kernel, ...@@ -160,7 +160,7 @@ MaceStatus FCWXKernel(cl::Kernel *kernel,
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code; MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
(*kernel_error)->UnMap(); (*kernel_error)->UnMap();
} }
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; MACE_CHECK_CL_SUCCESS(error);
if (future != nullptr) { if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) { future->wait_fn = [runtime, event](CallStats *stats) {
......
...@@ -28,15 +28,10 @@ namespace { ...@@ -28,15 +28,10 @@ namespace {
// [(C + 3) / 4 * W, N * H] // [(C + 3) / 4 * W, N * H]
void CalInOutputImageShape(const std::vector<index_t> &shape, /* NHWC */ void CalInOutputImageShape(const std::vector<index_t> &shape, /* NHWC */
std::vector<size_t> *image_shape) { std::vector<size_t> *image_shape) {
MACE_CHECK(shape.size() == 4 || shape.size() == 2); MACE_CHECK(shape.size() == 4);
image_shape->resize(2); image_shape->resize(2);
if (shape.size() == 4) { (*image_shape)[0] = RoundUpDiv4(shape[3]) * shape[2];
(*image_shape)[0] = RoundUpDiv4(shape[3]) * shape[2]; (*image_shape)[1] = shape[0] * shape[1];
(*image_shape)[1] = shape[0] * shape[1];
} else if (shape.size() == 2) {
(*image_shape)[0] = RoundUpDiv4(shape[1]);
(*image_shape)[1] = shape[0];
}
} }
// [Ic, H * W * (Oc + 3) / 4] // [Ic, H * W * (Oc + 3) / 4]
...@@ -83,27 +78,19 @@ void CalWinogradFilterImageShape( ...@@ -83,27 +78,19 @@ void CalWinogradFilterImageShape(
// [W * C, N * RoundUp<4>(H)] // [W * C, N * RoundUp<4>(H)]
void CalInOutHeightImageShape(const std::vector<index_t> &shape, /* NHWC */ void CalInOutHeightImageShape(const std::vector<index_t> &shape, /* NHWC */
std::vector<size_t> *image_shape) { std::vector<size_t> *image_shape) {
std::vector<index_t> padded_shape = shape; MACE_CHECK(shape.size() == 4);
while (padded_shape.size() < 4) {
padded_shape.push_back(1);
}
MACE_CHECK(padded_shape.size() == 4);
image_shape->resize(2); image_shape->resize(2);
(*image_shape)[0] = padded_shape[2] * padded_shape[3]; (*image_shape)[0] = shape[2] * shape[3];
(*image_shape)[1] = padded_shape[0] * RoundUpDiv4(padded_shape[1]); (*image_shape)[1] = shape[0] * RoundUpDiv4(shape[1]);
} }
// [RoundUp<4>(W) * C, N * H] // [RoundUp<4>(W) * C, N * H]
void CalInOutWidthImageShape(const std::vector<index_t> &shape, /* NHWC */ void CalInOutWidthImageShape(const std::vector<index_t> &shape, /* NHWC */
std::vector<size_t> *image_shape) { std::vector<size_t> *image_shape) {
std::vector<index_t> padded_shape = shape; MACE_CHECK(shape.size() == 4);
while (padded_shape.size() < 4) {
padded_shape.push_back(1);
}
MACE_CHECK(padded_shape.size() == 4);
image_shape->resize(2); image_shape->resize(2);
(*image_shape)[0] = RoundUpDiv4(padded_shape[2]) * padded_shape[3]; (*image_shape)[0] = RoundUpDiv4(shape[2]) * shape[3];
(*image_shape)[1] = padded_shape[0] * padded_shape[1]; (*image_shape)[1] = shape[0] * shape[1];
} }
// [Ic * H * W, (Oc + 3) / 4] // [Ic * H * W, (Oc + 3) / 4]
...@@ -163,6 +150,36 @@ void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */ ...@@ -163,6 +150,36 @@ void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */
} }
} }
std::vector<index_t> FormatBufferShape(
const std::vector<index_t> &buffer_shape,
const BufferType type) {
const size_t buffer_shape_size = buffer_shape.size();
switch (type) {
case IN_OUT_CHANNEL:
if (buffer_shape_size == 4) { // NHWC
return buffer_shape;
} else if (buffer_shape_size == 2) { // NC
return {buffer_shape[0], 1, 1, buffer_shape[1]};
} else {
LOG(FATAL) << "GPU only support 2D or 4D input and output";
}
case IN_OUT_HEIGHT:
case IN_OUT_WIDTH:
// only used for matmul test
if (buffer_shape_size == 3) {
return {buffer_shape[0], buffer_shape[1], buffer_shape[2], 1};
} else if (buffer_shape_size == 4) {
return buffer_shape;
} else {
LOG(FATAL) << "GPU only support 3D or 4D for IN_OUT_WIDTH "
"and IN_OUT_HEIGHT";
}
default:
return buffer_shape;
}
}
std::vector<index_t> CalWinogradShape(const std::vector<index_t> &shape, std::vector<index_t> CalWinogradShape(const std::vector<index_t> &shape,
const BufferType type, const BufferType type,
const int wino_blk_size) { const int wino_blk_size) {
......
...@@ -49,6 +49,10 @@ void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */ ...@@ -49,6 +49,10 @@ void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */
std::vector<size_t> *image_shape, std::vector<size_t> *image_shape,
const int wino_blk_size = 2); const int wino_blk_size = 2);
std::vector<index_t> FormatBufferShape(
const std::vector<index_t> &buffer_shape,
const BufferType type);
std::vector<index_t> CalWinogradShape(const std::vector<index_t> &shape, std::vector<index_t> CalWinogradShape(const std::vector<index_t> &shape,
const BufferType type, const BufferType type,
const int wino_blk_size = 2); const int wino_blk_size = 2);
......
...@@ -25,8 +25,9 @@ MaceStatus ImageToBufferFunctor<DeviceType::GPU, T>::operator()( ...@@ -25,8 +25,9 @@ MaceStatus ImageToBufferFunctor<DeviceType::GPU, T>::operator()(
const BufferType type, const BufferType type,
Tensor *buffer, Tensor *buffer,
StatsFuture *future) { StatsFuture *future) {
auto formatted_buffer_shape = FormatBufferShape(image->shape(), type);
std::vector<size_t> image_shape; std::vector<size_t> image_shape;
CalImage2DShape(image->shape(), type, &image_shape, wino_blk_size_); CalImage2DShape(formatted_buffer_shape, type, &image_shape, wino_blk_size_);
MACE_RETURN_IF_ERROR(buffer->Resize(image->shape())); MACE_RETURN_IF_ERROR(buffer->Resize(image->shape()));
uint32_t gws[2] = {static_cast<uint32_t>(image_shape[0]), uint32_t gws[2] = {static_cast<uint32_t>(image_shape[0]),
...@@ -123,30 +124,10 @@ MaceStatus ImageToBufferFunctor<DeviceType::GPU, T>::operator()( ...@@ -123,30 +124,10 @@ MaceStatus ImageToBufferFunctor<DeviceType::GPU, T>::operator()(
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1))); b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2))); b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3))); b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
} else if (type == IN_OUT_CHANNEL) {
if (buffer->dim_size() == 4) { // NHWC
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
} else if (buffer->dim_size() == 2) { // NC
b2f_kernel.setArg(idx++, static_cast<uint32_t>(1));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(1));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
} else {
MACE_NOT_IMPLEMENTED;
}
} else if (type == IN_OUT_WIDTH || type == IN_OUT_HEIGHT) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
if (buffer->dim_size() < 4) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(1));
} else {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
}
} else { } else {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1))); b2f_kernel.setArg(idx++, static_cast<uint32_t>(formatted_buffer_shape[1]));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2))); b2f_kernel.setArg(idx++, static_cast<uint32_t>(formatted_buffer_shape[2]));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3))); b2f_kernel.setArg(idx++, static_cast<uint32_t>(formatted_buffer_shape[3]));
} }
b2f_kernel.setArg(idx++, *(image->opencl_image())); b2f_kernel.setArg(idx++, *(image->opencl_image()));
......
...@@ -42,7 +42,8 @@ MaceStatus MatMulFunctor<DeviceType::GPU, T>::operator()(const Tensor *A, ...@@ -42,7 +42,8 @@ MaceStatus MatMulFunctor<DeviceType::GPU, T>::operator()(const Tensor *A,
c_shape[rank - 2] = height; c_shape[rank - 2] = height;
c_shape[rank - 1] = width; c_shape[rank - 1] = width;
std::vector<size_t> c_image_shape; std::vector<size_t> c_image_shape;
CalImage2DShape(c_shape, BufferType::IN_OUT_HEIGHT, &c_image_shape); std::vector<index_t> padded_c_shape = {batch, height, width, 1};
CalImage2DShape(padded_c_shape, BufferType::IN_OUT_HEIGHT, &c_image_shape);
MACE_RETURN_IF_ERROR(C->ResizeImage(c_shape, c_image_shape)); MACE_RETURN_IF_ERROR(C->ResizeImage(c_shape, c_image_shape));
const index_t height_blocks = RoundUpDiv4(height); const index_t height_blocks = RoundUpDiv4(height);
......
...@@ -94,8 +94,14 @@ MaceStatus WinogradTransformFunctor<DeviceType::GPU, T>::operator()( ...@@ -94,8 +94,14 @@ MaceStatus WinogradTransformFunctor<DeviceType::GPU, T>::operator()(
}; };
if (!IsVecEqual(input_shape_, input_tensor->shape())) { if (!IsVecEqual(input_shape_, input_tensor->shape())) {
output_shape = {blk_sqr, input_tensor->dim(3), out_width}; output_shape = {blk_sqr, input_tensor->dim(3), out_width};
std::vector<index_t> padded_output_shape = {
output_shape[0], output_shape[1], output_shape[2], 1
};
std::vector<size_t> image_shape; std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_HEIGHT, &image_shape); CalImage2DShape(padded_output_shape,
BufferType::IN_OUT_HEIGHT,
&image_shape);
// remove unused last dimension
MACE_RETURN_IF_ERROR(output_tensor->ResizeImage(output_shape, image_shape)); MACE_RETURN_IF_ERROR(output_tensor->ResizeImage(output_shape, image_shape));
uint32_t idx = 0; uint32_t idx = 0;
......
...@@ -216,7 +216,6 @@ class ConverterOption(object): ...@@ -216,7 +216,6 @@ class ConverterOption(object):
self._device = DeviceType.CPU.value self._device = DeviceType.CPU.value
self._winograd_enabled = False self._winograd_enabled = False
self._transformer_option = [ self._transformer_option = [
TransformerRule.ADD_IN_OUT_TENSOR_INFO,
TransformerRule.REMOVE_IDENTITY_OP, TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING, TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.FOLD_RESHAPE, TransformerRule.FOLD_RESHAPE,
...@@ -231,6 +230,7 @@ class ConverterOption(object): ...@@ -231,6 +230,7 @@ class ConverterOption(object):
TransformerRule.FOLD_ACTIVATION, TransformerRule.FOLD_ACTIVATION,
TransformerRule.TRANSPOSE_FILTERS, TransformerRule.TRANSPOSE_FILTERS,
TransformerRule.TRANSPOSE_DATA_FORMAT, TransformerRule.TRANSPOSE_DATA_FORMAT,
TransformerRule.ADD_IN_OUT_TENSOR_INFO,
TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC, TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
TransformerRule.RESHAPE_FC_WEIGHT, TransformerRule.RESHAPE_FC_WEIGHT,
TransformerRule.TRANSFORM_BUFFER_IMAGE, TransformerRule.TRANSFORM_BUFFER_IMAGE,
......
...@@ -55,7 +55,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -55,7 +55,6 @@ class Transformer(base_converter.ConverterInterface):
def __init__(self, option, model): def __init__(self, option, model):
# DO NOT reorder the following transformers' order # DO NOT reorder the following transformers' order
self._registered_transformers_order = [ self._registered_transformers_order = [
TransformerRule.ADD_IN_OUT_TENSOR_INFO,
TransformerRule.REMOVE_IDENTITY_OP, TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING, TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.FOLD_RESHAPE, TransformerRule.FOLD_RESHAPE,
...@@ -71,6 +70,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -71,6 +70,7 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.FOLD_ACTIVATION, TransformerRule.FOLD_ACTIVATION,
TransformerRule.TRANSPOSE_FILTERS, TransformerRule.TRANSPOSE_FILTERS,
TransformerRule.TRANSPOSE_DATA_FORMAT, TransformerRule.TRANSPOSE_DATA_FORMAT,
TransformerRule.ADD_IN_OUT_TENSOR_INFO,
TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC, TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
TransformerRule.RESHAPE_FC_WEIGHT, TransformerRule.RESHAPE_FC_WEIGHT,
TransformerRule.TRANSFORM_BUFFER_IMAGE, TransformerRule.TRANSFORM_BUFFER_IMAGE,
...@@ -78,8 +78,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -78,8 +78,6 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.SORT_BY_EXECUTION, TransformerRule.SORT_BY_EXECUTION,
] ]
self._registered_transformers = { self._registered_transformers = {
TransformerRule.ADD_IN_OUT_TENSOR_INFO:
self.add_in_out_tensor_info,
TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op, TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op,
TransformerRule.TRANSFORM_GLOBAL_POOLING: TransformerRule.TRANSFORM_GLOBAL_POOLING:
self.transform_global_pooling, self.transform_global_pooling,
...@@ -100,6 +98,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -100,6 +98,8 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.FOLD_ACTIVATION: self.fold_activation, TransformerRule.FOLD_ACTIVATION: self.fold_activation,
TransformerRule.TRANSPOSE_FILTERS: self.transpose_filters, TransformerRule.TRANSPOSE_FILTERS: self.transpose_filters,
TransformerRule.TRANSPOSE_DATA_FORMAT: self.transpose_data_format, TransformerRule.TRANSPOSE_DATA_FORMAT: self.transpose_data_format,
TransformerRule.ADD_IN_OUT_TENSOR_INFO:
self.add_in_out_tensor_info,
TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC: TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC:
self.transform_global_conv_to_fc, self.transform_global_conv_to_fc,
TransformerRule.RESHAPE_FC_WEIGHT: self.reshape_fc_weight, TransformerRule.RESHAPE_FC_WEIGHT: self.reshape_fc_weight,
......
...@@ -183,10 +183,12 @@ class GPUMemoryOptimizer(MemoryOptimizer): ...@@ -183,10 +183,12 @@ class GPUMemoryOptimizer(MemoryOptimizer):
mem_block[0] = output_shape[2] mem_block[0] = output_shape[2]
mem_block[1] = output_shape[0] * int((output_shape[1] + 3) / 4) mem_block[1] = output_shape[0] * int((output_shape[1] + 3) / 4)
else: else:
padded_output_shape = ([1, 1, 1, 1] + list(output_shape))[-4:] if len(output_shape) == 2: # only support fc/softmax
mem_block[0] = padded_output_shape[2] * int( mem_block[0] = int((output_shape[1] + 3) / 4)
(padded_output_shape[3] + 3) / 4) mem_block[1] = output_shape[0]
mem_block[1] = padded_output_shape[0] * padded_output_shape[1] else:
mem_block[0] = output_shape[2] * int((output_shape[3] + 3) / 4)
mem_block[1] = output_shape[0] * output_shape[1]
return mem_block return mem_block
def mem_size(self, memory_block): def mem_size(self, memory_block):
......
...@@ -73,24 +73,35 @@ void CreateNetArg(NetDef *net_def) { ...@@ -73,24 +73,35 @@ void CreateNetArg(NetDef *net_def) {
} }
{% endif %} {% endif %}
{% if net.input_info | length > 0 %}
void CreateInputInfo(NetDef *net_def) {
net_def->mutable_input_info()->Reserve({{ net.input_info | length }});
InputInfo *input_info = nullptr;
{% for idx in range(net.input_info|length) %}
input_info = net_def->add_input_info();
input_info->set_name({{ net.input_info[idx].name|tojson }});
input_info->set_data_type(static_cast<DataType>({{ net.input_info[idx].data_type }}));
input_info->mutable_dims()->Reserve({{ net.input_info[idx].dims|length }});
{% for dim in net.input_info[idx].dims %}
input_info->add_dims({{ dim }});
{% endfor %}
{% endfor %}
}
{% endif %}
{% if net.output_info | length > 0 %} {% if net.output_info | length > 0 %}
void CreateOutputInfo(NetDef *net_def) { void CreateOutputInfo(NetDef *net_def) {
std::vector<std::vector<int>> dims { {{net.output_info | map(attribute='dims') | join(', ') | replace('[', '{') | replace(']', '}') }} };
std::vector<int> data_types_int { {{ net.output_info | map(attribute='data_type') | join(', ') }} };
std::vector<mace::DataType> data_types({{ net.output_info | length }});
for (int k = 0; k < {{ net.output_info | length }}; ++k) {
data_types[k] = static_cast<mace::DataType>(data_types_int[k]);
}
net_def->mutable_output_info()->Reserve({{ net.output_info | length }}); net_def->mutable_output_info()->Reserve({{ net.output_info | length }});
for (int i = 0; i < {{ net.output_info | length }}; ++i) { OutputInfo *output_info = nullptr;
auto output_info = net_def->add_output_info(); {% for idx in range(net.output_info|length) %}
output_info->set_data_type(data_types[i]); output_info = net_def->add_output_info();
output_info->mutable_dims()->Reserve(dims[i].size()); output_info->set_name({{ net.output_info[idx].name|tojson }});
for (size_t j = 0; j < dims[i].size(); ++j) { output_info->set_data_type(static_cast<DataType>({{ net.output_info[idx].data_type }}));
output_info->add_dims(dims[i][j]); output_info->mutable_dims()->Reserve({{ net.output_info[idx].dims|length }});
} {% for dim in net.output_info[idx].dims %}
} output_info->add_dims({{dim}});
{% endfor %}
{% endfor %}
} }
{% endif %} {% endif %}
...@@ -147,6 +158,9 @@ const std::shared_ptr<NetDef> CreateNet() { ...@@ -147,6 +158,9 @@ const std::shared_ptr<NetDef> CreateNet() {
{% if net.mem_arena.mem_block|length != 0 %} {% if net.mem_arena.mem_block|length != 0 %}
CreateMemoryArena(net_def->mutable_mem_arena()); CreateMemoryArena(net_def->mutable_mem_arena());
{% endif %} {% endif %}
{% if net.input_info | length > 0 %}
CreateInputInfo(net_def.get());
{% endif %}
{% if net.output_info | length > 0 %} {% if net.output_info | length > 0 %}
CreateOutputInfo(net_def.get()); CreateOutputInfo(net_def.get());
{% endif %} {% endif %}
......
...@@ -154,9 +154,10 @@ def validate_caffe_model(platform, device_type, model_file, input_file, ...@@ -154,9 +154,10 @@ def validate_caffe_model(platform, device_type, model_file, input_file,
for i in range(len(output_names)): for i in range(len(output_names)):
value = net.blobs[net.top_names[output_names[i]][0]].data value = net.blobs[net.top_names[output_names[i]][0]].data
out_shape = output_shapes[i] out_shape = output_shapes[i]
out_shape[1], out_shape[2], out_shape[3] = out_shape[3], out_shape[ if len(out_shape) == 4:
1], out_shape[2] out_shape[1], out_shape[2], out_shape[3] = \
value = value.reshape(out_shape).transpose((0, 2, 3, 1)) out_shape[3], out_shape[1], out_shape[2]
value = value.reshape(out_shape).transpose((0, 2, 3, 1))
output_file_name = common.formatted_file_name( output_file_name = common.formatted_file_name(
mace_out_file, output_names[i]) mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name) mace_out_value = load_data(output_file_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册