提交 3896d955 编写于 作者: D dengkaipeng

add yolo_box_op CPU kernel

上级 4e8c03bd
......@@ -33,6 +33,7 @@ detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc)
detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu)
detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc)
detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op.cu)
detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu)
if(WITH_GPU)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/yolo_box_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class YoloBoxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of YoloBoxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Boxes"),
"Output(Boxes) of YoloBoxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Scores"),
"Output(Scores) of YoloBoxOp should not be null.");
auto dim_x = ctx->GetInputDim("X");
auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
int anchor_num = anchors.size() / 2;
auto class_num = ctx->Attrs().Get<int>("class_num");
auto conf_thresh = ctx->Attrs().Get<float>("conf_thresh");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor.");
PADDLE_ENFORCE_EQ(
dim_x[1], anchor_num * (5 + class_num),
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"+ class_num)).");
PADDLE_ENFORCE_GT(anchors.size(), 0,
"Attr(anchors) length should be greater then 0.");
PADDLE_ENFORCE_EQ(anchors.size() % 2, 0,
"Attr(anchors) length should be even integer.");
PADDLE_ENFORCE_GT(class_num, 0,
"Attr(class_num) should be an integer greater then 0.");
int box_num = dim_x[2] * dim_x[3] * anchor_num;
std::vector<int64_t> dim_boxes({dim_x[0], box_num, 4});
ctx->SetOutputDim("Boxes", framework::make_ddim(dim_boxes));
std::vector<int64_t> dim_scores({dim_x[0], box_num, class_num});
ctx->SetOutputDim("Scores", framework::make_ddim(dim_scores));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
}
};
class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input tensor of YoloBox operator, "
"This is a 4-D tensor with shape of [N, C, H, W]."
"H and W should be same, and the second dimention(C) stores"
"box locations, confidence score and classification one-hot"
"keys of each anchor box. Generally, X should be the output"
"of YOLOv3 network.");
AddOutput("Boxes",
"The output tensor of detection boxes of YoloBox operator, "
"This is a 3-D tensor with shape of [N, M, 4], N is the"
"batch num, M is output box number, and the 3rd dimention"
"stores [xmin, ymin, xmax, ymax] coordinates of boxes.");
AddOutput("Scores",
"The output tensor ofdetection boxes scores of YoloBox"
"operator, This is a 3-D tensor with shape of [N, M, C],"
"N is the batch num, M is output box number, C is the"
"class number.");
AddAttr<int>("class_num", "The number of classes to predict.");
AddAttr<std::vector<int>>("anchors",
"The anchor width and height, "
"it will be parsed pair by pair.")
.SetDefault(std::vector<int>{});
AddAttr<int>("downsample_ratio",
"The downsample ratio from network input to YoloBox operator "
"input, so 32, 16, 8 should be set for the first, second, "
"and thrid YoloBox operators.")
.SetDefault(32);
AddAttr<float>("conf_thresh",
"The confidence scores threshold of detection boxes."
"boxes with confidence scores under threshold should"
"be ignored.")
.SetDefault(0.01);
AddComment(R"DOC(
This operator generate YOLO detection boxes fron output of YOLOv3 network.
The output of previous network is in shape [N, C, H, W], while H and W
should be the same, specify the grid size, each grid point predict given
number boxes, this given number is specified by anchors, it should be
half anchors length, which following will be represented as S. In the
second dimention(the channel dimention), C should be S * (class_num + 5),
class_num is the box categoriy number of source dataset(such as coco),
so in the second dimention, stores 4 box location coordinates x, y, w, h
and confidence score of the box and class one-hot key of each anchor box.
While the 4 location coordinates if $$tx, ty, tw, th$$, the box predictions
correspnd to:
$$
b_x = \sigma(t_x) + c_x
b_y = \sigma(t_y) + c_y
b_w = p_w e^{t_w}
b_h = p_h e^{t_h}
$$
While $$c_x, c_y$$ is the left top corner of current grid and $$p_w, p_h$$
is specified by anchors.
The logistic scores of the 5rd channel of each anchor prediction boxes
represent the confidence score of each prediction scores, and the logistic
scores of the last class_num channels of each anchor prediction boxes
represent the classifcation scores. Boxes with confidence scores less then
conf_thresh should be ignored, and boxes final scores if the products result
of confidence scores and classification scores.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(yolo_box, ops::YoloBoxOp, ops::YoloBoxOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(yolo_box, ops::YoloBoxKernel<float>,
ops::YoloBoxKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/yolo_box_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
static __global__ void GenDensityPriorBox(
const int height, const int width, const int im_height, const int im_width,
const T offset, const T step_width, const T step_height,
const int num_priors, const T* ratios_shift, bool is_clip, const T var_xmin,
const T var_ymin, const T var_xmax, const T var_ymax, T* out, T* var) {
int gidx = blockIdx.x * blockDim.x + threadIdx.x;
int gidy = blockIdx.y * blockDim.y + threadIdx.y;
int step_x = blockDim.x * gridDim.x;
int step_y = blockDim.y * gridDim.y;
}
template <typename T>
class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto* boxes = ctx.Output<Tensor>("Boxes");
auto* scores = ctx.Output<Tensor>("Scores");
auto anchors = ctx.Attr<std::vector<int>>("anchors");
int class_num = ctx.Attr<int>("class_num");
float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
const int n = input->dims()[0];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int box_num = boxes->dims()[1];
const int an_num = anchors.size() / 2;
int input_size = downsample_ratio * h;
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
const T* input_data = input->data<T>();
T* boxes_data = boxes->mutable_data<T>({n}, ctx.GetPlace());
memset(loss_data, 0, boxes->numel() * sizeof(T));
T* scores_data = scores->mutable_data<T>({n}, ctx.GetPlace());
memset(scores_data, 0, scores->numel() * sizeof(T));
}
}; // namespace operators
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(density_prior_box,
ops::DensityPriorBoxOpCUDAKernel<float>,
ops::DensityPriorBoxOpCUDAKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
struct Box {
T x, y, w, h;
};
template <typename T>
static inline T sigmoid(T x) {
return 1.0 / (1.0 + std::exp(-x));
}
template <typename T>
static inline Box<T> GetYoloBox(const T* x, std::vector<int> anchors, int i,
int j, int an_idx, int grid_size,
int input_size, int index, int stride) {
Box<T> b;
b.x = (i + sigmoid<T>(x[index])) * input_size / grid_size;
b.y = (j + sigmoid<T>(x[index + stride])) * input_size / grid_size;
b.w = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx];
b.h = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1];
return b;
}
static inline int GetEntryIndex(int batch, int an_idx, int hw_idx, int an_num,
int an_stride, int stride, int entry) {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
}
template <typename T>
static inline void CalcDetectionBox(T* boxes, Box<T> pred, const int box_idx) {
boxes[box_idx] = pred.x - pred.w / 2;
boxes[box_idx + 1] = pred.y - pred.h / 2;
boxes[box_idx + 2] = pred.x + pred.w / 2;
boxes[box_idx + 3] = pred.y + pred.h / 2;
}
template <typename T>
static inline void CalcLabelScore(T* scores, const T* input,
const int label_idx, const int score_idx,
const int class_num, const T conf,
const int stride) {
for (int i = 0; i < class_num; i++) {
scores[score_idx + i] = conf * sigmoid<T>(input[label_idx + i * stride]);
}
}
template <typename T>
class YoloBoxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* boxes = ctx.Output<Tensor>("Boxes");
auto* scores = ctx.Output<Tensor>("Scores");
auto anchors = ctx.Attr<std::vector<int>>("anchors");
int class_num = ctx.Attr<int>("class_num");
float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
const int n = input->dims()[0];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int box_num = boxes->dims()[1];
const int an_num = anchors.size() / 2;
int input_size = downsample_ratio * h;
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
const T* input_data = input->data<T>();
T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, ctx.GetPlace());
memset(boxes_data, 0, boxes->numel() * sizeof(T));
T* scores_data =
scores->mutable_data<T>({n, box_num, class_num}, ctx.GetPlace());
memset(scores_data, 0, scores->numel() * sizeof(T));
for (int i = 0; i < n; i++) {
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
int obj_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 4);
T conf = sigmoid<T>(input_data[obj_idx]);
if (conf < conf_thresh) {
continue;
}
int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0);
Box<T> pred = GetYoloBox(input_data, anchors, l, k, j, h,
input_size, box_idx, stride);
box_idx = (i * box_num + j * stride + k * w + l) * 4;
CalcDetectionBox<T>(boxes_data, pred, box_idx);
int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5);
int score_idx = (i * box_num + j * stride + k * w + l) * class_num;
CalcLabelScore<T>(scores_data, input_data, label_idx, score_idx,
class_num, conf, stride);
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -49,6 +49,7 @@ __all__ = [
'box_coder',
'polygon_box_transform',
'yolov3_loss',
'yolo_box',
'box_clip',
'multiclass_nms',
'distribute_fpn_proposals',
......@@ -609,6 +610,71 @@ def yolov3_loss(x,
return loss
@templatedoc(op_type="yolo_box")
def yolo_box(x, anchors, class_num, conf_thresh, downsample_ratio, name=None):
"""
${comment}
Args:
x (Variable): ${x_comment}
anchors (list|tuple): ${anchors_comment}
class_num (int): ${class_num_comment}
conf_thresh (float): ${conf_thresh_comment}
downsample_ratio (int): ${downsample_ratio_comment}
name (string): the name of yolov3 loss
Returns:
Variable: A 1-D tensor with shape [1], the value of yolov3 loss
Raises:
TypeError: Input x of yolov_box must be Variable
TypeError: Attr anchors of yolo box must be list or tuple
TypeError: Attr class_num of yolo box must be an integer
TypeError: Attr conf_thresh of yolo box must be a float number
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32')
anchors = [10, 13, 16, 30, 33, 23]
loss = fluid.layers.yolov3_loss(x=x, class_num=80, anchors=anchors,
conf_thresh=0.01, downsample_ratio=32)
"""
helper = LayerHelper('yolo_box', **locals())
if not isinstance(x, Variable):
raise TypeError("Input x of yolov3_loss must be Variable")
if not isinstance(anchors, list) and not isinstance(anchors, tuple):
raise TypeError("Attr anchors of yolov3_loss must be list or tuple")
if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple):
raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple")
if not isinstance(class_num, int):
raise TypeError("Attr class_num of yolov3_loss must be an integer")
if not isinstance(conf_thresh, float):
raise TypeError(
"Attr ignore_thresh of yolov3_loss must be a float number")
boxes = helper.create_variable_for_type_inference(dtype=x.dtype)
scores = helper.create_variable_for_type_inference(dtype=x.dtype)
attrs = {
"anchors": anchors,
"class_num": class_num,
"conf_thresh": ignore_thresh,
"downsample_ratio": downsample_ratio,
}
helper.append_op(
type='yolo_box',
inputs={"X": x, },
outputs={
'Boxes': boxes,
'Scores': scores,
},
attrs=attrs)
return boxes, scores
@templatedoc()
def detection_map(detect_res,
label,
......
......@@ -478,9 +478,16 @@ class TestYoloDetection(unittest.TestCase):
gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32')
loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13],
[0, 1], 10, 0.7, 32)
self.assertIsNotNone(loss)
def test_yolo_box(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
boxes, scores = layers.yolo_box(x, [10, 13, 30, 13], 10, 0.01, 32)
self.assertIsNotNone(boxes)
self.assertIsNotNone(scores)
class TestBoxClip(unittest.TestCase):
def test_box_clip(self):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import unittest
import numpy as np
from op_test import OpTest
from paddle.fluid import core
def sigmoid(x):
return 1.0 / (1.0 + np.exp(-1.0 * x))
def YoloBox(x, attrs):
n, c, h, w = x.shape
anchors = attrs['anchors']
an_num = int(len(anchors) // 2)
class_num = attrs['class_num']
conf_thresh = attrs['conf_thresh']
downsample = attrs['downsample']
input_size = downsample * h
x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
pred_box = x[:, :, :, :, :4].copy()
grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))
grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w))
pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w
pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h
anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
anchors_s = np.array(
[(an_w / input_size, an_h / input_size) for an_w, an_h in anchors])
anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1))
anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1))
pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w
pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h
pred_conf = sigmoid(x[:, :, :, :, 4:5])
pred_conf[pred_conf < conf_thresh] = 0.
pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf
pred_box = pred_box * (pred_conf > 0.).astype('float32')
pred_box = pred_box.reshape((n, -1, 4))
pred_box[:, :, :
2], pred_box[:, :, 2:
4] = pred_box[:, :, :
2] - pred_box[:, :, 2:
4] / 2., pred_box[:, :, :
2] + pred_box[:, :,
2:
4] / 2.0
pred_box = pred_box * input_size
return pred_box, pred_score.reshape((n, -1, class_num))
class TestYoloBoxOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'yolo_box'
x = np.random.random(self.x_shape).astype('float32')
self.attrs = {
"anchors": self.anchors,
"class_num": self.class_num,
"conf_thresh": self.conf_thresh,
"downsample": self.downsample,
}
self.inputs = {'X': x, }
boxes, scores = YoloBox(x, self.attrs)
self.outputs = {
"Boxes": boxes,
"Scores": scores,
}
def test_check_output(self):
self.check_output()
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int(len(self.anchors) // 2)
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.x_shape = (3, an_num * (5 + self.class_num), 5, 5)
if __name__ == "__main__":
unittest.main()
......@@ -75,8 +75,8 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs):
mask_num = len(anchor_mask)
class_num = attrs["class_num"]
ignore_thresh = attrs['ignore_thresh']
downsample = attrs['downsample']
input_size = downsample * h
downsample_ratio = attrs['downsample_ratio']
input_size = downsample_ratio * h
x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
loss = np.zeros((n)).astype('float32')
......@@ -86,10 +86,6 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs):
pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w
pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h
x[:, :, :, :, 5:] = np.where(x[:, :, :, :, 5:] < -0.5, x[:, :, :, :, 5:],
np.ones_like(x[:, :, :, :, 5:]) * 1.0 /
class_num)
mask_anchors = []
for m in anchor_mask:
mask_anchors.append((anchors[2 * m], anchors[2 * m + 1]))
......@@ -176,7 +172,7 @@ class TestYolov3LossOp(OpTest):
"anchor_mask": self.anchor_mask,
"class_num": self.class_num,
"ignore_thresh": self.ignore_thresh,
"downsample": self.downsample,
"downsample_ratio": self.downsample_ratio,
}
self.inputs = {
......@@ -208,7 +204,7 @@ class TestYolov3LossOp(OpTest):
self.anchor_mask = [1, 2]
self.class_num = 5
self.ignore_thresh = 0.5
self.downsample = 32
self.downsample_ratio = 32
self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5)
self.gtbox_shape = (3, 5, 4)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册