未验证 提交 faece382 编写于 作者: H hong 提交者: GitHub

Move yolo box to phi (#40112)

* add yolo box kernel; test=develop

* fix comile error; test=develop
上级 12346cdc
......@@ -16,7 +16,6 @@
#include <cassert>
#include "paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h"
#include "paddle/fluid/operators/detection/yolo_box_op.h"
namespace paddle {
namespace inference {
......
......@@ -62,7 +62,7 @@ detection_library(locality_aware_nms_op SRCS locality_aware_nms_op.cc DEPS gpc)
detection_library(matrix_nms_op SRCS matrix_nms_op.cc DEPS gpc)
detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu)
detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc)
detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op.cu)
detection_library(yolo_box_op SRCS yolo_box_op.cc)
detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu)
detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc sigmoid_focal_loss_op.cu)
detection_library(retinanet_detection_output_op SRCS retinanet_detection_output_op.cc)
......
......@@ -9,7 +9,6 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/yolo_box_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
......@@ -240,8 +239,6 @@ REGISTER_OPERATOR(
yolo_box, ops::YoloBoxOp, ops::YoloBoxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(yolo_box, ops::YoloBoxKernel<float>,
ops::YoloBoxKernel<double>);
REGISTER_OP_VERSION(yolo_box)
.AddCheckpoint(
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/detection/yolo_box_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
T* scores, const float conf_thresh,
const int* anchors, const int n, const int h,
const int w, const int an_num, const int class_num,
const int box_num, int input_size_h,
int input_size_w, bool clip_bbox, const float scale,
const float bias, bool iou_aware,
const float iou_aware_factor) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
T box[4];
for (; tid < n * box_num; tid += stride) {
int grid_num = h * w;
int i = tid / box_num;
int j = (tid % box_num) / grid_num;
int k = (tid % grid_num) / w;
int l = tid % w;
int an_stride = (5 + class_num) * grid_num;
int img_height = imgsize[2 * i];
int img_width = imgsize[2 * i + 1];
int obj_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4,
iou_aware);
T conf = sigmoid<T>(input[obj_idx]);
if (iou_aware) {
int iou_idx = GetIoUIndex(i, j, k * w + l, an_num, an_stride, grid_num);
T iou = sigmoid<T>(input[iou_idx]);
conf = pow(conf, static_cast<T>(1. - iou_aware_factor)) *
pow(iou, static_cast<T>(iou_aware_factor));
}
if (conf < conf_thresh) {
continue;
}
int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0,
iou_aware);
GetYoloBox<T>(box, input, anchors, l, k, j, h, w, input_size_h,
input_size_w, box_idx, grid_num, img_height, img_width, scale,
bias);
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox);
int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num,
5, iou_aware);
int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num;
CalcLabelScore<T>(scores, input, label_idx, score_idx, class_num, conf,
grid_num);
}
}
template <typename T>
class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* img_size = ctx.Input<Tensor>("ImgSize");
auto* boxes = ctx.Output<Tensor>("Boxes");
auto* scores = ctx.Output<Tensor>("Scores");
auto anchors = ctx.Attr<std::vector<int>>("anchors");
int class_num = ctx.Attr<int>("class_num");
float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool clip_bbox = ctx.Attr<bool>("clip_bbox");
bool iou_aware = ctx.Attr<bool>("iou_aware");
float iou_aware_factor = ctx.Attr<float>("iou_aware_factor");
float scale = ctx.Attr<float>("scale_x_y");
float bias = -0.5 * (scale - 1.);
const int n = input->dims()[0];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int box_num = boxes->dims()[1];
const int an_num = anchors.size() / 2;
int input_size_h = downsample_ratio * h;
int input_size_w = downsample_ratio * w;
auto& dev_ctx = ctx.cuda_device_context();
int bytes = sizeof(int) * anchors.size();
auto anchors_ptr = memory::Alloc(dev_ctx, sizeof(int) * anchors.size());
int* anchors_data = reinterpret_cast<int*>(anchors_ptr->ptr());
const auto gplace = ctx.GetPlace();
const auto cplace = platform::CPUPlace();
memory::Copy(gplace, anchors_data, cplace, anchors.data(), bytes,
dev_ctx.stream());
const T* input_data = input->data<T>();
const int* imgsize_data = img_size->data<int>();
T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, ctx.GetPlace());
T* scores_data =
scores->mutable_data<T>({n, box_num, class_num}, ctx.GetPlace());
phi::funcs::SetConstant<platform::CUDADeviceContext, T> set_zero;
set_zero(dev_ctx, boxes, static_cast<T>(0));
set_zero(dev_ctx, scores, static_cast<T>(0));
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), n * box_num);
dim3 thread_num = config.thread_per_block;
#ifdef WITH_NV_JETSON
if (config.compute_capability == 53 || config.compute_capability == 62) {
thread_num = 512;
}
#endif
KeYoloBoxFw<T><<<config.block_per_grid, thread_num, 0,
ctx.cuda_device_context().stream()>>>(
input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
anchors_data, n, h, w, an_num, class_num, box_num, input_size_h,
input_size_w, clip_bbox, scale, bias, iou_aware, iou_aware_factor);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(yolo_box, ops::YoloBoxOpCUDAKernel<float>,
ops::YoloBoxOpCUDAKernel<double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
HOSTDEVICE inline T sigmoid(T x) {
return 1.0 / (1.0 + std::exp(-x));
}
template <typename T>
HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i,
int j, int an_idx, int grid_size_h,
int grid_size_w, int input_size_h,
int input_size_w, int index, int stride,
int img_height, int img_width, float scale,
float bias) {
box[0] = (i + sigmoid<T>(x[index]) * scale + bias) * img_width / grid_size_w;
box[1] = (j + sigmoid<T>(x[index + stride]) * scale + bias) * img_height /
grid_size_h;
box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width /
input_size_w;
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] *
img_height / input_size_h;
}
HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
int an_num, int an_stride, int stride,
int entry, bool iou_aware) {
if (iou_aware) {
return (batch * an_num + an_idx) * an_stride +
(batch * an_num + an_num + entry) * stride + hw_idx;
} else {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
}
}
HOSTDEVICE inline int GetIoUIndex(int batch, int an_idx, int hw_idx, int an_num,
int an_stride, int stride) {
return batch * an_num * an_stride + (batch * an_num + an_idx) * stride +
hw_idx;
}
template <typename T>
HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx,
const int img_height,
const int img_width, bool clip_bbox) {
boxes[box_idx] = box[0] - box[2] / 2;
boxes[box_idx + 1] = box[1] - box[3] / 2;
boxes[box_idx + 2] = box[0] + box[2] / 2;
boxes[box_idx + 3] = box[1] + box[3] / 2;
if (clip_bbox) {
boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast<T>(0);
boxes[box_idx + 1] =
boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<T>(0);
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1
? boxes[box_idx + 2]
: static_cast<T>(img_width - 1);
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1
? boxes[box_idx + 3]
: static_cast<T>(img_height - 1);
}
}
template <typename T>
HOSTDEVICE inline void CalcLabelScore(T* scores, const T* input,
const int label_idx, const int score_idx,
const int class_num, const T conf,
const int stride) {
for (int i = 0; i < class_num; i++) {
scores[score_idx + i] = conf * sigmoid<T>(input[label_idx + i * stride]);
}
}
template <typename T>
class YoloBoxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* imgsize = ctx.Input<Tensor>("ImgSize");
auto* boxes = ctx.Output<Tensor>("Boxes");
auto* scores = ctx.Output<Tensor>("Scores");
auto anchors = ctx.Attr<std::vector<int>>("anchors");
int class_num = ctx.Attr<int>("class_num");
float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool clip_bbox = ctx.Attr<bool>("clip_bbox");
bool iou_aware = ctx.Attr<bool>("iou_aware");
float iou_aware_factor = ctx.Attr<float>("iou_aware_factor");
float scale = ctx.Attr<float>("scale_x_y");
float bias = -0.5 * (scale - 1.);
const int n = input->dims()[0];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int box_num = boxes->dims()[1];
const int an_num = anchors.size() / 2;
int input_size_h = downsample_ratio * h;
int input_size_w = downsample_ratio * w;
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
Tensor anchors_;
auto anchors_data =
anchors_.mutable_data<int>({an_num * 2}, ctx.GetPlace());
std::copy(anchors.begin(), anchors.end(), anchors_data);
const T* input_data = input->data<T>();
const int* imgsize_data = imgsize->data<int>();
T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, ctx.GetPlace());
memset(boxes_data, 0, boxes->numel() * sizeof(T));
T* scores_data =
scores->mutable_data<T>({n, box_num, class_num}, ctx.GetPlace());
memset(scores_data, 0, scores->numel() * sizeof(T));
T box[4];
for (int i = 0; i < n; i++) {
int img_height = imgsize_data[2 * i];
int img_width = imgsize_data[2 * i + 1];
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
int obj_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride,
stride, 4, iou_aware);
T conf = sigmoid<T>(input_data[obj_idx]);
if (iou_aware) {
int iou_idx =
GetIoUIndex(i, j, k * w + l, an_num, an_stride, stride);
T iou = sigmoid<T>(input_data[iou_idx]);
conf = pow(conf, static_cast<T>(1. - iou_aware_factor)) *
pow(iou, static_cast<T>(iou_aware_factor));
}
if (conf < conf_thresh) {
continue;
}
int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride,
stride, 0, iou_aware);
GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, w,
input_size_h, input_size_w, box_idx, stride,
img_height, img_width, scale, bias);
box_idx = (i * box_num + j * stride + k * w + l) * 4;
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height, img_width,
clip_bbox);
int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride,
stride, 5, iou_aware);
int score_idx = (i * box_num + j * stride + k * w + l) * class_num;
CalcLabelScore<T>(scores_data, input_data, label_idx, score_idx,
class_num, conf, stride);
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/yolo_box_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/yolo_box_util.h"
namespace phi {
template <typename T, typename Context>
void YoloBoxKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& img_size,
const std::vector<int>& anchors,
int class_num,
float conf_thresh,
int downsample_ratio,
bool clip_bbox,
float scale_x_y,
bool iou_aware,
float iou_aware_factor,
DenseTensor* boxes,
DenseTensor* scores) {
auto* input = &x;
auto* imgsize = &img_size;
float scale = scale_x_y;
float bias = -0.5 * (scale - 1.);
const int n = input->dims()[0];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int box_num = boxes->dims()[1];
const int an_num = anchors.size() / 2;
int input_size_h = downsample_ratio * h;
int input_size_w = downsample_ratio * w;
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
DenseTensor anchors_;
auto anchors_data =
anchors_.mutable_data<int>({an_num * 2}, dev_ctx.GetPlace());
std::copy(anchors.begin(), anchors.end(), anchors_data);
const T* input_data = input->data<T>();
const int* imgsize_data = imgsize->data<int>();
T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, dev_ctx.GetPlace());
memset(boxes_data, 0, boxes->numel() * sizeof(T));
T* scores_data =
scores->mutable_data<T>({n, box_num, class_num}, dev_ctx.GetPlace());
memset(scores_data, 0, scores->numel() * sizeof(T));
T box[4];
for (int i = 0; i < n; i++) {
int img_height = imgsize_data[2 * i];
int img_width = imgsize_data[2 * i + 1];
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
int obj_idx = funcs::GetEntryIndex(
i, j, k * w + l, an_num, an_stride, stride, 4, iou_aware);
T conf = funcs::sigmoid<T>(input_data[obj_idx]);
if (iou_aware) {
int iou_idx =
funcs::GetIoUIndex(i, j, k * w + l, an_num, an_stride, stride);
T iou = funcs::sigmoid<T>(input_data[iou_idx]);
conf = pow(conf, static_cast<T>(1. - iou_aware_factor)) *
pow(iou, static_cast<T>(iou_aware_factor));
}
if (conf < conf_thresh) {
continue;
}
int box_idx = funcs::GetEntryIndex(
i, j, k * w + l, an_num, an_stride, stride, 0, iou_aware);
funcs::GetYoloBox<T>(box,
input_data,
anchors_data,
l,
k,
j,
h,
w,
input_size_h,
input_size_w,
box_idx,
stride,
img_height,
img_width,
scale,
bias);
box_idx = (i * box_num + j * stride + k * w + l) * 4;
funcs::CalcDetectionBox<T>(
boxes_data, box, box_idx, img_height, img_width, clip_bbox);
int label_idx = funcs::GetEntryIndex(
i, j, k * w + l, an_num, an_stride, stride, 5, iou_aware);
int score_idx = (i * box_num + j * stride + k * w + l) * class_num;
funcs::CalcLabelScore<T>(scores_data,
input_data,
label_idx,
score_idx,
class_num,
conf,
stride);
}
}
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
yolo_box, CPU, ALL_LAYOUT, phi::YoloBoxKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace phi {
namespace funcs {
template <typename T>
HOSTDEVICE inline T sigmoid(T x) {
return 1.0 / (1.0 + std::exp(-x));
}
template <typename T>
HOSTDEVICE inline void GetYoloBox(T* box,
const T* x,
const int* anchors,
int i,
int j,
int an_idx,
int grid_size_h,
int grid_size_w,
int input_size_h,
int input_size_w,
int index,
int stride,
int img_height,
int img_width,
float scale,
float bias) {
box[0] = (i + sigmoid<T>(x[index]) * scale + bias) * img_width / grid_size_w;
box[1] = (j + sigmoid<T>(x[index + stride]) * scale + bias) * img_height /
grid_size_h;
box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width /
input_size_w;
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] *
img_height / input_size_h;
}
HOSTDEVICE inline int GetEntryIndex(int batch,
int an_idx,
int hw_idx,
int an_num,
int an_stride,
int stride,
int entry,
bool iou_aware) {
if (iou_aware) {
return (batch * an_num + an_idx) * an_stride +
(batch * an_num + an_num + entry) * stride + hw_idx;
} else {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
}
}
HOSTDEVICE inline int GetIoUIndex(
int batch, int an_idx, int hw_idx, int an_num, int an_stride, int stride) {
return batch * an_num * an_stride + (batch * an_num + an_idx) * stride +
hw_idx;
}
template <typename T>
HOSTDEVICE inline void CalcDetectionBox(T* boxes,
T* box,
const int box_idx,
const int img_height,
const int img_width,
bool clip_bbox) {
boxes[box_idx] = box[0] - box[2] / 2;
boxes[box_idx + 1] = box[1] - box[3] / 2;
boxes[box_idx + 2] = box[0] + box[2] / 2;
boxes[box_idx + 3] = box[1] + box[3] / 2;
if (clip_bbox) {
boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast<T>(0);
boxes[box_idx + 1] =
boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<T>(0);
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1
? boxes[box_idx + 2]
: static_cast<T>(img_width - 1);
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1
? boxes[box_idx + 3]
: static_cast<T>(img_height - 1);
}
}
template <typename T>
HOSTDEVICE inline void CalcLabelScore(T* scores,
const T* input,
const int label_idx,
const int score_idx,
const int class_num,
const T conf,
const int stride) {
for (int i = 0; i < class_num; i++) {
scores[score_idx + i] = conf * sigmoid<T>(input[label_idx + i * stride]);
}
}
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/yolo_box_util.h"
#include "paddle/phi/kernels/yolo_box_kernel.h"
namespace phi {
template <typename T>
__global__ void KeYoloBoxFw(const T* input,
const int* imgsize,
T* boxes,
T* scores,
const float conf_thresh,
const int* anchors,
const int n,
const int h,
const int w,
const int an_num,
const int class_num,
const int box_num,
int input_size_h,
int input_size_w,
bool clip_bbox,
const float scale,
const float bias,
bool iou_aware,
const float iou_aware_factor) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
T box[4];
for (; tid < n * box_num; tid += stride) {
int grid_num = h * w;
int i = tid / box_num;
int j = (tid % box_num) / grid_num;
int k = (tid % grid_num) / w;
int l = tid % w;
int an_stride = (5 + class_num) * grid_num;
int img_height = imgsize[2 * i];
int img_width = imgsize[2 * i + 1];
int obj_idx = funcs::GetEntryIndex(
i, j, k * w + l, an_num, an_stride, grid_num, 4, iou_aware);
T conf = funcs::sigmoid<T>(input[obj_idx]);
if (iou_aware) {
int iou_idx =
funcs::GetIoUIndex(i, j, k * w + l, an_num, an_stride, grid_num);
T iou = funcs::sigmoid<T>(input[iou_idx]);
conf = pow(conf, static_cast<T>(1. - iou_aware_factor)) *
pow(iou, static_cast<T>(iou_aware_factor));
}
if (conf < conf_thresh) {
continue;
}
int box_idx = funcs::GetEntryIndex(
i, j, k * w + l, an_num, an_stride, grid_num, 0, iou_aware);
funcs::GetYoloBox<T>(box,
input,
anchors,
l,
k,
j,
h,
w,
input_size_h,
input_size_w,
box_idx,
grid_num,
img_height,
img_width,
scale,
bias);
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
funcs::CalcDetectionBox<T>(
boxes, box, box_idx, img_height, img_width, clip_bbox);
int label_idx = funcs::GetEntryIndex(
i, j, k * w + l, an_num, an_stride, grid_num, 5, iou_aware);
int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num;
funcs::CalcLabelScore<T>(
scores, input, label_idx, score_idx, class_num, conf, grid_num);
}
}
template <typename T, typename Context>
void YoloBoxKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& img_size,
const std::vector<int>& anchors,
int class_num,
float conf_thresh,
int downsample_ratio,
bool clip_bbox,
float scale_x_y,
bool iou_aware,
float iou_aware_factor,
DenseTensor* boxes,
DenseTensor* scores) {
auto* input = &x;
float scale = scale_x_y;
float bias = -0.5 * (scale - 1.);
const int n = input->dims()[0];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int box_num = boxes->dims()[1];
const int an_num = anchors.size() / 2;
int input_size_h = downsample_ratio * h;
int input_size_w = downsample_ratio * w;
int bytes = sizeof(int) * anchors.size();
auto anchors_ptr =
paddle::memory::Alloc(dev_ctx, sizeof(int) * anchors.size());
int* anchors_data = reinterpret_cast<int*>(anchors_ptr->ptr());
const auto gplace = dev_ctx.GetPlace();
const auto cplace = phi::CPUPlace();
paddle::memory::Copy(
gplace, anchors_data, cplace, anchors.data(), bytes, dev_ctx.stream());
const T* input_data = input->data<T>();
const int* imgsize_data = img_size.data<int>();
T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, dev_ctx.GetPlace());
T* scores_data =
scores->mutable_data<T>({n, box_num, class_num}, dev_ctx.GetPlace());
phi::funcs::SetConstant<phi::GPUContext, T> set_zero;
set_zero(dev_ctx, boxes, static_cast<T>(0));
set_zero(dev_ctx, scores, static_cast<T>(0));
backends::gpu::GpuLaunchConfig config =
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * box_num);
dim3 thread_num = config.thread_per_block;
#ifdef WITH_NV_JETSON
if (config.compute_capability == 53 || config.compute_capability == 62) {
thread_num = 512;
}
#endif
KeYoloBoxFw<T><<<config.block_per_grid, thread_num, 0, dev_ctx.stream()>>>(
input_data,
imgsize_data,
boxes_data,
scores_data,
conf_thresh,
anchors_data,
n,
h,
w,
an_num,
class_num,
box_num,
input_size_h,
input_size_w,
clip_bbox,
scale,
bias,
iou_aware,
iou_aware_factor);
}
} // namespace phi
PD_REGISTER_KERNEL(
yolo_box, GPU, ALL_LAYOUT, phi::YoloBoxKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void YoloBoxKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& img_size,
const std::vector<int>& anchors,
int class_num,
float conf_thresh,
int downsample_ratio,
bool clip_bbox,
float scale_x_y,
bool iou_aware,
float iou_aware_factor,
DenseTensor* boxes,
DenseTensor* scores);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature YoloBoxOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("yolo_box",
{"X", "ImgSize"},
{"anchors",
"class_num",
"conf_thresh",
"downsample_ratio",
"clip_bbox",
"scale_x_y",
"iou_aware",
"iou_aware_factor"},
{"Boxes", "Scores"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(yolo_box, phi::YoloBoxOpArgumentMapping);
......@@ -260,5 +260,6 @@ class TestYoloBoxOpHW(TestYoloBoxOp):
self.iou_aware_factor = 0.5
if (__name__ == '__main__'):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册