From c803658dfe0ff9138edba5f1a75700a590efab48 Mon Sep 17 00:00:00 2001 From: Santa An <49897975+AnBaolei1984@users.noreply.github.com> Date: Thu, 27 Feb 2020 14:44:33 +0800 Subject: [PATCH] [lite][bm] support ssd test=develop (#2994) --- lite/api/test_resnet50_lite_bm.cc | 1 + lite/kernels/bm/bridges/CMakeLists.txt | 5 + lite/kernels/bm/bridges/box_coder_op.cc | 126 +++++++++ lite/kernels/bm/bridges/concat_op.cc | 2 + lite/kernels/bm/bridges/multiclass_nms_op.cc | 119 +++++++++ lite/kernels/bm/bridges/paddle_use_bridges.h | 2 + lite/kernels/bm/bridges/prior_box_op.cc | 256 ++++++++++--------- 7 files changed, 384 insertions(+), 127 deletions(-) create mode 100644 lite/kernels/bm/bridges/box_coder_op.cc create mode 100644 lite/kernels/bm/bridges/multiclass_nms_op.cc diff --git a/lite/api/test_resnet50_lite_bm.cc b/lite/api/test_resnet50_lite_bm.cc index cb9cb304b8..73ad405f16 100644 --- a/lite/api/test_resnet50_lite_bm.cc +++ b/lite/api/test_resnet50_lite_bm.cc @@ -32,6 +32,7 @@ namespace lite { void TestModel(const std::vector& valid_places) { lite::Predictor predictor; + std::vector passes; predictor.Build(FLAGS_model_dir, "", "", valid_places, passes); auto* input_tensor = predictor.GetInput(0); diff --git a/lite/kernels/bm/bridges/CMakeLists.txt b/lite/kernels/bm/bridges/CMakeLists.txt index 688e307a64..bd422de76c 100644 --- a/lite/kernels/bm/bridges/CMakeLists.txt +++ b/lite/kernels/bm/bridges/CMakeLists.txt @@ -21,6 +21,9 @@ lite_cc_library(subgraph_bridge_transpose_op_bm SRCS transpose_op.cc DEPS ${bm_s lite_cc_library(subgraph_bridge_reshape_op_bm SRCS reshape_op.cc DEPS ${bm_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_norm_op_bm SRCS norm_op.cc DEPS ${bm_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_prior_box_op_bm SRCS prior_box_op.cc DEPS ${bm_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_box_coder_op_bm SRCS box_coder_op.cc DEPS ${bm_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_multiclass_nms_op_bm SRCS multiclass_nms_op.cc DEPS ${bm_subgraph_bridge_deps}) + set(bm_subgraph_bridges subgraph_bridge_registry subgraph_bridge_engine @@ -39,4 +42,6 @@ set(bm_subgraph_bridges subgraph_bridge_reshape_op_bm subgraph_bridge_norm_op_bm subgraph_bridge_prior_box_op_bm + subgraph_bridge_box_coder_op_bm + subgraph_bridge_multiclass_nms_op_bm CACHE INTERNAL "bm_subgraph_bridges") diff --git a/lite/kernels/bm/bridges/box_coder_op.cc b/lite/kernels/bm/bridges/box_coder_op.cc new file mode 100644 index 0000000000..67f5104c8b --- /dev/null +++ b/lite/kernels/bm/bridges/box_coder_op.cc @@ -0,0 +1,126 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include +#include +#include +#include "lite/kernels/bm/bridges/graph.h" +#include "lite/kernels/bm/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace bm { + +int BoxCoderConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + + auto graph = static_cast(ctx); + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto box_var_name = op_info->Input("PriorBox").front(); + auto box = scope->FindVar(box_var_name)->GetMutable(); + auto box_dims = box->dims(); + auto box_var_var_name = op_info->Input("PriorBoxVar").front(); + auto box_var = scope->FindVar(box_var_var_name)->GetMutable(); + auto box_var_dims = box_var->dims(); + auto target_box_var_name = op_info->Input("TargetBox").front(); + auto target_box = + scope->FindVar(target_box_var_name)->GetMutable(); + auto target_box_dims = target_box->dims(); + auto output_var_name = op_info->Output("OutputBox").front(); + auto output = scope->FindVar(output_var_name)->GetMutable(); + auto output_dims = output->dims(); + + std::vector i_box_shape_data(box_dims.size()); + for (size_t i = 0; i < box_dims.size(); i++) { + i_box_shape_data[i] = static_cast(box_dims[i]); + } + std::vector i_box_var_shape_data(box_var_dims.size()); + for (size_t i = 0; i < box_var_dims.size(); i++) { + i_box_var_shape_data[i] = static_cast(box_var_dims[i]); + } + std::vector i_target_box_shape_data(target_box_dims.size()); + for (size_t i = 0; i < target_box_dims.size(); i++) { + i_target_box_shape_data[i] = static_cast(target_box_dims[i]); + } + std::vector i_output_shape_data(output_dims.size()); + for (size_t i = 0; i < output_dims.size(); i++) { + i_output_shape_data[i] = static_cast(output_dims[i]); + } + auto code_type = op_info->GetAttr("code_type"); + auto box_normalized = op_info->GetAttr("box_normalized"); + int32_t axis = 0; + if (op_info->HasAttr("axis")) { + axis = op_info->GetAttr("axis"); + } + std::vector variance; + if (op_info->HasAttr("variance")) { + variance = op_info->GetAttr>("variance"); + } + user_cpu_param_t bm_param; + bm_param.op_type = USER_PADDLE_BOX_CODER; + bm_param.u.box_coder_param.axis = axis; + bm_param.u.box_coder_param.variance = &variance[0]; + bm_param.u.box_coder_param.code_type = + (code_type == "encode_center_size") ? 0 : 1; + bm_param.u.box_coder_param.normalized = box_normalized; + int32_t input_num = 3; + int32_t output_num = 1; + int32_t* in_shape[3]; + int32_t in_dim[3]; + const char* in_name[3]; + in_shape[0] = &i_box_shape_data[0]; + in_shape[1] = &i_target_box_shape_data[0]; + in_shape[2] = &i_box_var_shape_data[0]; + in_dim[0] = box_dims.size(); + in_dim[1] = target_box_dims.size(); + in_dim[2] = box_var_dims.size(); + in_name[0] = static_cast(box_var_name.c_str()); + in_name[1] = static_cast(target_box_var_name.c_str()); + in_name[2] = static_cast(box_var_var_name.c_str()); + int32_t* out_shape[1]; + int32_t out_dim[1]; + const char* out_name[1]; + out_shape[0] = &i_output_shape_data[0]; + out_dim[0] = output_dims.size(); + out_name[0] = static_cast(output_var_name.c_str()); + + add_user_cpu_layer(graph->GetCompilerHandle(), + input_num, + in_shape, + in_dim, + in_name, + output_num, + out_shape, + out_dim, + out_name, + &bm_param, + static_cast(sizeof(bm_param))); + graph->AddNode(output_var_name); + return SUCCESS; +} + +} // namespace bm +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(box_coder, + kBM, + paddle::lite::subgraph::bm::BoxCoderConverter); diff --git a/lite/kernels/bm/bridges/concat_op.cc b/lite/kernels/bm/bridges/concat_op.cc index 9a8729aa8d..0b568aa4d1 100644 --- a/lite/kernels/bm/bridges/concat_op.cc +++ b/lite/kernels/bm/bridges/concat_op.cc @@ -30,6 +30,8 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto op_type = op_info->Type(); // input auto x_names = op_info->Input("X"); + auto x_type = kernel->GetInputDeclType("X"); + CHECK(x_type->layout() == DATALAYOUT(kNCHW)); // output auto output_var_name = op_info->Output("Out").front(); auto output = scope->FindVar(output_var_name)->GetMutable(); diff --git a/lite/kernels/bm/bridges/multiclass_nms_op.cc b/lite/kernels/bm/bridges/multiclass_nms_op.cc new file mode 100644 index 0000000000..6e7520f272 --- /dev/null +++ b/lite/kernels/bm/bridges/multiclass_nms_op.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/kernels/bm/bridges/graph.h" +#include "lite/kernels/bm/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace bm { + +int MultiClassNMSConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto scope = op->scope(); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto boxes_var_name = op_info->Input("BBoxes").front(); + auto boxes = scope->FindVar(boxes_var_name)->GetMutable(); + auto boxes_dims = boxes->dims(); + std::vector i_boxes_shape_data(boxes_dims.size()); + for (size_t i = 0; i < boxes_dims.size(); i++) { + i_boxes_shape_data[i] = static_cast(boxes_dims[i]); + } + auto score_var_name = op_info->Input("Scores").front(); + auto score = scope->FindVar(score_var_name)->GetMutable(); + auto score_dims = score->dims(); + std::vector i_score_shape_data(score_dims.size()); + for (size_t i = 0; i < score_dims.size(); i++) { + i_score_shape_data[i] = static_cast(score_dims[i]); + } + + auto out_var_name = op_info->Output("Out").front(); + auto out = scope->FindVar(out_var_name)->GetMutable(); + auto out_dims = out->dims(); + std::vector i_out_shape_data(out_dims.size()); + for (size_t i = 0; i < out_dims.size(); i++) { + i_out_shape_data[i] = static_cast(out_dims[i]); + } + + auto background_label = op_info->GetAttr("background_label"); + auto keep_top_k = op_info->GetAttr("keep_top_k"); + auto nms_top_k = op_info->GetAttr("nms_top_k"); + auto score_threshold = op_info->GetAttr("score_threshold"); + auto nms_threshold = op_info->GetAttr("nms_threshold"); + auto nms_eta = op_info->GetAttr("nms_eta"); + bool normalized; + if (op_info->HasAttr("normalized")) { + normalized = op_info->GetAttr("normalized"); + } + + user_cpu_param_t bm_param; + bm_param.op_type = USER_PADDLE_MULTICLASS_NMS; + bm_param.u.multiclass_nms_param.background_label = background_label; + bm_param.u.multiclass_nms_param.score_threshold = score_threshold; + bm_param.u.multiclass_nms_param.keep_top_k = keep_top_k; + bm_param.u.multiclass_nms_param.nms_top_k = nms_top_k; + bm_param.u.multiclass_nms_param.nms_threshold = nms_threshold; + bm_param.u.multiclass_nms_param.nms_eta = nms_eta; + bm_param.u.multiclass_nms_param.normalized = normalized; + + int32_t input_num = 2; + int32_t output_num = 1; + int32_t* in_shape[2]; + int32_t in_dim[2]; + const char* in_name[2]; + in_shape[0] = &i_boxes_shape_data[0]; + in_shape[1] = &i_score_shape_data[0]; + in_dim[0] = boxes_dims.size(); + in_dim[1] = score_dims.size(); + in_name[0] = static_cast(boxes_var_name.c_str()); + in_name[1] = static_cast(score_var_name.c_str()); + int32_t* out_shape[1]; + int32_t out_dim[1]; + const char* out_name[1]; + i_out_shape_data[0] = keep_top_k; + i_out_shape_data[1] = 6; + out_shape[0] = &i_out_shape_data[0]; + out_dim[0] = 2; + out_name[0] = static_cast(out_var_name.c_str()); + + add_user_cpu_layer(graph->GetCompilerHandle(), + input_num, + in_shape, + in_dim, + in_name, + output_num, + out_shape, + out_dim, + out_name, + &bm_param, + static_cast(sizeof(bm_param))); + graph->AddNode(out_var_name); + return SUCCESS; +} + +} // namespace bm +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(multiclass_nms, + kBM, + paddle::lite::subgraph::bm::MultiClassNMSConverter); diff --git a/lite/kernels/bm/bridges/paddle_use_bridges.h b/lite/kernels/bm/bridges/paddle_use_bridges.h index fdaf70de6a..72820e965f 100644 --- a/lite/kernels/bm/bridges/paddle_use_bridges.h +++ b/lite/kernels/bm/bridges/paddle_use_bridges.h @@ -36,3 +36,5 @@ USE_SUBGRAPH_BRIDGE(flatten, kBM); USE_SUBGRAPH_BRIDGE(flatten2, kBM); USE_SUBGRAPH_BRIDGE(norm, kBM); USE_SUBGRAPH_BRIDGE(prior_box, kBM); +USE_SUBGRAPH_BRIDGE(box_coder, kBM); +USE_SUBGRAPH_BRIDGE(multiclass_nms, kBM); diff --git a/lite/kernels/bm/bridges/prior_box_op.cc b/lite/kernels/bm/bridges/prior_box_op.cc index 17c3fbf034..de30d0e318 100644 --- a/lite/kernels/bm/bridges/prior_box_op.cc +++ b/lite/kernels/bm/bridges/prior_box_op.cc @@ -83,127 +83,106 @@ float* compute_priorbox_kernel(OpLite* op, st_priorbox_param* param) { for (size_t i = 0; i < expand_aspect_ratios.size(); i++) { param->aspect_ratios.push_back(expand_aspect_ratios[i]); } - param->prior_num = param->aspect_ratios.size() * param->min_sizes.size(); + + auto img_width = img_dims[3]; + auto img_height = img_dims[2]; + auto feature_width = in_dims[3]; + auto feature_height = in_dims[2]; + float step_width, step_height; + if (param->step_w == 0.f || param->step_h == 0.f) { + step_width = static_cast(img_width) / feature_width; + step_height = static_cast(img_height) / feature_height; + } else { + step_width = param->step_w; + step_height = param->step_h; + } + int num_priors = param->aspect_ratios.size() * param->min_sizes.size(); if (param->max_sizes.size() > 0) { - param->prior_num += param->max_sizes.size(); + num_priors += param->max_sizes.size(); } - int32_t win1 = in_dims[3]; - int32_t hin1 = in_dims[2]; - DDim shape_out({hin1, win1, param->prior_num, 4}); + param->prior_num = num_priors; + DDim shape_out({feature_height, feature_width, num_priors, 4}); + int32_t channel_size = feature_height * feature_width * num_priors * 4; boxes->Resize(shape_out); var->Resize(shape_out); - // boxes->mutable_data(); - // var->mutable_data(); float* cpu_data = static_cast(malloc(sizeof(float) * boxes->data_size() * 2)); CHECK(cpu_data != nullptr); - const int32_t width = in_dims[3]; - const int32_t height = in_dims[2]; - int32_t img_width = param->img_w; - int32_t img_height = param->img_h; - if (img_width == 0 || img_height == 0) { - img_width = img_dims[3]; - img_height = img_dims[2]; - } - float step_w = param->step_w; - float step_h = param->step_h; - if (step_w == 0.f || step_h == 0.f) { - step_w = static_cast(img_width) / width; - step_h = static_cast(img_height) / height; - } - float offset = param->offset; - int32_t channel_size = height * width * param->prior_num * 4; - int32_t idx = 0; - /////////////////////////////////////////////////////////////////////// - for (int32_t h = 0; h < height; ++h) { - for (int32_t w = 0; w < width; ++w) { - float center_x = (w + offset) * step_w; - float center_y = (h + offset) * step_h; - float box_width = 0.f; - float box_height = 0.f; - float* min_buf = reinterpret_cast(malloc(sizeof(float) * 4)); - float* max_buf = reinterpret_cast(malloc(sizeof(float) * 4)); - float* com_buf = reinterpret_cast( - malloc(sizeof(float) * expand_aspect_ratios.size() * 4)); - CHECK(min_buf != nullptr); - CHECK(max_buf != nullptr); - CHECK(com_buf != nullptr); - // LOG(INFO) << "the number of min_size is " << min_sizes_.size(); + float* b_t = cpu_data; + for (int h = 0; h < feature_height; ++h) { + for (int w = 0; w < feature_width; ++w) { + float center_x = (w + param->offset) * step_width; + float center_y = (h + param->offset) * step_height; + float box_width, box_height; for (size_t s = 0; s < param->min_sizes.size(); ++s) { - int32_t min_idx = 0; - int32_t max_idx = 0; - int32_t com_idx = 0; - int32_t min_size = param->min_sizes[s]; - //! first prior: aspect_ratio = 1, size = min_size - box_width = box_height = min_size; - //! xmin - min_buf[min_idx++] = (center_x - box_width / 2.f) / img_width; - //! ymin - min_buf[min_idx++] = (center_y - box_height / 2.f) / img_height; - //! xmax - min_buf[min_idx++] = (center_x + box_width / 2.f) / img_width; - //! ymax - min_buf[min_idx++] = (center_y + box_height / 2.f) / img_height; - if (param->max_sizes.size() > 0) { - int max_size = param->max_sizes[s]; - //! second prior: aspect_ratio = 1, size = sqrt(min_size * max_size) - box_width = box_height = sqrtf(min_size * max_size); - //! xmin - max_buf[max_idx++] = (center_x - box_width / 2.f) / img_width; - //! ymin - max_buf[max_idx++] = (center_y - box_height / 2.f) / img_height; - //! xmax - max_buf[max_idx++] = (center_x + box_width / 2.f) / img_width; - //! ymax - max_buf[max_idx++] = (center_y + box_height / 2.f) / img_height; - } - //! rest of priors - for (size_t r = 0; r < expand_aspect_ratios.size(); ++r) { - float ar = expand_aspect_ratios[r]; - if (fabs(ar - 1.) < 1e-6) { - continue; - } - box_width = min_size * sqrt(ar); - box_height = min_size / sqrt(ar); - //! xmin - com_buf[com_idx++] = (center_x - box_width / 2.f) / img_width; - //! ymin - com_buf[com_idx++] = (center_y - box_height / 2.f) / img_height; - //! xmax - com_buf[com_idx++] = (center_x + box_width / 2.f) / img_width; - //! ymax - com_buf[com_idx++] = (center_y + box_height / 2.f) / img_height; - } + auto min_size = param->min_sizes[s]; if (param->min_max_aspect_ratios_order) { - memcpy(cpu_data + idx, min_buf, sizeof(float) * min_idx); - idx += min_idx; - memcpy(cpu_data + idx, max_buf, sizeof(float) * max_idx); - idx += max_idx; - memcpy(cpu_data + idx, com_buf, sizeof(float) * com_idx); - idx += com_idx; + box_width = box_height = min_size / 2.; + b_t[0] = (center_x - box_width) / img_width; + b_t[1] = (center_y - box_height) / img_height; + b_t[2] = (center_x + box_width) / img_width; + b_t[3] = (center_y + box_height) / img_height; + b_t += 4; + if (param->max_sizes.size() > 0) { + auto max_size = param->max_sizes[s]; + // square prior with size sqrt(minSize * maxSize) + box_width = box_height = sqrt(min_size * max_size) / 2.; + b_t[0] = (center_x - box_width) / img_width; + b_t[1] = (center_y - box_height) / img_height; + b_t[2] = (center_x + box_width) / img_width; + b_t[3] = (center_y + box_height) / img_height; + b_t += 4; + } + // priors with different aspect ratios + for (size_t r = 0; r < param->aspect_ratios.size(); ++r) { + float ar = param->aspect_ratios[r]; + if (fabs(ar - 1.) < 1e-6) { + continue; + } + box_width = min_size * sqrt(ar) / 2.; + box_height = min_size / sqrt(ar) / 2.; + b_t[0] = (center_x - box_width) / img_width; + b_t[1] = (center_y - box_height) / img_height; + b_t[2] = (center_x + box_width) / img_width; + b_t[3] = (center_y + box_height) / img_height; + b_t += 4; + } } else { - memcpy(cpu_data + idx, com_buf, sizeof(float) * com_idx); - idx += com_idx; - memcpy(cpu_data + idx, max_buf, sizeof(float) * max_idx); - idx += max_idx; + // priors with different aspect ratios + for (size_t r = 0; r < param->aspect_ratios.size(); ++r) { + float ar = param->aspect_ratios[r]; + box_width = min_size * sqrt(ar) / 2.; + box_height = min_size / sqrt(ar) / 2.; + b_t[0] = (center_x - box_width) / img_width; + b_t[1] = (center_y - box_height) / img_height; + b_t[2] = (center_x + box_width) / img_width; + b_t[3] = (center_y + box_height) / img_height; + b_t += 4; + } + if (param->max_sizes.size() > 0) { + auto max_size = param->max_sizes[s]; + // square prior with size sqrt(minSize * maxSize) + box_width = box_height = sqrt(min_size * max_size) / 2.; + b_t[0] = (center_x - box_width) / img_width; + b_t[1] = (center_y - box_height) / img_height; + b_t[2] = (center_x + box_width) / img_width; + b_t[3] = (center_y + box_height) / img_height; + b_t += 4; + } } } - free(min_buf); - free(max_buf); - free(com_buf); } } - //! clip the prior's coordidate such that it is within [0, 1] + if (param->clip) { for (int32_t d = 0; d < channel_size; ++d) { cpu_data[d] = std::min(std::max(cpu_data[d], 0.f), 1.f); } } - //! set the variance. float* ptr = cpu_data + channel_size; int count = 0; - for (int32_t h = 0; h < height; ++h) { - for (int32_t w = 0; w < width; ++w) { + for (int32_t h = 0; h < feature_height; ++h) { + for (int32_t w = 0; w < feature_width; ++w) { for (int32_t i = 0; i < param->prior_num; ++i) { for (int j = 0; j < 4; ++j) { ptr[count] = param->variances[j]; @@ -237,7 +216,6 @@ int PriorBoxConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto boxes_var_name = op_info->Output("Boxes").front(); auto boxes = scope->FindVar(boxes_var_name)->GetMutable(); auto var_var_name = op_info->Output("Variances").front(); - auto unique_op_name = lite::subgraph::bm::UniqueName(op_type); // param st_priorbox_param param; param.clip = op_info->GetAttr("clip"); @@ -269,20 +247,19 @@ int PriorBoxConverter(void* ctx, OpLite* op, KernelBase* kernel) { op_info->GetAttr("min_max_aspect_ratios_order"); } float* cpu_data = compute_priorbox_kernel(op, ¶m); - compute_priorbox_kernel(op, param); auto boxes_dims = boxes->dims(); - std::vector i_pri_out_shape_data(boxes_dims.size()); - for (size_t i = 0; i < boxes_dims.size(); i++) { - i_pri_out_shape_data[i] = static_cast(boxes_dims[i]); - } - i_pri_out_shape_data[0] *= 2; + std::vector i_pri_out_shape_data(3); + i_pri_out_shape_data[0] = 1; + i_pri_out_shape_data[1] = 2; + i_pri_out_shape_data[2] = boxes->data_size(); + auto bm_priorbox_name = lite::subgraph::bm::UniqueName("bm_priorbox"); add_priorbox_layer(graph->GetCompilerHandle(), const_cast(&i_input_shape_data[0]), in_dims.size(), static_cast(in_var_name.c_str()), const_cast(&i_pri_out_shape_data[0]), - boxes_dims.size(), - static_cast(unique_op_name.c_str()), + 3, + static_cast(bm_priorbox_name.c_str()), static_cast(cpu_data), param.min_sizes.size(), const_cast(¶m.min_sizes[0]), @@ -299,32 +276,57 @@ int PriorBoxConverter(void* ctx, OpLite* op, KernelBase* kernel) { param.step_h, param.step_w, param.offset); - std::vector i_output_shape_data(boxes_dims.size()); - for (size_t i = 0; i < boxes_dims.size(); i++) { - i_output_shape_data[i] = static_cast(boxes_dims[i]); - } int32_t* shape[2]; - int dim[2]; + int32_t dim[2]; const char* name[2]; - dim[0] = boxes_dims.size(); - dim[1] = boxes_dims.size(); - name[0] = static_cast(boxes_var_name.c_str()); - name[1] = static_cast(var_var_name.c_str()); - shape[0] = &i_output_shape_data[0]; - shape[1] = &i_output_shape_data[0]; - int split_size = 2; + int32_t dim_size = 3; + dim[0] = dim_size; + dim[1] = dim_size; + std::vector i_split_shape_data(dim_size); + for (size_t i = 0; i < dim_size; i++) { + i_split_shape_data[i] = i_pri_out_shape_data[i]; + } + i_split_shape_data[1] /= 2; + shape[0] = &i_split_shape_data[0]; + shape[1] = &i_split_shape_data[0]; + name[0] = static_cast( + lite::subgraph::bm::UniqueName("bm_boxes").c_str()); + name[1] = static_cast( + lite::subgraph::bm::UniqueName("bm_boxes_var").c_str()); + int split_size[2]; + split_size[0] = shape[0][1]; + split_size[1] = shape[1][1]; add_tf_split_layer(graph->GetCompilerHandle(), const_cast(&i_pri_out_shape_data[0]), - boxes_dims.size(), - static_cast(unique_op_name.c_str()), + 3, + static_cast(bm_priorbox_name.c_str()), 2, shape, dim, name, - boxes_dims.size(), - 0, - &split_size, - 0); + 3, + 1, + split_size, + 2); + // final output + std::vector i_output_shape_data(boxes_dims.size()); + for (size_t i = 0; i < boxes_dims.size(); i++) { + i_output_shape_data[i] = static_cast(boxes_dims[i]); + } + add_reshape_layer_v2(graph->GetCompilerHandle(), + name[0], + shape[0], + 3, + static_cast(boxes_var_name.c_str()), + const_cast(&i_output_shape_data[0]), + boxes_dims.size()); + add_reshape_layer_v2(graph->GetCompilerHandle(), + name[1], + shape[1], + 3, + static_cast(var_var_name.c_str()), + const_cast(&i_output_shape_data[0]), + boxes_dims.size()); graph->AddNode(boxes_var_name); graph->AddNode(var_var_name); return SUCCESS; -- GitLab