提交 ad897304 编写于 作者: D dengkaipeng

fix pre-commit. test=develop

上级 72a18bb1
......@@ -22,9 +22,9 @@ using Tensor = framework::Tensor;
template <typename T>
__global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
T* scores, const float conf_thresh, const int* anchors,
const int n, const int h, const int w,
const int an_num, const int class_num,
T* scores, const float conf_thresh,
const int* anchors, const int n, const int h,
const int w, const int an_num, const int class_num,
const int box_num, int input_size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
......@@ -50,7 +50,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0);
GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx,
grid_num, img_height, img_width);
grid_num, img_height, img_width);
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width);
......@@ -84,7 +84,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
int input_size = downsample_ratio * h;
Tensor anchors_t, cpu_anchors_t;
auto cpu_anchors_data = cpu_anchors_t.mutable_data<int>({an_num*2}, platform::CPUPlace());
auto cpu_anchors_data =
cpu_anchors_t.mutable_data<int>({an_num * 2}, platform::CPUPlace());
std::copy(anchors.begin(), anchors.end(), cpu_anchors_data);
TensorCopySync(cpu_anchors_t, ctx.GetPlace(), &anchors_t);
auto anchors_data = anchors_t.data<int>();
......@@ -103,8 +104,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
grid_dim = grid_dim > 8 ? 8 : grid_dim;
KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
anchors_data, n, h, w, an_num, class_num, box_num, input_size);
input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
anchors_data, n, h, w, an_num, class_num, box_num, input_size);
}
};
......@@ -112,6 +113,5 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(yolo_box,
ops::YoloBoxOpCUDAKernel<float>,
REGISTER_OP_CUDA_KERNEL(yolo_box, ops::YoloBoxOpCUDAKernel<float>,
ops::YoloBoxOpCUDAKernel<double>);
......@@ -20,7 +20,6 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T>
HOSTDEVICE inline T sigmoid(T x) {
return 1.0 / (1.0 + std::exp(-x));
......@@ -28,15 +27,15 @@ HOSTDEVICE inline T sigmoid(T x) {
template <typename T>
HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i,
int j, int an_idx, int grid_size,
int input_size, int index, int stride,
int img_height, int img_width) {
int j, int an_idx, int grid_size,
int input_size, int index, int stride,
int img_height, int img_width) {
box[0] = (i + sigmoid<T>(x[index])) * img_width / grid_size;
box[1] = (j + sigmoid<T>(x[index + stride])) * img_height / grid_size;
box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width /
input_size;
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * img_height /
input_size;
input_size;
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] *
img_height / input_size;
}
HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
......@@ -47,16 +46,22 @@ HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
template <typename T>
HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx,
const int img_height, const int img_width) {
const int img_height,
const int img_width) {
boxes[box_idx] = box[0] - box[2] / 2;
boxes[box_idx + 1] = box[1] - box[3] / 2;
boxes[box_idx + 2] = box[0] + box[2] / 2;
boxes[box_idx + 3] = box[1] + box[3] / 2;
boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast<T>(0);
boxes[box_idx + 1] = boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<T>(0);
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1 ? boxes[box_idx + 2] : static_cast<T>(img_width - 1);
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 ? boxes[box_idx + 3] : static_cast<T>(img_height - 1);
boxes[box_idx + 1] =
boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<T>(0);
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1
? boxes[box_idx + 2]
: static_cast<T>(img_width - 1);
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1
? boxes[box_idx + 3]
: static_cast<T>(img_height - 1);
}
template <typename T>
......@@ -92,8 +97,10 @@ class YoloBoxKernel : public framework::OpKernel<T> {
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
int anchors_[anchors.size()];
std::copy(anchors.begin(), anchors.end(), anchors_);
Tensor anchors_;
auto anchors_data =
anchors_.mutable_data<int>({an_num * 2}, ctx.GetPlace());
std::copy(anchors.begin(), anchors.end(), anchors_data);
const T* input_data = input->data<T>();
const int* imgsize_data = imgsize->data<int>();
......@@ -120,10 +127,11 @@ class YoloBoxKernel : public framework::OpKernel<T> {
int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0);
GetYoloBox<T>(box, input_data, anchors_, l, k, j, h, input_size,
box_idx, stride, img_height, img_width);
GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, input_size,
box_idx, stride, img_height, img_width);
box_idx = (i * box_num + j * stride + k * w + l) * 4;
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height, img_width);
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height,
img_width);
int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5);
......
......@@ -59,12 +59,19 @@ def YoloBox(x, img_size, attrs):
pred_box[:, :, :2], pred_box[:, :, 2:4] = \
pred_box[:, :, :2] - pred_box[:, :, 2:4] / 2., \
pred_box[:, :, :2] + pred_box[:, :, 2:4] / 2.0
# pred_box = pred_box * input_size
pred_box[:, :, 0] = pred_box[:, :, 0] * img_size[:, 1][:, np.newaxis]
pred_box[:, :, 1] = pred_box[:, :, 1] * img_size[:, 0][:, np.newaxis]
pred_box[:, :, 2] = pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis]
pred_box[:, :, 3] = pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis]
for i in range(len(pred_box)):
pred_box[i, :, 0] = np.clip(pred_box[i, :, 0], 0, np.inf)
pred_box[i, :, 1] = np.clip(pred_box[i, :, 1], 0, np.inf)
pred_box[i, :, 2] = np.clip(pred_box[i, :, 2], -np.inf,
img_size[i, 1] - 1)
pred_box[i, :, 3] = np.clip(pred_box[i, :, 3], -np.inf,
img_size[i, 0] - 1)
return pred_box, pred_score.reshape((n, -1, class_num))
......@@ -93,8 +100,7 @@ class TestYoloBoxOp(OpTest):
}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3)
self.check_output()
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册