// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "lite/core/subgraph_bridge_registry.h" #include "lite/kernels/bm/bridges/graph.h" #include "lite/kernels/bm/bridges/utility.h" namespace paddle { namespace lite { namespace subgraph { namespace bm { int MultiClassNMSConverter(void* ctx, OpLite* op, KernelBase* kernel) { CHECK(ctx != nullptr); CHECK(op != nullptr); auto graph = static_cast(ctx); auto scope = op->scope(); auto op_info = op->op_info(); auto op_type = op_info->Type(); auto boxes_var_name = op_info->Input("BBoxes").front(); auto boxes = scope->FindVar(boxes_var_name)->GetMutable(); auto boxes_dims = boxes->dims(); std::vector i_boxes_shape_data(boxes_dims.size()); for (size_t i = 0; i < boxes_dims.size(); i++) { i_boxes_shape_data[i] = static_cast(boxes_dims[i]); } auto score_var_name = op_info->Input("Scores").front(); auto score = scope->FindVar(score_var_name)->GetMutable(); auto score_dims = score->dims(); std::vector i_score_shape_data(score_dims.size()); for (size_t i = 0; i < score_dims.size(); i++) { i_score_shape_data[i] = static_cast(score_dims[i]); } auto background_label = op_info->GetAttr("background_label"); auto keep_top_k = op_info->GetAttr("keep_top_k"); auto nms_top_k = op_info->GetAttr("nms_top_k"); auto score_threshold = op_info->GetAttr("score_threshold"); auto nms_threshold = op_info->GetAttr("nms_threshold"); auto nms_eta = op_info->GetAttr("nms_eta"); bool normalized; if (op_info->HasAttr("normalized")) { normalized = op_info->GetAttr("normalized"); } auto out_var_name = op_info->Output("Out").front(); auto out = scope->FindVar(out_var_name)->GetMutable(); std::vector vec_out_dim(score_dims.size()); if (3 == score_dims.size()) { vec_out_dim[0] = score_dims[0]; // batch_size vec_out_dim[1] = keep_top_k; vec_out_dim[2] = 6; } else { vec_out_dim[0] = keep_top_k; vec_out_dim[1] = 6; } DDimLite out_dims(vec_out_dim); out->Resize(out_dims); out->mutable_data(); std::vector i_out_shape_data(out_dims.size()); for (size_t i = 0; i < out_dims.size(); i++) { i_out_shape_data[i] = static_cast(out_dims[i]); } user_cpu_param_t bm_param; bm_param.op_type = USER_PADDLE_MULTICLASS_NMS; bm_param.u.multiclass_nms_param.background_label = background_label; bm_param.u.multiclass_nms_param.score_threshold = score_threshold; bm_param.u.multiclass_nms_param.keep_top_k = keep_top_k; bm_param.u.multiclass_nms_param.nms_top_k = nms_top_k; bm_param.u.multiclass_nms_param.nms_threshold = nms_threshold; bm_param.u.multiclass_nms_param.nms_eta = nms_eta; bm_param.u.multiclass_nms_param.normalized = normalized; int32_t input_num = 2; int32_t output_num = 1; int32_t* in_shape[2]; int32_t in_dim[2]; const char* in_name[2]; in_shape[0] = &i_boxes_shape_data[0]; in_shape[1] = &i_score_shape_data[0]; in_dim[0] = boxes_dims.size(); in_dim[1] = score_dims.size(); in_name[0] = static_cast(boxes_var_name.c_str()); in_name[1] = static_cast(score_var_name.c_str()); int32_t* out_shape[1]; int32_t out_dim[1]; const char* out_name[1]; out_shape[0] = &i_out_shape_data[0]; out_dim[0] = out_dims.size(); out_name[0] = static_cast(out_var_name.c_str()); add_user_cpu_layer(graph->GetCompilerHandle(), input_num, in_shape, in_dim, in_name, output_num, out_shape, out_dim, out_name, &bm_param, static_cast(sizeof(bm_param))); graph->AddNode(out_var_name); return SUCCESS; } } // namespace bm } // namespace subgraph } // namespace lite } // namespace paddle REGISTER_SUBGRAPH_BRIDGE(multiclass_nms, kBM, paddle::lite::subgraph::bm::MultiClassNMSConverter);