/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/hostdevice.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; template HOSTDEVICE inline T sigmoid(T x) { return 1.0 / (1.0 + std::exp(-x)); } template HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i, int j, int an_idx, int grid_size, int input_size, int index, int stride, int img_height, int img_width) { box[0] = (i + sigmoid(x[index])) * img_width / grid_size; box[1] = (j + sigmoid(x[index + stride])) * img_height / grid_size; box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width / input_size; box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * img_height / input_size; } HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx, int an_num, int an_stride, int stride, int entry) { return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; } template HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx, const int img_height, const int img_width) { boxes[box_idx] = box[0] - box[2] / 2; boxes[box_idx + 1] = box[1] - box[3] / 2; boxes[box_idx + 2] = box[0] + box[2] / 2; boxes[box_idx + 3] = box[1] + box[3] / 2; boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast(0); boxes[box_idx + 1] = boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast(0); boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1 ? boxes[box_idx + 2] : static_cast(img_width - 1); boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 ? boxes[box_idx + 3] : static_cast(img_height - 1); } template HOSTDEVICE inline void CalcLabelScore(T* scores, const T* input, const int label_idx, const int score_idx, const int class_num, const T conf, const int stride) { for (int i = 0; i < class_num; i++) { scores[score_idx + i] = conf * sigmoid(input[label_idx + i * stride]); } } template class YoloBoxKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("X"); auto* imgsize = ctx.Input("ImgSize"); auto* boxes = ctx.Output("Boxes"); auto* scores = ctx.Output("Scores"); auto anchors = ctx.Attr>("anchors"); int class_num = ctx.Attr("class_num"); float conf_thresh = ctx.Attr("conf_thresh"); int downsample_ratio = ctx.Attr("downsample_ratio"); const int n = input->dims()[0]; const int h = input->dims()[2]; const int w = input->dims()[3]; const int box_num = boxes->dims()[1]; const int an_num = anchors.size() / 2; int input_size = downsample_ratio * h; const int stride = h * w; const int an_stride = (class_num + 5) * stride; Tensor anchors_; auto anchors_data = anchors_.mutable_data({an_num * 2}, ctx.GetPlace()); std::copy(anchors.begin(), anchors.end(), anchors_data); const T* input_data = input->data(); const int* imgsize_data = imgsize->data(); T* boxes_data = boxes->mutable_data({n, box_num, 4}, ctx.GetPlace()); memset(boxes_data, 0, boxes->numel() * sizeof(T)); T* scores_data = scores->mutable_data({n, box_num, class_num}, ctx.GetPlace()); memset(scores_data, 0, scores->numel() * sizeof(T)); T box[4]; for (int i = 0; i < n; i++) { int img_height = imgsize_data[2 * i]; int img_width = imgsize_data[2 * i + 1]; for (int j = 0; j < an_num; j++) { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { int obj_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 4); T conf = sigmoid(input_data[obj_idx]); if (conf < conf_thresh) { continue; } int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0); GetYoloBox(box, input_data, anchors_data, l, k, j, h, input_size, box_idx, stride, img_height, img_width); box_idx = (i * box_num + j * stride + k * w + l) * 4; CalcDetectionBox(boxes_data, box, box_idx, img_height, img_width); int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5); int score_idx = (i * box_num + j * stride + k * w + l) * class_num; CalcLabelScore(scores_data, input_data, label_idx, score_idx, class_num, conf, stride); } } } } } }; } // namespace operators } // namespace paddle