未验证 提交 abb49df2 编写于 作者: W wangguanzhong 提交者: GitHub

Enhance yolo_box & yolov3_loss (#24370)

* add scale_x_y for yolo_box, test=develop

* refine eps in iou_similarity, test=develop
上级 d1bb76a2
......@@ -18,7 +18,8 @@ limitations under the License. */
template <typename T>
inline HOSTDEVICE T IOUSimilarity(T xmin1, T ymin1, T xmax1, T ymax1, T xmin2,
T ymin2, T xmax2, T ymax2, bool normalized) {
T ymin2, T xmax2, T ymax2, bool normalized,
T eps) {
constexpr T zero = static_cast<T>(0);
T area1;
T area2;
......@@ -43,19 +44,21 @@ inline HOSTDEVICE T IOUSimilarity(T xmin1, T ymin1, T xmax1, T ymax1, T xmin2,
inter_height = inter_height > zero ? inter_height : zero;
inter_width = inter_width > zero ? inter_width : zero;
T inter_area = inter_width * inter_height;
T union_area = area1 + area2 - inter_area;
T union_area = area1 + area2 - inter_area + eps;
T sim_score = inter_area / union_area;
return sim_score;
}
template <typename T>
struct IOUSimilarityFunctor {
IOUSimilarityFunctor(const T* x, const T* y, T* z, int cols, bool normalized)
IOUSimilarityFunctor(const T* x, const T* y, T* z, int cols, bool normalized,
T eps)
: x_(x),
y_(y),
z_(z),
cols_(static_cast<size_t>(cols)),
normalized_(normalized) {}
normalized_(normalized),
eps_(eps) {}
inline HOSTDEVICE void operator()(size_t tid) const {
size_t row_id = tid / cols_;
......@@ -72,7 +75,7 @@ struct IOUSimilarityFunctor {
T y_max2 = y_[col_id * 4 + 3];
T sim = IOUSimilarity(x_min1, y_min1, x_max1, y_max1, x_min2, y_min2,
x_max2, y_max2, normalized_);
x_max2, y_max2, normalized_, eps_);
z_[row_id * cols_ + col_id] = sim;
}
......@@ -81,6 +84,7 @@ struct IOUSimilarityFunctor {
T* z_;
const size_t cols_;
bool normalized_;
T eps_;
};
namespace paddle {
......@@ -97,9 +101,10 @@ class IOUSimilarityKernel : public framework::OpKernel<T> {
int x_n = in_x->dims()[0];
int y_n = in_y->dims()[0];
T eps = static_cast<T>(1e-10);
IOUSimilarityFunctor<T> functor(in_x->data<T>(), in_y->data<T>(),
out->mutable_data<T>(ctx.GetPlace()), y_n,
normalized);
normalized, eps);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()), x_n * y_n);
......
......@@ -136,6 +136,10 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"Whether clip output bonding box in Input(ImgSize) "
"boundary. Default true.")
.SetDefault(true);
AddAttr<float>("scale_x_y",
"Scale the center point of decoded bounding "
"box. Default 1.0")
.SetDefault(1.);
AddComment(R"DOC(
This operator generates YOLO detection boxes from output of YOLOv3 network.
......
......@@ -26,7 +26,8 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
T* scores, const float conf_thresh,
const int* anchors, const int n, const int h,
const int w, const int an_num, const int class_num,
const int box_num, int input_size, bool clip_bbox) {
const int box_num, int input_size, bool clip_bbox,
const float scale, const float bias) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
T box[4];
......@@ -51,7 +52,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes,
int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0);
GetYoloBox<T>(box, input, anchors, l, k, j, h, input_size, box_idx,
grid_num, img_height, img_width);
grid_num, img_height, img_width, scale, bias);
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox);
......@@ -77,6 +78,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool clip_bbox = ctx.Attr<bool>("clip_bbox");
float scale = ctx.Attr<float>("scale_x_y");
float bias = -0.5 * (scale - 1.);
const int n = input->dims()[0];
const int h = input->dims()[2];
......@@ -109,7 +112,7 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
KeYoloBoxFw<T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, imgsize_data, boxes_data, scores_data, conf_thresh,
anchors_data, n, h, w, an_num, class_num, box_num, input_size,
clip_bbox);
clip_bbox, scale, bias);
}
};
......
......@@ -29,9 +29,11 @@ template <typename T>
HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i,
int j, int an_idx, int grid_size,
int input_size, int index, int stride,
int img_height, int img_width) {
box[0] = (i + sigmoid<T>(x[index])) * img_width / grid_size;
box[1] = (j + sigmoid<T>(x[index + stride])) * img_height / grid_size;
int img_height, int img_width, float scale,
float bias) {
box[0] = (i + sigmoid<T>(x[index]) * scale + bias) * img_width / grid_size;
box[1] = (j + sigmoid<T>(x[index + stride]) * scale + bias) * img_height /
grid_size;
box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width /
input_size;
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] *
......@@ -89,6 +91,8 @@ class YoloBoxKernel : public framework::OpKernel<T> {
float conf_thresh = ctx.Attr<float>("conf_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool clip_bbox = ctx.Attr<bool>("clip_bbox");
float scale = ctx.Attr<float>("scale_x_y");
float bias = -0.5 * (scale - 1.);
const int n = input->dims()[0];
const int h = input->dims()[2];
......@@ -131,7 +135,7 @@ class YoloBoxKernel : public framework::OpKernel<T> {
int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0);
GetYoloBox<T>(box, input_data, anchors_data, l, k, j, h, input_size,
box_idx, stride, img_height, img_width);
box_idx, stride, img_height, img_width, scale, bias);
box_idx = (i * box_num + j * stride + k * w + l) * 4;
CalcDetectionBox<T>(boxes_data, box, box_idx, img_height, img_width,
clip_bbox);
......
......@@ -215,6 +215,10 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_label_smooth",
"Whether to use label smooth. Default True.")
.SetDefault(true);
AddAttr<float>("scale_x_y",
"Scale the center point of decoded bounding "
"box. Default 1.0")
.SetDefault(1.);
AddComment(R"DOC(
This operator generates yolov3 loss based on given predict result and ground
truth boxes.
......
......@@ -73,10 +73,11 @@ static inline T sigmoid(T x) {
template <typename T>
static inline Box<T> GetYoloBox(const T* x, std::vector<int> anchors, int i,
int j, int an_idx, int grid_size,
int input_size, int index, int stride) {
int input_size, int index, int stride,
float scale, float bias) {
Box<T> b;
b.x = (i + sigmoid<T>(x[index])) / grid_size;
b.y = (j + sigmoid<T>(x[index + stride])) / grid_size;
b.x = (i + sigmoid<T>(x[index]) * scale + bias) / grid_size;
b.y = (j + sigmoid<T>(x[index + stride]) * scale + bias) / grid_size;
b.w = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] / input_size;
b.h = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] / input_size;
return b;
......@@ -267,6 +268,8 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
float ignore_thresh = ctx.Attr<float>("ignore_thresh");
int downsample_ratio = ctx.Attr<int>("downsample_ratio");
bool use_label_smooth = ctx.Attr<bool>("use_label_smooth");
float scale = ctx.Attr<float>("scale_x_y");
float bias = -0.5 * (scale - 1.);
const int n = input->dims()[0];
const int h = input->dims()[2];
......@@ -325,8 +328,9 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
// then ignore_thresh, ignore the objectness loss.
int box_idx =
GetEntryIndex(i, j, k * w + l, mask_num, an_stride, stride, 0);
Box<T> pred = GetYoloBox(input_data, anchors, l, k, anchor_mask[j],
h, input_size, box_idx, stride);
Box<T> pred =
GetYoloBox(input_data, anchors, l, k, anchor_mask[j], h,
input_size, box_idx, stride, scale, bias);
T best_iou = 0;
for (int t = 0; t < b; t++) {
if (!gt_valid_mask_data[i * b + t]) {
......
......@@ -919,7 +919,8 @@ def yolov3_loss(x,
downsample_ratio,
gt_score=None,
use_label_smooth=True,
name=None):
name=None,
scale_x_y=1.):
"""
${comment}
......@@ -945,6 +946,7 @@ def yolov3_loss(x,
gt_score (Variable): mixup score of ground truth boxes, should be in shape
of [N, B]. Default None.
use_label_smooth (bool): ${use_label_smooth_comment}
scale_x_y (float): ${scale_x_y_comment}
Returns:
Variable: A 1-D tensor with shape [N], the value of yolov3 loss
......@@ -1017,6 +1019,7 @@ def yolov3_loss(x,
"ignore_thresh": ignore_thresh,
"downsample_ratio": downsample_ratio,
"use_label_smooth": use_label_smooth,
"scale_x_y": scale_x_y,
}
helper.append_op(
......@@ -1039,7 +1042,8 @@ def yolo_box(x,
conf_thresh,
downsample_ratio,
clip_bbox=True,
name=None):
name=None,
scale_x_y=1.):
"""
${comment}
......@@ -1051,6 +1055,7 @@ def yolo_box(x,
conf_thresh (float): ${conf_thresh_comment}
downsample_ratio (int): ${downsample_ratio_comment}
clip_bbox (bool): ${clip_bbox_comment}
scale_x_y (float): ${scale_x_y_comment}
name (string): The default value is None. Normally there is no need
for user to set this property. For more information,
please refer to :ref:`api_guide_Name`
......@@ -1099,6 +1104,7 @@ def yolo_box(x,
"conf_thresh": conf_thresh,
"downsample_ratio": downsample_ratio,
"clip_bbox": clip_bbox,
"scale_x_y": scale_x_y,
}
helper.append_op(
......
......@@ -552,6 +552,36 @@ class TestYoloDetection(unittest.TestCase):
self.assertIsNotNone(boxes)
self.assertIsNotNone(scores)
def test_yolov3_loss_with_scale(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
gt_box = layers.data(name='gt_box', shape=[10, 4], dtype='float32')
gt_label = layers.data(name='gt_label', shape=[10], dtype='int32')
gt_score = layers.data(name='gt_score', shape=[10], dtype='float32')
loss = layers.yolov3_loss(
x,
gt_box,
gt_label, [10, 13, 30, 13], [0, 1],
10,
0.7,
32,
gt_score=gt_score,
use_label_smooth=False,
scale_x_y=1.2)
self.assertIsNotNone(loss)
def test_yolo_box_with_scale(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
img_size = layers.data(name='img_size', shape=[2], dtype='int32')
boxes, scores = layers.yolo_box(
x, img_size, [10, 13, 30, 13], 10, 0.01, 32, scale_x_y=1.2)
self.assertIsNotNone(boxes)
self.assertIsNotNone(scores)
class TestBoxClip(unittest.TestCase):
def test_box_clip(self):
......
......@@ -33,6 +33,8 @@ def YoloBox(x, img_size, attrs):
conf_thresh = attrs['conf_thresh']
downsample = attrs['downsample']
clip_bbox = attrs['clip_bbox']
scale_x_y = attrs['scale_x_y']
bias_x_y = -0.5 * (scale_x_y - 1.)
input_size = downsample * h
x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
......@@ -40,8 +42,10 @@ def YoloBox(x, img_size, attrs):
pred_box = x[:, :, :, :, :4].copy()
grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))
grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w))
pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w
pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h
pred_box[:, :, :, :, 0] = (
grid_x + sigmoid(pred_box[:, :, :, :, 0]) * scale_x_y + bias_x_y) / w
pred_box[:, :, :, :, 1] = (
grid_y + sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y + bias_x_y) / h
anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
anchors_s = np.array(
......@@ -90,6 +94,7 @@ class TestYoloBoxOp(OpTest):
"conf_thresh": self.conf_thresh,
"downsample": self.downsample,
"clip_bbox": self.clip_bbox,
"scale_x_y": self.scale_x_y,
}
self.inputs = {
......@@ -115,6 +120,7 @@ class TestYoloBoxOp(OpTest):
self.clip_bbox = True
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.
class TestYoloBoxOpNoClipBbox(TestYoloBoxOp):
......@@ -128,6 +134,21 @@ class TestYoloBoxOpNoClipBbox(TestYoloBoxOp):
self.clip_bbox = False
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.
class TestYoloBoxOpScaleXY(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int(len(self.anchors) // 2)
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = True
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.2
if __name__ == "__main__":
......
......@@ -77,6 +77,8 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs):
ignore_thresh = attrs['ignore_thresh']
downsample_ratio = attrs['downsample_ratio']
use_label_smooth = attrs['use_label_smooth']
scale_x_y = attrs['scale_x_y']
bias_x_y = -0.5 * (scale_x_y - 1.)
input_size = downsample_ratio * h
x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
loss = np.zeros((n)).astype('float64')
......@@ -88,8 +90,10 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs):
pred_box = x[:, :, :, :, :4].copy()
grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))
grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w))
pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w
pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h
pred_box[:, :, :, :, 0] = (
grid_x + sigmoid(pred_box[:, :, :, :, 0]) * scale_x_y + bias_x_y) / w
pred_box[:, :, :, :, 1] = (
grid_y + sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y + bias_x_y) / h
mask_anchors = []
for m in anchor_mask:
......@@ -180,6 +184,7 @@ class TestYolov3LossOp(OpTest):
"ignore_thresh": self.ignore_thresh,
"downsample_ratio": self.downsample_ratio,
"use_label_smooth": self.use_label_smooth,
"scale_x_y": self.scale_x_y,
}
self.inputs = {
......@@ -222,6 +227,7 @@ class TestYolov3LossOp(OpTest):
self.gtbox_shape = (3, 5, 4)
self.gtscore = True
self.use_label_smooth = True
self.scale_x_y = 1.
class TestYolov3LossWithoutLabelSmooth(TestYolov3LossOp):
......@@ -238,6 +244,7 @@ class TestYolov3LossWithoutLabelSmooth(TestYolov3LossOp):
self.gtbox_shape = (3, 5, 4)
self.gtscore = True
self.use_label_smooth = False
self.scale_x_y = 1.
class TestYolov3LossNoGTScore(TestYolov3LossOp):
......@@ -254,6 +261,24 @@ class TestYolov3LossNoGTScore(TestYolov3LossOp):
self.gtbox_shape = (3, 5, 4)
self.gtscore = False
self.use_label_smooth = True
self.scale_x_y = 1.
class TestYolov3LossWithScaleXY(TestYolov3LossOp):
def initTestCase(self):
self.anchors = [
10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198,
373, 326
]
self.anchor_mask = [0, 1, 2]
self.class_num = 5
self.ignore_thresh = 0.7
self.downsample_ratio = 32
self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5)
self.gtbox_shape = (3, 5, 4)
self.gtscore = True
self.use_label_smooth = True
self.scale_x_y = 1.2
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册