diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt index b2a57cb01f951b1ee980ce99e717de09c4c6934a..eda3849edde7ed0774efe00456acdc2c6d5f2a9f 100644 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -16,3 +16,4 @@ add_kernel(assign_compute_host Host extra SRCS assign_compute.cc DEPS ${lite_ker add_kernel(print_compute_host Host extra SRCS print_compute.cc DEPS ${lite_kernel_deps}) add_kernel(while_compute_host Host extra SRCS while_compute.cc DEPS ${lite_kernel_deps} program) add_kernel(conditional_block_compute_host Host extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} program) +add_kernel(retinanet_detection_output_compute_host Host extra SRCS retinanet_detection_output_compute.cc DEPS ${lite_kernel_deps}) diff --git a/lite/kernels/host/retinanet_detection_output_compute.cc b/lite/kernels/host/retinanet_detection_output_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..95a4bf708e7f03aee9d9ac99323b173287260b13 --- /dev/null +++ b/lite/kernels/host/retinanet_detection_output_compute.cc @@ -0,0 +1,435 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/host/retinanet_detection_output_compute.h" +#include +#include +#include +#include +#include "lite/operators/retinanet_detection_output_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +template +bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +template +bool SortScoreTwoPairDescend(const std::pair>& pair1, + const std::pair>& pair2) { + return pair1.first > pair2.first; +} + +template +static inline void GetMaxScoreIndex( + const std::vector& scores, + const T threshold, + int top_k, + std::vector>* sorted_indices) { + for (size_t i = 0; i < scores.size(); ++i) { + if (scores[i] > threshold) { + sorted_indices->push_back(std::make_pair(scores[i], i)); + } + } + // Sort the score pair according to the scores in descending order + std::stable_sort(sorted_indices->begin(), + sorted_indices->end(), + SortScorePairDescend); + // Keep top_k scores if needed. + if (top_k > -1 && top_k < static_cast(sorted_indices->size())) { + sorted_indices->resize(top_k); + } +} + +template +static inline T BBoxArea(const std::vector& box, const bool normalized) { + if (box[2] < box[0] || box[3] < box[1]) { + // If coordinate values are is invalid + // (e.g. xmax < xmin or ymax < ymin), return 0. + return static_cast(0.); + } else { + const T w = box[2] - box[0]; + const T h = box[3] - box[1]; + if (normalized) { + return w * h; + } else { + // If coordinate values are not within range [0, 1]. + return (w + 1) * (h + 1); + } + } +} + +template +static inline T JaccardOverlap(const std::vector& box1, + const std::vector& box2, + const bool normalized) { + if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || + box2[3] < box1[1]) { + return static_cast(0.); + } else { + const T inter_xmin = std::max(box1[0], box2[0]); + const T inter_ymin = std::max(box1[1], box2[1]); + const T inter_xmax = std::min(box1[2], box2[2]); + const T inter_ymax = std::min(box1[3], box2[3]); + T norm = normalized ? static_cast(0.) : static_cast(1.); + T inter_w = inter_xmax - inter_xmin + norm; + T inter_h = inter_ymax - inter_ymin + norm; + const T inter_area = inter_w * inter_h; + const T bbox1_area = BBoxArea(box1, normalized); + const T bbox2_area = BBoxArea(box2, normalized); + return inter_area / (bbox1_area + bbox2_area - inter_area); + } +} + +template +void NMSFast(const std::vector>& cls_dets, + const T nms_threshold, + const T eta, + std::vector* selected_indices) { + int64_t num_boxes = cls_dets.size(); + std::vector> sorted_indices; + for (int64_t i = 0; i < num_boxes; ++i) { + sorted_indices.push_back(std::make_pair(cls_dets[i][4], i)); + } + // Sort the score pair according to the scores in descending order + std::stable_sort( + sorted_indices.begin(), sorted_indices.end(), SortScorePairDescend); + selected_indices->clear(); + T adaptive_threshold = nms_threshold; + + while (sorted_indices.size() != 0) { + const int idx = sorted_indices.front().second; + bool keep = true; + for (size_t k = 0; k < selected_indices->size(); ++k) { + if (keep) { + const int kept_idx = (*selected_indices)[k]; + T overlap = T(0.); + + overlap = JaccardOverlap(cls_dets[idx], cls_dets[kept_idx], false); + keep = overlap <= adaptive_threshold; + } else { + break; + } + } + if (keep) { + selected_indices->push_back(idx); + } + sorted_indices.erase(sorted_indices.begin()); + if (keep && eta < 1 && adaptive_threshold > 0.5) { + adaptive_threshold *= eta; + } + } +} + +template +void DeltaScoreToPrediction( + const std::vector& bboxes_data, + const std::vector& anchors_data, + T im_height, + T im_width, + T im_scale, + int class_num, + const std::vector>& sorted_indices, + std::map>>* preds) { + im_height = static_cast(std::round(im_height / im_scale)); + im_width = static_cast(std::round(im_width / im_scale)); + T zero(0); + int i = 0; + for (const auto& it : sorted_indices) { + T score = it.first; + int idx = it.second; + int a = idx / class_num; + int c = idx % class_num; + + int box_offset = a * 4; + T anchor_box_width = + anchors_data[box_offset + 2] - anchors_data[box_offset] + 1; + T anchor_box_height = + anchors_data[box_offset + 3] - anchors_data[box_offset + 1] + 1; + T anchor_box_center_x = anchors_data[box_offset] + anchor_box_width / 2; + T anchor_box_center_y = + anchors_data[box_offset + 1] + anchor_box_height / 2; + T target_box_center_x = 0, target_box_center_y = 0; + T target_box_width = 0, target_box_height = 0; + target_box_center_x = + bboxes_data[box_offset] * anchor_box_width + anchor_box_center_x; + target_box_center_y = + bboxes_data[box_offset + 1] * anchor_box_height + anchor_box_center_y; + target_box_width = std::exp(bboxes_data[box_offset + 2]) * anchor_box_width; + target_box_height = + std::exp(bboxes_data[box_offset + 3]) * anchor_box_height; + T pred_box_xmin = target_box_center_x - target_box_width / 2; + T pred_box_ymin = target_box_center_y - target_box_height / 2; + T pred_box_xmax = target_box_center_x + target_box_width / 2 - 1; + T pred_box_ymax = target_box_center_y + target_box_height / 2 - 1; + pred_box_xmin = pred_box_xmin / im_scale; + pred_box_ymin = pred_box_ymin / im_scale; + pred_box_xmax = pred_box_xmax / im_scale; + pred_box_ymax = pred_box_ymax / im_scale; + + pred_box_xmin = std::max(std::min(pred_box_xmin, im_width - 1), zero); + pred_box_ymin = std::max(std::min(pred_box_ymin, im_height - 1), zero); + pred_box_xmax = std::max(std::min(pred_box_xmax, im_width - 1), zero); + pred_box_ymax = std::max(std::min(pred_box_ymax, im_height - 1), zero); + + std::vector one_pred; + one_pred.push_back(pred_box_xmin); + one_pred.push_back(pred_box_ymin); + one_pred.push_back(pred_box_xmax); + one_pred.push_back(pred_box_ymax); + one_pred.push_back(score); + (*preds)[c].push_back(one_pred); + i++; + } +} + +template +void MultiClassNMS(const std::map>>& preds, + int class_num, + const int keep_top_k, + const T nms_threshold, + const T nms_eta, + std::vector>* nmsed_out, + int* num_nmsed_out) { + std::map> indices; + int num_det = 0; + for (int c = 0; c < class_num; ++c) { + if (static_cast(preds.count(c))) { + const std::vector> cls_dets = preds.at(c); + NMSFast(cls_dets, nms_threshold, nms_eta, &(indices[c])); + num_det += indices[c].size(); + } + } + + std::vector>> score_index_pairs; + for (const auto& it : indices) { + int label = it.first; + const std::vector& label_indices = it.second; + for (size_t j = 0; j < label_indices.size(); ++j) { + int idx = label_indices[j]; + score_index_pairs.push_back( + std::make_pair(preds.at(label)[idx][4], std::make_pair(label, idx))); + } + } + // Keep top k results per image. + std::stable_sort(score_index_pairs.begin(), + score_index_pairs.end(), + SortScoreTwoPairDescend); + if (num_det > keep_top_k) { + score_index_pairs.resize(keep_top_k); + } + + // Store the new indices. + std::map> new_indices; + for (const auto& it : score_index_pairs) { + int label = it.second.first; + int idx = it.second.second; + std::vector one_pred; + one_pred.push_back(label); + one_pred.push_back(preds.at(label)[idx][4]); + one_pred.push_back(preds.at(label)[idx][0]); + one_pred.push_back(preds.at(label)[idx][1]); + one_pred.push_back(preds.at(label)[idx][2]); + one_pred.push_back(preds.at(label)[idx][3]); + nmsed_out->push_back(one_pred); + } + + *num_nmsed_out = (num_det > keep_top_k ? keep_top_k : num_det); +} + +template +void RetinanetDetectionOutput( + const operators::RetinanetDetectionOutputParam& param, + const std::vector& scores, + const std::vector& bboxes, + const std::vector& anchors, + const Tensor& im_info, + std::vector>* nmsed_out, + int* num_nmsed_out) { + int64_t nms_top_k = param.nms_top_k; + int64_t keep_top_k = param.keep_top_k; + T nms_threshold = static_cast(param.nms_threshold); + T nms_eta = static_cast(param.nms_eta); + T score_threshold = static_cast(param.score_threshold); + + int64_t class_num = scores[0].dims()[1]; + std::map>> preds; + for (size_t l = 0; l < scores.size(); ++l) { + // Fetch per level score + Tensor scores_per_level = scores[l]; + // Fetch per level bbox + Tensor bboxes_per_level = bboxes[l]; + // Fetch per level anchor + Tensor anchors_per_level = anchors[l]; + + int64_t scores_num = scores_per_level.numel(); + int64_t bboxes_num = bboxes_per_level.numel(); + std::vector scores_data(scores_num); + std::vector bboxes_data(bboxes_num); + std::vector anchors_data(bboxes_num); + std::copy_n(scores_per_level.data(), scores_num, scores_data.begin()); + std::copy_n(bboxes_per_level.data(), bboxes_num, bboxes_data.begin()); + std::copy_n(anchors_per_level.data(), bboxes_num, anchors_data.begin()); + std::vector> sorted_indices; + + // For the highest level, we take the threshold 0.0 + T threshold = (l < (scores.size() - 1) ? score_threshold : 0.0); + GetMaxScoreIndex(scores_data, threshold, nms_top_k, &sorted_indices); + auto* im_info_data = im_info.data(); + auto im_height = im_info_data[0]; + auto im_width = im_info_data[1]; + auto im_scale = im_info_data[2]; + DeltaScoreToPrediction(bboxes_data, + anchors_data, + im_height, + im_width, + im_scale, + class_num, + sorted_indices, + &preds); + } + + MultiClassNMS(preds, + class_num, + keep_top_k, + nms_threshold, + nms_eta, + nmsed_out, + num_nmsed_out); +} + +template +void MultiClassOutput(const std::vector>& nmsed_out, + Tensor* outs) { + auto* odata = outs->mutable_data(); + int count = 0; + int64_t out_dim = 6; + for (size_t i = 0; i < nmsed_out.size(); ++i) { + odata[count * out_dim] = nmsed_out[i][0] + 1; // label + odata[count * out_dim + 1] = nmsed_out[i][1]; // score + odata[count * out_dim + 2] = nmsed_out[i][2]; // xmin + odata[count * out_dim + 3] = nmsed_out[i][3]; // xmin + odata[count * out_dim + 4] = nmsed_out[i][4]; // xmin + odata[count * out_dim + 5] = nmsed_out[i][5]; // xmin + count++; + } +} + +void RetinanetDetectionOutputCompute::Run() { + auto& param = Param(); + auto& boxes = param.bboxes; + auto& scores = param.scores; + auto& anchors = param.anchors; + auto* im_info = param.im_info; + auto* outs = param.out; + + std::vector boxes_list(boxes.size()); + std::vector scores_list(scores.size()); + std::vector anchors_list(anchors.size()); + for (size_t j = 0; j < boxes_list.size(); ++j) { + boxes_list[j] = *boxes[j]; + scores_list[j] = *scores[j]; + anchors_list[j] = *anchors[j]; + } + auto score_dims = scores_list[0].dims(); + int64_t batch_size = score_dims[0]; + auto box_dims = boxes_list[0].dims(); + int64_t box_dim = box_dims[2]; + int64_t out_dim = box_dim + 2; + + std::vector>> all_nmsed_out; + std::vector batch_starts = {0}; + for (int i = 0; i < batch_size; ++i) { + int num_nmsed_out = 0; + std::vector box_per_batch_list(boxes_list.size()); + std::vector score_per_batch_list(scores_list.size()); + for (size_t j = 0; j < boxes_list.size(); ++j) { + auto score_dims = scores_list[j].dims(); + score_per_batch_list[j] = scores_list[j].Slice(i, i + 1); + score_per_batch_list[j].Resize({score_dims[1], score_dims[2]}); + box_per_batch_list[j] = boxes_list[j].Slice(i, i + 1); + box_per_batch_list[j].Resize({score_dims[1], box_dim}); + } + Tensor im_info_slice = im_info->Slice(i, i + 1); + + std::vector> nmsed_out; + RetinanetDetectionOutput(param, + score_per_batch_list, + box_per_batch_list, + anchors_list, + im_info_slice, + &nmsed_out, + &num_nmsed_out); + all_nmsed_out.push_back(nmsed_out); + batch_starts.push_back(batch_starts.back() + num_nmsed_out); + } + + uint64_t num_kept = batch_starts.back(); + if (num_kept == 0) { + outs->Resize({0, out_dim}); + } else { + outs->Resize({static_cast(num_kept), out_dim}); + for (int i = 0; i < batch_size; ++i) { + int64_t s = static_cast(batch_starts[i]); + int64_t e = static_cast(batch_starts[i + 1]); + if (e > s) { + Tensor out = outs->Slice(s, e); + MultiClassOutput(all_nmsed_out[i], &out); + } + } + } + + LoD lod; + lod.emplace_back(batch_starts); + outs->set_lod(lod); +} + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + retinanet_detection_output, + kHost, + kFloat, + kNCHW, + paddle::lite::kernels::host::RetinanetDetectionOutputCompute, + def) + .BindInput("BBoxes", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("Scores", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("Anchors", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindInput("ImInfo", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/host/retinanet_detection_output_compute.h b/lite/kernels/host/retinanet_detection_output_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..612ea7105e2728b856f02d71e9fcfaea2a1ef680 --- /dev/null +++ b/lite/kernels/host/retinanet_detection_output_compute.h @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +class RetinanetDetectionOutputCompute + : public KernelLite { + public: + void Run() override; + + virtual ~RetinanetDetectionOutputCompute() = default; +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 586c75c60f6caef1354cf9a167bace798944c354..050fd6296509f6ef56f06066a813c5dd4b70246f 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -137,6 +137,7 @@ add_operator(topk_op extra SRCS topk_op.cc DEPS ${op_DEPS}) add_operator(increment_op extra SRCS increment_op.cc DEPS ${op_DEPS}) add_operator(layer_norm_op extra SRCS layer_norm_op.cc DEPS ${op_DEPS}) add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS}) +add_operator(retinanet_detection_output_op extra SRCS retinanet_detection_output_op.cc DEPS ${op_DEPS}) # for content-dnn specific add_operator(search_aligned_mat_mul_op extra SRCS search_aligned_mat_mul_op.cc DEPS ${op_DEPS}) add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index d1b3102c5e0bda214d4b48e3820a2ca35f6065de..a17cdc9579f791a104d4921e646bfb05e1981337 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1537,6 +1537,19 @@ struct PrintParam : ParamBase { bool is_forward{true}; }; +struct RetinanetDetectionOutputParam : ParamBase { + std::vector bboxes{}; + std::vector scores{}; + std::vector anchors{}; + Tensor* im_info{}; + Tensor* out{}; + float score_threshold{}; + int nms_top_k{}; + float nms_threshold{}; + float nms_eta{}; + int keep_top_k{}; +}; + } // namespace operators } // namespace lite } // namespace paddle diff --git a/lite/operators/retinanet_detection_output_op.cc b/lite/operators/retinanet_detection_output_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e27f2bfca0ab25b8f73d4c6a68d539a7c22389e0 --- /dev/null +++ b/lite/operators/retinanet_detection_output_op.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/retinanet_detection_output_op.h" +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool RetinanetDetectionOutputOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.bboxes.size() > 0); + CHECK_OR_FALSE(param_.scores.size() > 0); + CHECK_OR_FALSE(param_.anchors.size() > 0); + CHECK_OR_FALSE(param_.bboxes.size() == param_.scores.size()); + CHECK_OR_FALSE(param_.bboxes.size() == param_.anchors.size()); + CHECK_OR_FALSE(param_.im_info); + CHECK_OR_FALSE(param_.out); + + DDim bbox_dims = param_.bboxes.front()->dims(); + DDim score_dims = param_.scores.front()->dims(); + DDim anchor_dims = param_.anchors.front()->dims(); + DDim im_info_dims = param_.im_info->dims(); + + CHECK_OR_FALSE(bbox_dims.size() == 3); + CHECK_OR_FALSE(score_dims.size() == 3); + CHECK_OR_FALSE(anchor_dims.size() == 2); + CHECK_OR_FALSE(bbox_dims[2] == 4); + CHECK_OR_FALSE(bbox_dims[1] == score_dims[1]); + CHECK_OR_FALSE(anchor_dims[0] == bbox_dims[1]); + CHECK_OR_FALSE(im_info_dims.size() == 2); + + return true; +} + +bool RetinanetDetectionOutputOpLite::InferShapeImpl() const { + DDim bbox_dims = param_.bboxes.front()->dims(); + param_.out->Resize({bbox_dims[1], bbox_dims[2] + 2}); + return true; +} + +bool RetinanetDetectionOutputOpLite::AttachImpl(const cpp::OpDesc &op_desc, + lite::Scope *scope) { + for (auto arg_name : op_desc.Input("BBoxes")) { + param_.bboxes.push_back( + scope->FindVar(arg_name)->GetMutable()); + } + for (auto arg_name : op_desc.Input("Scores")) { + param_.scores.push_back( + scope->FindVar(arg_name)->GetMutable()); + } + for (auto arg_name : op_desc.Input("Anchors")) { + param_.anchors.push_back( + scope->FindVar(arg_name)->GetMutable()); + } + AttachInput(op_desc, scope, "ImInfo", false, ¶m_.im_info); + AttachOutput(op_desc, scope, "Out", false, ¶m_.out); + + param_.score_threshold = op_desc.GetAttr("score_threshold"); + param_.nms_top_k = op_desc.GetAttr("nms_top_k"); + param_.nms_threshold = op_desc.GetAttr("nms_threshold"); + param_.nms_eta = op_desc.GetAttr("nms_eta"); + param_.keep_top_k = op_desc.GetAttr("keep_top_k"); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(retinanet_detection_output, + paddle::lite::operators::RetinanetDetectionOutputOpLite); diff --git a/lite/operators/retinanet_detection_output_op.h b/lite/operators/retinanet_detection_output_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9969227e15941644249b46ba7372f9afc705672c --- /dev/null +++ b/lite/operators/retinanet_detection_output_op.h @@ -0,0 +1,55 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/operators/op_params.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class RetinanetDetectionOutputOpLite : public OpLite { + public: + RetinanetDetectionOutputOpLite() {} + + explicit RetinanetDetectionOutputOpLite(const std::string &op_type) + : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "retinanet_detection_output"; + } + +#ifdef LITE_WITH_PROFILE + void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {} +#endif + + private: + mutable RetinanetDetectionOutputParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle