diff --git a/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu index 57177cfa8b421e1d79004bb1a7f738d98dc00f97..336005d883b0f523213060645e540c35a14e4e9c 100644 --- a/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu @@ -16,7 +16,6 @@ #include #include "paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h" -#include "paddle/fluid/operators/detection/yolo_box_op.h" namespace paddle { namespace inference { diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 1ebafa54598574ae9027a4887639a2a1d27448ea..568c7982cfc7c07b9c7f840ccaa32e4025225122 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -62,7 +62,7 @@ detection_library(locality_aware_nms_op SRCS locality_aware_nms_op.cc DEPS gpc) detection_library(matrix_nms_op SRCS matrix_nms_op.cc DEPS gpc) detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu) detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc) -detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op.cu) +detection_library(yolo_box_op SRCS yolo_box_op.cc) detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu) detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc sigmoid_focal_loss_op.cu) detection_library(retinanet_detection_output_op SRCS retinanet_detection_output_op.cc) diff --git a/paddle/fluid/operators/detection/yolo_box_op.cc b/paddle/fluid/operators/detection/yolo_box_op.cc index 511d8e0eed1065ae0cd2cec3a7bcf534cd3043ab..0d9fbf612f73c428fb8050fcfcc319ddafabe482 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cc +++ b/paddle/fluid/operators/detection/yolo_box_op.cc @@ -9,7 +9,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/detection/yolo_box_op.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" @@ -240,8 +239,6 @@ REGISTER_OPERATOR( yolo_box, ops::YoloBoxOp, ops::YoloBoxOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(yolo_box, ops::YoloBoxKernel, - ops::YoloBoxKernel); REGISTER_OP_VERSION(yolo_box) .AddCheckpoint( diff --git a/paddle/fluid/operators/detection/yolo_box_op.cu b/paddle/fluid/operators/detection/yolo_box_op.cu deleted file mode 100644 index fb5c214a59e1274ffc30226bf49a068df960f414..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/detection/yolo_box_op.cu +++ /dev/null @@ -1,143 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/operators/detection/yolo_box_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" -#include "paddle/phi/kernels/funcs/math_function.h" -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -__global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, - T* scores, const float conf_thresh, - const int* anchors, const int n, const int h, - const int w, const int an_num, const int class_num, - const int box_num, int input_size_h, - int input_size_w, bool clip_bbox, const float scale, - const float bias, bool iou_aware, - const float iou_aware_factor) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = blockDim.x * gridDim.x; - T box[4]; - for (; tid < n * box_num; tid += stride) { - int grid_num = h * w; - int i = tid / box_num; - int j = (tid % box_num) / grid_num; - int k = (tid % grid_num) / w; - int l = tid % w; - - int an_stride = (5 + class_num) * grid_num; - int img_height = imgsize[2 * i]; - int img_width = imgsize[2 * i + 1]; - - int obj_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4, - iou_aware); - T conf = sigmoid(input[obj_idx]); - if (iou_aware) { - int iou_idx = GetIoUIndex(i, j, k * w + l, an_num, an_stride, grid_num); - T iou = sigmoid(input[iou_idx]); - conf = pow(conf, static_cast(1. - iou_aware_factor)) * - pow(iou, static_cast(iou_aware_factor)); - } - if (conf < conf_thresh) { - continue; - } - - int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0, - iou_aware); - GetYoloBox(box, input, anchors, l, k, j, h, w, input_size_h, - input_size_w, box_idx, grid_num, img_height, img_width, scale, - bias); - box_idx = (i * box_num + j * grid_num + k * w + l) * 4; - CalcDetectionBox(boxes, box, box_idx, img_height, img_width, clip_bbox); - - int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, - 5, iou_aware); - int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num; - CalcLabelScore(scores, input, label_idx, score_idx, class_num, conf, - grid_num); - } -} - -template -class YoloBoxOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); - auto* img_size = ctx.Input("ImgSize"); - auto* boxes = ctx.Output("Boxes"); - auto* scores = ctx.Output("Scores"); - - auto anchors = ctx.Attr>("anchors"); - int class_num = ctx.Attr("class_num"); - float conf_thresh = ctx.Attr("conf_thresh"); - int downsample_ratio = ctx.Attr("downsample_ratio"); - bool clip_bbox = ctx.Attr("clip_bbox"); - bool iou_aware = ctx.Attr("iou_aware"); - float iou_aware_factor = ctx.Attr("iou_aware_factor"); - float scale = ctx.Attr("scale_x_y"); - float bias = -0.5 * (scale - 1.); - - const int n = input->dims()[0]; - const int h = input->dims()[2]; - const int w = input->dims()[3]; - const int box_num = boxes->dims()[1]; - const int an_num = anchors.size() / 2; - int input_size_h = downsample_ratio * h; - int input_size_w = downsample_ratio * w; - - auto& dev_ctx = ctx.cuda_device_context(); - int bytes = sizeof(int) * anchors.size(); - auto anchors_ptr = memory::Alloc(dev_ctx, sizeof(int) * anchors.size()); - int* anchors_data = reinterpret_cast(anchors_ptr->ptr()); - const auto gplace = ctx.GetPlace(); - const auto cplace = platform::CPUPlace(); - memory::Copy(gplace, anchors_data, cplace, anchors.data(), bytes, - dev_ctx.stream()); - - const T* input_data = input->data(); - const int* imgsize_data = img_size->data(); - T* boxes_data = boxes->mutable_data({n, box_num, 4}, ctx.GetPlace()); - T* scores_data = - scores->mutable_data({n, box_num, class_num}, ctx.GetPlace()); - phi::funcs::SetConstant set_zero; - set_zero(dev_ctx, boxes, static_cast(0)); - set_zero(dev_ctx, scores, static_cast(0)); - platform::GpuLaunchConfig config = - platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), n * box_num); - - dim3 thread_num = config.thread_per_block; -#ifdef WITH_NV_JETSON - if (config.compute_capability == 53 || config.compute_capability == 62) { - thread_num = 512; - } -#endif - - KeYoloBoxFw<<>>( - input_data, imgsize_data, boxes_data, scores_data, conf_thresh, - anchors_data, n, h, w, an_num, class_num, box_num, input_size_h, - input_size_w, clip_bbox, scale, bias, iou_aware, iou_aware_factor); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(yolo_box, ops::YoloBoxOpCUDAKernel, - ops::YoloBoxOpCUDAKernel); diff --git a/paddle/fluid/operators/detection/yolo_box_op.h b/paddle/fluid/operators/detection/yolo_box_op.h deleted file mode 100644 index 2cd69c60b7c44d0557c23b8d1bd933650e8402c3..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/detection/yolo_box_op.h +++ /dev/null @@ -1,180 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#pragma once -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/hostdevice.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -HOSTDEVICE inline T sigmoid(T x) { - return 1.0 / (1.0 + std::exp(-x)); -} - -template -HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i, - int j, int an_idx, int grid_size_h, - int grid_size_w, int input_size_h, - int input_size_w, int index, int stride, - int img_height, int img_width, float scale, - float bias) { - box[0] = (i + sigmoid(x[index]) * scale + bias) * img_width / grid_size_w; - box[1] = (j + sigmoid(x[index + stride]) * scale + bias) * img_height / - grid_size_h; - box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width / - input_size_w; - box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * - img_height / input_size_h; -} - -HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx, - int an_num, int an_stride, int stride, - int entry, bool iou_aware) { - if (iou_aware) { - return (batch * an_num + an_idx) * an_stride + - (batch * an_num + an_num + entry) * stride + hw_idx; - } else { - return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; - } -} - -HOSTDEVICE inline int GetIoUIndex(int batch, int an_idx, int hw_idx, int an_num, - int an_stride, int stride) { - return batch * an_num * an_stride + (batch * an_num + an_idx) * stride + - hw_idx; -} - -template -HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx, - const int img_height, - const int img_width, bool clip_bbox) { - boxes[box_idx] = box[0] - box[2] / 2; - boxes[box_idx + 1] = box[1] - box[3] / 2; - boxes[box_idx + 2] = box[0] + box[2] / 2; - boxes[box_idx + 3] = box[1] + box[3] / 2; - - if (clip_bbox) { - boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast(0); - boxes[box_idx + 1] = - boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast(0); - boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1 - ? boxes[box_idx + 2] - : static_cast(img_width - 1); - boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 - ? boxes[box_idx + 3] - : static_cast(img_height - 1); - } -} - -template -HOSTDEVICE inline void CalcLabelScore(T* scores, const T* input, - const int label_idx, const int score_idx, - const int class_num, const T conf, - const int stride) { - for (int i = 0; i < class_num; i++) { - scores[score_idx + i] = conf * sigmoid(input[label_idx + i * stride]); - } -} - -template -class YoloBoxKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); - auto* imgsize = ctx.Input("ImgSize"); - auto* boxes = ctx.Output("Boxes"); - auto* scores = ctx.Output("Scores"); - auto anchors = ctx.Attr>("anchors"); - int class_num = ctx.Attr("class_num"); - float conf_thresh = ctx.Attr("conf_thresh"); - int downsample_ratio = ctx.Attr("downsample_ratio"); - bool clip_bbox = ctx.Attr("clip_bbox"); - bool iou_aware = ctx.Attr("iou_aware"); - float iou_aware_factor = ctx.Attr("iou_aware_factor"); - float scale = ctx.Attr("scale_x_y"); - float bias = -0.5 * (scale - 1.); - - const int n = input->dims()[0]; - const int h = input->dims()[2]; - const int w = input->dims()[3]; - const int box_num = boxes->dims()[1]; - const int an_num = anchors.size() / 2; - int input_size_h = downsample_ratio * h; - int input_size_w = downsample_ratio * w; - - const int stride = h * w; - const int an_stride = (class_num + 5) * stride; - - Tensor anchors_; - auto anchors_data = - anchors_.mutable_data({an_num * 2}, ctx.GetPlace()); - std::copy(anchors.begin(), anchors.end(), anchors_data); - - const T* input_data = input->data(); - const int* imgsize_data = imgsize->data(); - T* boxes_data = boxes->mutable_data({n, box_num, 4}, ctx.GetPlace()); - memset(boxes_data, 0, boxes->numel() * sizeof(T)); - T* scores_data = - scores->mutable_data({n, box_num, class_num}, ctx.GetPlace()); - memset(scores_data, 0, scores->numel() * sizeof(T)); - - T box[4]; - for (int i = 0; i < n; i++) { - int img_height = imgsize_data[2 * i]; - int img_width = imgsize_data[2 * i + 1]; - - for (int j = 0; j < an_num; j++) { - for (int k = 0; k < h; k++) { - for (int l = 0; l < w; l++) { - int obj_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, - stride, 4, iou_aware); - T conf = sigmoid(input_data[obj_idx]); - if (iou_aware) { - int iou_idx = - GetIoUIndex(i, j, k * w + l, an_num, an_stride, stride); - T iou = sigmoid(input_data[iou_idx]); - conf = pow(conf, static_cast(1. - iou_aware_factor)) * - pow(iou, static_cast(iou_aware_factor)); - } - if (conf < conf_thresh) { - continue; - } - - int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, - stride, 0, iou_aware); - GetYoloBox(box, input_data, anchors_data, l, k, j, h, w, - input_size_h, input_size_w, box_idx, stride, - img_height, img_width, scale, bias); - box_idx = (i * box_num + j * stride + k * w + l) * 4; - CalcDetectionBox(boxes_data, box, box_idx, img_height, img_width, - clip_bbox); - - int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, - stride, 5, iou_aware); - int score_idx = (i * box_num + j * stride + k * w + l) * class_num; - CalcLabelScore(scores_data, input_data, label_idx, score_idx, - class_num, conf, stride); - } - } - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/cpu/yolo_box_kernel.cc b/paddle/phi/kernels/cpu/yolo_box_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..a83bc019fc3af395cedc20edd548b70149a915d5 --- /dev/null +++ b/paddle/phi/kernels/cpu/yolo_box_kernel.cc @@ -0,0 +1,128 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/yolo_box_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/yolo_box_util.h" + +namespace phi { + +template +void YoloBoxKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& img_size, + const std::vector& anchors, + int class_num, + float conf_thresh, + int downsample_ratio, + bool clip_bbox, + float scale_x_y, + bool iou_aware, + float iou_aware_factor, + DenseTensor* boxes, + DenseTensor* scores) { + auto* input = &x; + auto* imgsize = &img_size; + float scale = scale_x_y; + float bias = -0.5 * (scale - 1.); + + const int n = input->dims()[0]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + const int box_num = boxes->dims()[1]; + const int an_num = anchors.size() / 2; + int input_size_h = downsample_ratio * h; + int input_size_w = downsample_ratio * w; + + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + + DenseTensor anchors_; + auto anchors_data = + anchors_.mutable_data({an_num * 2}, dev_ctx.GetPlace()); + std::copy(anchors.begin(), anchors.end(), anchors_data); + + const T* input_data = input->data(); + const int* imgsize_data = imgsize->data(); + T* boxes_data = boxes->mutable_data({n, box_num, 4}, dev_ctx.GetPlace()); + memset(boxes_data, 0, boxes->numel() * sizeof(T)); + T* scores_data = + scores->mutable_data({n, box_num, class_num}, dev_ctx.GetPlace()); + memset(scores_data, 0, scores->numel() * sizeof(T)); + + T box[4]; + for (int i = 0; i < n; i++) { + int img_height = imgsize_data[2 * i]; + int img_width = imgsize_data[2 * i + 1]; + + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + int obj_idx = funcs::GetEntryIndex( + i, j, k * w + l, an_num, an_stride, stride, 4, iou_aware); + T conf = funcs::sigmoid(input_data[obj_idx]); + if (iou_aware) { + int iou_idx = + funcs::GetIoUIndex(i, j, k * w + l, an_num, an_stride, stride); + T iou = funcs::sigmoid(input_data[iou_idx]); + conf = pow(conf, static_cast(1. - iou_aware_factor)) * + pow(iou, static_cast(iou_aware_factor)); + } + if (conf < conf_thresh) { + continue; + } + + int box_idx = funcs::GetEntryIndex( + i, j, k * w + l, an_num, an_stride, stride, 0, iou_aware); + funcs::GetYoloBox(box, + input_data, + anchors_data, + l, + k, + j, + h, + w, + input_size_h, + input_size_w, + box_idx, + stride, + img_height, + img_width, + scale, + bias); + box_idx = (i * box_num + j * stride + k * w + l) * 4; + funcs::CalcDetectionBox( + boxes_data, box, box_idx, img_height, img_width, clip_bbox); + + int label_idx = funcs::GetEntryIndex( + i, j, k * w + l, an_num, an_stride, stride, 5, iou_aware); + int score_idx = (i * box_num + j * stride + k * w + l) * class_num; + funcs::CalcLabelScore(scores_data, + input_data, + label_idx, + score_idx, + class_num, + conf, + stride); + } + } + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + yolo_box, CPU, ALL_LAYOUT, phi::YoloBoxKernel, float, double) {} diff --git a/paddle/phi/kernels/funcs/yolo_box_util.h b/paddle/phi/kernels/funcs/yolo_box_util.h new file mode 100644 index 0000000000000000000000000000000000000000..337af2d7a236e9ea93c4eecf835ad9ca446b5276 --- /dev/null +++ b/paddle/phi/kernels/funcs/yolo_box_util.h @@ -0,0 +1,112 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace phi { +namespace funcs { + +template +HOSTDEVICE inline T sigmoid(T x) { + return 1.0 / (1.0 + std::exp(-x)); +} + +template +HOSTDEVICE inline void GetYoloBox(T* box, + const T* x, + const int* anchors, + int i, + int j, + int an_idx, + int grid_size_h, + int grid_size_w, + int input_size_h, + int input_size_w, + int index, + int stride, + int img_height, + int img_width, + float scale, + float bias) { + box[0] = (i + sigmoid(x[index]) * scale + bias) * img_width / grid_size_w; + box[1] = (j + sigmoid(x[index + stride]) * scale + bias) * img_height / + grid_size_h; + box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width / + input_size_w; + box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * + img_height / input_size_h; +} + +HOSTDEVICE inline int GetEntryIndex(int batch, + int an_idx, + int hw_idx, + int an_num, + int an_stride, + int stride, + int entry, + bool iou_aware) { + if (iou_aware) { + return (batch * an_num + an_idx) * an_stride + + (batch * an_num + an_num + entry) * stride + hw_idx; + } else { + return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; + } +} + +HOSTDEVICE inline int GetIoUIndex( + int batch, int an_idx, int hw_idx, int an_num, int an_stride, int stride) { + return batch * an_num * an_stride + (batch * an_num + an_idx) * stride + + hw_idx; +} + +template +HOSTDEVICE inline void CalcDetectionBox(T* boxes, + T* box, + const int box_idx, + const int img_height, + const int img_width, + bool clip_bbox) { + boxes[box_idx] = box[0] - box[2] / 2; + boxes[box_idx + 1] = box[1] - box[3] / 2; + boxes[box_idx + 2] = box[0] + box[2] / 2; + boxes[box_idx + 3] = box[1] + box[3] / 2; + + if (clip_bbox) { + boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast(0); + boxes[box_idx + 1] = + boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast(0); + boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1 + ? boxes[box_idx + 2] + : static_cast(img_width - 1); + boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 + ? boxes[box_idx + 3] + : static_cast(img_height - 1); + } +} + +template +HOSTDEVICE inline void CalcLabelScore(T* scores, + const T* input, + const int label_idx, + const int score_idx, + const int class_num, + const T conf, + const int stride) { + for (int i = 0; i < class_num; i++) { + scores[score_idx + i] = conf * sigmoid(input[label_idx + i * stride]); + } +} + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/yolo_box_kernel.cu b/paddle/phi/kernels/gpu/yolo_box_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..2719dcd9e54094f2b6af0ad18eabb445081d60a6 --- /dev/null +++ b/paddle/phi/kernels/gpu/yolo_box_kernel.cu @@ -0,0 +1,182 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/yolo_box_util.h" +#include "paddle/phi/kernels/yolo_box_kernel.h" + +namespace phi { + +template +__global__ void KeYoloBoxFw(const T* input, + const int* imgsize, + T* boxes, + T* scores, + const float conf_thresh, + const int* anchors, + const int n, + const int h, + const int w, + const int an_num, + const int class_num, + const int box_num, + int input_size_h, + int input_size_w, + bool clip_bbox, + const float scale, + const float bias, + bool iou_aware, + const float iou_aware_factor) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + T box[4]; + for (; tid < n * box_num; tid += stride) { + int grid_num = h * w; + int i = tid / box_num; + int j = (tid % box_num) / grid_num; + int k = (tid % grid_num) / w; + int l = tid % w; + + int an_stride = (5 + class_num) * grid_num; + int img_height = imgsize[2 * i]; + int img_width = imgsize[2 * i + 1]; + + int obj_idx = funcs::GetEntryIndex( + i, j, k * w + l, an_num, an_stride, grid_num, 4, iou_aware); + T conf = funcs::sigmoid(input[obj_idx]); + if (iou_aware) { + int iou_idx = + funcs::GetIoUIndex(i, j, k * w + l, an_num, an_stride, grid_num); + T iou = funcs::sigmoid(input[iou_idx]); + conf = pow(conf, static_cast(1. - iou_aware_factor)) * + pow(iou, static_cast(iou_aware_factor)); + } + if (conf < conf_thresh) { + continue; + } + + int box_idx = funcs::GetEntryIndex( + i, j, k * w + l, an_num, an_stride, grid_num, 0, iou_aware); + funcs::GetYoloBox(box, + input, + anchors, + l, + k, + j, + h, + w, + input_size_h, + input_size_w, + box_idx, + grid_num, + img_height, + img_width, + scale, + bias); + box_idx = (i * box_num + j * grid_num + k * w + l) * 4; + funcs::CalcDetectionBox( + boxes, box, box_idx, img_height, img_width, clip_bbox); + + int label_idx = funcs::GetEntryIndex( + i, j, k * w + l, an_num, an_stride, grid_num, 5, iou_aware); + int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num; + funcs::CalcLabelScore( + scores, input, label_idx, score_idx, class_num, conf, grid_num); + } +} + +template +void YoloBoxKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& img_size, + const std::vector& anchors, + int class_num, + float conf_thresh, + int downsample_ratio, + bool clip_bbox, + float scale_x_y, + bool iou_aware, + float iou_aware_factor, + DenseTensor* boxes, + DenseTensor* scores) { + auto* input = &x; + float scale = scale_x_y; + float bias = -0.5 * (scale - 1.); + + const int n = input->dims()[0]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + const int box_num = boxes->dims()[1]; + const int an_num = anchors.size() / 2; + int input_size_h = downsample_ratio * h; + int input_size_w = downsample_ratio * w; + + int bytes = sizeof(int) * anchors.size(); + auto anchors_ptr = + paddle::memory::Alloc(dev_ctx, sizeof(int) * anchors.size()); + int* anchors_data = reinterpret_cast(anchors_ptr->ptr()); + const auto gplace = dev_ctx.GetPlace(); + const auto cplace = phi::CPUPlace(); + paddle::memory::Copy( + gplace, anchors_data, cplace, anchors.data(), bytes, dev_ctx.stream()); + + const T* input_data = input->data(); + const int* imgsize_data = img_size.data(); + T* boxes_data = boxes->mutable_data({n, box_num, 4}, dev_ctx.GetPlace()); + T* scores_data = + scores->mutable_data({n, box_num, class_num}, dev_ctx.GetPlace()); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, boxes, static_cast(0)); + set_zero(dev_ctx, scores, static_cast(0)); + backends::gpu::GpuLaunchConfig config = + backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * box_num); + + dim3 thread_num = config.thread_per_block; +#ifdef WITH_NV_JETSON + if (config.compute_capability == 53 || config.compute_capability == 62) { + thread_num = 512; + } +#endif + + KeYoloBoxFw<<>>( + input_data, + imgsize_data, + boxes_data, + scores_data, + conf_thresh, + anchors_data, + n, + h, + w, + an_num, + class_num, + box_num, + input_size_h, + input_size_w, + clip_bbox, + scale, + bias, + iou_aware, + iou_aware_factor); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + yolo_box, GPU, ALL_LAYOUT, phi::YoloBoxKernel, float, double) {} diff --git a/paddle/phi/kernels/yolo_box_kernel.h b/paddle/phi/kernels/yolo_box_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..9553d300cad2b78719958077aea297a500c9359a --- /dev/null +++ b/paddle/phi/kernels/yolo_box_kernel.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void YoloBoxKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& img_size, + const std::vector& anchors, + int class_num, + float conf_thresh, + int downsample_ratio, + bool clip_bbox, + float scale_x_y, + bool iou_aware, + float iou_aware_factor, + DenseTensor* boxes, + DenseTensor* scores); + +} // namespace phi diff --git a/paddle/phi/ops/compat/yolo_box_sig.cc b/paddle/phi/ops/compat/yolo_box_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..bb39e72a64f5075908ca9b28d5f685fb0d6b6c9f --- /dev/null +++ b/paddle/phi/ops/compat/yolo_box_sig.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature YoloBoxOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("yolo_box", + {"X", "ImgSize"}, + {"anchors", + "class_num", + "conf_thresh", + "downsample_ratio", + "clip_bbox", + "scale_x_y", + "iou_aware", + "iou_aware_factor"}, + {"Boxes", "Scores"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(yolo_box, phi::YoloBoxOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py index 043c5c1651a09ac022d8a694b2e916b613c77f6b..f210d97362cf062260594dce1112059919f179c4 100644 --- a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py @@ -260,5 +260,6 @@ class TestYoloBoxOpHW(TestYoloBoxOp): self.iou_aware_factor = 0.5 -if (__name__ == '__main__'): +if __name__ == '__main__': + paddle.enable_static() unittest.main()