diff --git a/paddle/fluid/operators/detection/iou_similarity_op.h b/paddle/fluid/operators/detection/iou_similarity_op.h index d8fb64329802b774f2e6c218d0cfca35079cbfe0..4dcd73e8aec0ba517b330ddf294b2c8df9a28ce4 100644 --- a/paddle/fluid/operators/detection/iou_similarity_op.h +++ b/paddle/fluid/operators/detection/iou_similarity_op.h @@ -18,7 +18,8 @@ limitations under the License. */ template inline HOSTDEVICE T IOUSimilarity(T xmin1, T ymin1, T xmax1, T ymax1, T xmin2, - T ymin2, T xmax2, T ymax2, bool normalized) { + T ymin2, T xmax2, T ymax2, bool normalized, + T eps) { constexpr T zero = static_cast(0); T area1; T area2; @@ -43,19 +44,21 @@ inline HOSTDEVICE T IOUSimilarity(T xmin1, T ymin1, T xmax1, T ymax1, T xmin2, inter_height = inter_height > zero ? inter_height : zero; inter_width = inter_width > zero ? inter_width : zero; T inter_area = inter_width * inter_height; - T union_area = area1 + area2 - inter_area; + T union_area = area1 + area2 - inter_area + eps; T sim_score = inter_area / union_area; return sim_score; } template struct IOUSimilarityFunctor { - IOUSimilarityFunctor(const T* x, const T* y, T* z, int cols, bool normalized) + IOUSimilarityFunctor(const T* x, const T* y, T* z, int cols, bool normalized, + T eps) : x_(x), y_(y), z_(z), cols_(static_cast(cols)), - normalized_(normalized) {} + normalized_(normalized), + eps_(eps) {} inline HOSTDEVICE void operator()(size_t tid) const { size_t row_id = tid / cols_; @@ -72,7 +75,7 @@ struct IOUSimilarityFunctor { T y_max2 = y_[col_id * 4 + 3]; T sim = IOUSimilarity(x_min1, y_min1, x_max1, y_max1, x_min2, y_min2, - x_max2, y_max2, normalized_); + x_max2, y_max2, normalized_, eps_); z_[row_id * cols_ + col_id] = sim; } @@ -81,6 +84,7 @@ struct IOUSimilarityFunctor { T* z_; const size_t cols_; bool normalized_; + T eps_; }; namespace paddle { @@ -97,9 +101,10 @@ class IOUSimilarityKernel : public framework::OpKernel { int x_n = in_x->dims()[0]; int y_n = in_y->dims()[0]; + T eps = static_cast(1e-10); IOUSimilarityFunctor functor(in_x->data(), in_y->data(), out->mutable_data(ctx.GetPlace()), y_n, - normalized); + normalized, eps); platform::ForRange for_range( static_cast(ctx.device_context()), x_n * y_n); diff --git a/paddle/fluid/operators/detection/yolo_box_op.cc b/paddle/fluid/operators/detection/yolo_box_op.cc index f6c32a1332cc8e1a63673b0b025437be9074f6f2..6f2a3ca87623847f261f0111bdfd8c168bb24b0a 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cc +++ b/paddle/fluid/operators/detection/yolo_box_op.cc @@ -136,6 +136,10 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { "Whether clip output bonding box in Input(ImgSize) " "boundary. Default true.") .SetDefault(true); + AddAttr("scale_x_y", + "Scale the center point of decoded bounding " + "box. Default 1.0") + .SetDefault(1.); AddComment(R"DOC( This operator generates YOLO detection boxes from output of YOLOv3 network. diff --git a/paddle/fluid/operators/detection/yolo_box_op.cu b/paddle/fluid/operators/detection/yolo_box_op.cu index b8476a7cf30d0ff15af83bcb5422e71f567bd583..6462b9f762a9ca31cd15aa5f1e3cc0bfbfa49d63 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cu +++ b/paddle/fluid/operators/detection/yolo_box_op.cu @@ -26,7 +26,8 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, T* scores, const float conf_thresh, const int* anchors, const int n, const int h, const int w, const int an_num, const int class_num, - const int box_num, int input_size, bool clip_bbox) { + const int box_num, int input_size, bool clip_bbox, + const float scale, const float bias) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; T box[4]; @@ -51,7 +52,7 @@ __global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0); GetYoloBox(box, input, anchors, l, k, j, h, input_size, box_idx, - grid_num, img_height, img_width); + grid_num, img_height, img_width, scale, bias); box_idx = (i * box_num + j * grid_num + k * w + l) * 4; CalcDetectionBox(boxes, box, box_idx, img_height, img_width, clip_bbox); @@ -77,6 +78,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { float conf_thresh = ctx.Attr("conf_thresh"); int downsample_ratio = ctx.Attr("downsample_ratio"); bool clip_bbox = ctx.Attr("clip_bbox"); + float scale = ctx.Attr("scale_x_y"); + float bias = -0.5 * (scale - 1.); const int n = input->dims()[0]; const int h = input->dims()[2]; @@ -109,7 +112,7 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel { KeYoloBoxFw<<>>( input_data, imgsize_data, boxes_data, scores_data, conf_thresh, anchors_data, n, h, w, an_num, class_num, box_num, input_size, - clip_bbox); + clip_bbox, scale, bias); } }; diff --git a/paddle/fluid/operators/detection/yolo_box_op.h b/paddle/fluid/operators/detection/yolo_box_op.h index b9c378e01f44ed36b5b09477dc1c3dc27a10e3c5..388467d37ba644abe651e1080ce14dd0d0e704bc 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.h +++ b/paddle/fluid/operators/detection/yolo_box_op.h @@ -29,9 +29,11 @@ template HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i, int j, int an_idx, int grid_size, int input_size, int index, int stride, - int img_height, int img_width) { - box[0] = (i + sigmoid(x[index])) * img_width / grid_size; - box[1] = (j + sigmoid(x[index + stride])) * img_height / grid_size; + int img_height, int img_width, float scale, + float bias) { + box[0] = (i + sigmoid(x[index]) * scale + bias) * img_width / grid_size; + box[1] = (j + sigmoid(x[index + stride]) * scale + bias) * img_height / + grid_size; box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width / input_size; box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * @@ -89,6 +91,8 @@ class YoloBoxKernel : public framework::OpKernel { float conf_thresh = ctx.Attr("conf_thresh"); int downsample_ratio = ctx.Attr("downsample_ratio"); bool clip_bbox = ctx.Attr("clip_bbox"); + float scale = ctx.Attr("scale_x_y"); + float bias = -0.5 * (scale - 1.); const int n = input->dims()[0]; const int h = input->dims()[2]; @@ -131,7 +135,7 @@ class YoloBoxKernel : public framework::OpKernel { int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0); GetYoloBox(box, input_data, anchors_data, l, k, j, h, input_size, - box_idx, stride, img_height, img_width); + box_idx, stride, img_height, img_width, scale, bias); box_idx = (i * box_num + j * stride + k * w + l) * 4; CalcDetectionBox(boxes_data, box, box_idx, img_height, img_width, clip_bbox); diff --git a/paddle/fluid/operators/detection/yolov3_loss_op.cc b/paddle/fluid/operators/detection/yolov3_loss_op.cc index ad13f334dae3659b183cdbba25d59cacef39c200..74f300357748a70fbf8edd7c6b6153bd09994fd0 100644 --- a/paddle/fluid/operators/detection/yolov3_loss_op.cc +++ b/paddle/fluid/operators/detection/yolov3_loss_op.cc @@ -215,6 +215,10 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("use_label_smooth", "Whether to use label smooth. Default True.") .SetDefault(true); + AddAttr("scale_x_y", + "Scale the center point of decoded bounding " + "box. Default 1.0") + .SetDefault(1.); AddComment(R"DOC( This operator generates yolov3 loss based on given predict result and ground truth boxes. diff --git a/paddle/fluid/operators/detection/yolov3_loss_op.h b/paddle/fluid/operators/detection/yolov3_loss_op.h index b29ad1c920d8fcebb6e9c7dda8a99ebda3a6423e..1acfb2cf4e50fb8ad461d133a0546974f573e873 100644 --- a/paddle/fluid/operators/detection/yolov3_loss_op.h +++ b/paddle/fluid/operators/detection/yolov3_loss_op.h @@ -73,10 +73,11 @@ static inline T sigmoid(T x) { template static inline Box GetYoloBox(const T* x, std::vector anchors, int i, int j, int an_idx, int grid_size, - int input_size, int index, int stride) { + int input_size, int index, int stride, + float scale, float bias) { Box b; - b.x = (i + sigmoid(x[index])) / grid_size; - b.y = (j + sigmoid(x[index + stride])) / grid_size; + b.x = (i + sigmoid(x[index]) * scale + bias) / grid_size; + b.y = (j + sigmoid(x[index + stride]) * scale + bias) / grid_size; b.w = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] / input_size; b.h = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] / input_size; return b; @@ -267,6 +268,8 @@ class Yolov3LossKernel : public framework::OpKernel { float ignore_thresh = ctx.Attr("ignore_thresh"); int downsample_ratio = ctx.Attr("downsample_ratio"); bool use_label_smooth = ctx.Attr("use_label_smooth"); + float scale = ctx.Attr("scale_x_y"); + float bias = -0.5 * (scale - 1.); const int n = input->dims()[0]; const int h = input->dims()[2]; @@ -325,8 +328,9 @@ class Yolov3LossKernel : public framework::OpKernel { // then ignore_thresh, ignore the objectness loss. int box_idx = GetEntryIndex(i, j, k * w + l, mask_num, an_stride, stride, 0); - Box pred = GetYoloBox(input_data, anchors, l, k, anchor_mask[j], - h, input_size, box_idx, stride); + Box pred = + GetYoloBox(input_data, anchors, l, k, anchor_mask[j], h, + input_size, box_idx, stride, scale, bias); T best_iou = 0; for (int t = 0; t < b; t++) { if (!gt_valid_mask_data[i * b + t]) { diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index ae92ede4eab47eebc4a8e5827014f04e7b6f2e62..8a2ecc36caa6bd0f9b3abb4969477e42c0b89d2d 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -932,7 +932,8 @@ def yolov3_loss(x, downsample_ratio, gt_score=None, use_label_smooth=True, - name=None): + name=None, + scale_x_y=1.): """ ${comment} @@ -958,6 +959,7 @@ def yolov3_loss(x, gt_score (Variable): mixup score of ground truth boxes, should be in shape of [N, B]. Default None. use_label_smooth (bool): ${use_label_smooth_comment} + scale_x_y (float): ${scale_x_y_comment} Returns: Variable: A 1-D tensor with shape [N], the value of yolov3 loss @@ -1030,6 +1032,7 @@ def yolov3_loss(x, "ignore_thresh": ignore_thresh, "downsample_ratio": downsample_ratio, "use_label_smooth": use_label_smooth, + "scale_x_y": scale_x_y, } helper.append_op( @@ -1052,7 +1055,8 @@ def yolo_box(x, conf_thresh, downsample_ratio, clip_bbox=True, - name=None): + name=None, + scale_x_y=1.): """ ${comment} @@ -1064,6 +1068,7 @@ def yolo_box(x, conf_thresh (float): ${conf_thresh_comment} downsample_ratio (int): ${downsample_ratio_comment} clip_bbox (bool): ${clip_bbox_comment} + scale_x_y (float): ${scale_x_y_comment} name (string): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` @@ -1112,6 +1117,7 @@ def yolo_box(x, "conf_thresh": conf_thresh, "downsample_ratio": downsample_ratio, "clip_bbox": clip_bbox, + "scale_x_y": scale_x_y, } helper.append_op( diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index cb798c8ed595d13dd8ff5e33323d6e796aaac6f9..0380a39dab9f259e102bae5f8d4fe4ee1375d930 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -527,6 +527,36 @@ class TestYoloDetection(unittest.TestCase): self.assertIsNotNone(boxes) self.assertIsNotNone(scores) + def test_yolov3_loss_with_scale(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') + gt_box = layers.data(name='gt_box', shape=[10, 4], dtype='float32') + gt_label = layers.data(name='gt_label', shape=[10], dtype='int32') + gt_score = layers.data(name='gt_score', shape=[10], dtype='float32') + loss = layers.yolov3_loss( + x, + gt_box, + gt_label, [10, 13, 30, 13], [0, 1], + 10, + 0.7, + 32, + gt_score=gt_score, + use_label_smooth=False, + scale_x_y=1.2) + + self.assertIsNotNone(loss) + + def test_yolo_box_with_scale(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') + img_size = layers.data(name='img_size', shape=[2], dtype='int32') + boxes, scores = layers.yolo_box( + x, img_size, [10, 13, 30, 13], 10, 0.01, 32, scale_x_y=1.2) + self.assertIsNotNone(boxes) + self.assertIsNotNone(scores) + class TestBoxClip(unittest.TestCase): def test_box_clip(self): diff --git a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py index 82b84a665bf2e3093ccb10b27dd28666aa4ad19e..ef53d8cec34a2ed1ce3db013094452b2ab9e7108 100644 --- a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py @@ -33,6 +33,8 @@ def YoloBox(x, img_size, attrs): conf_thresh = attrs['conf_thresh'] downsample = attrs['downsample'] clip_bbox = attrs['clip_bbox'] + scale_x_y = attrs['scale_x_y'] + bias_x_y = -0.5 * (scale_x_y - 1.) input_size = downsample * h x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) @@ -40,8 +42,10 @@ def YoloBox(x, img_size, attrs): pred_box = x[:, :, :, :, :4].copy() grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) - pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w - pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h + pred_box[:, :, :, :, 0] = ( + grid_x + sigmoid(pred_box[:, :, :, :, 0]) * scale_x_y + bias_x_y) / w + pred_box[:, :, :, :, 1] = ( + grid_y + sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y + bias_x_y) / h anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)] anchors_s = np.array( @@ -90,6 +94,7 @@ class TestYoloBoxOp(OpTest): "conf_thresh": self.conf_thresh, "downsample": self.downsample, "clip_bbox": self.clip_bbox, + "scale_x_y": self.scale_x_y, } self.inputs = { @@ -115,6 +120,7 @@ class TestYoloBoxOp(OpTest): self.clip_bbox = True self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) self.imgsize_shape = (self.batch_size, 2) + self.scale_x_y = 1. class TestYoloBoxOpNoClipBbox(TestYoloBoxOp): @@ -128,6 +134,21 @@ class TestYoloBoxOpNoClipBbox(TestYoloBoxOp): self.clip_bbox = False self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) self.imgsize_shape = (self.batch_size, 2) + self.scale_x_y = 1. + + +class TestYoloBoxOpScaleXY(TestYoloBoxOp): + def initTestCase(self): + self.anchors = [10, 13, 16, 30, 33, 23] + an_num = int(len(self.anchors) // 2) + self.batch_size = 32 + self.class_num = 2 + self.conf_thresh = 0.5 + self.downsample = 32 + self.clip_bbox = True + self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) + self.imgsize_shape = (self.batch_size, 2) + self.scale_x_y = 1.2 if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index de68d9ac3ae77902734bb88b75a4d1a632b48cc3..db73160c489b0584ab33b11061f0cc3f81f7da38 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -77,6 +77,8 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): ignore_thresh = attrs['ignore_thresh'] downsample_ratio = attrs['downsample_ratio'] use_label_smooth = attrs['use_label_smooth'] + scale_x_y = attrs['scale_x_y'] + bias_x_y = -0.5 * (scale_x_y - 1.) input_size = downsample_ratio * h x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) loss = np.zeros((n)).astype('float64') @@ -88,8 +90,10 @@ def YOLOv3Loss(x, gtbox, gtlabel, gtscore, attrs): pred_box = x[:, :, :, :, :4].copy() grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) - pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w - pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h + pred_box[:, :, :, :, 0] = ( + grid_x + sigmoid(pred_box[:, :, :, :, 0]) * scale_x_y + bias_x_y) / w + pred_box[:, :, :, :, 1] = ( + grid_y + sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y + bias_x_y) / h mask_anchors = [] for m in anchor_mask: @@ -180,6 +184,7 @@ class TestYolov3LossOp(OpTest): "ignore_thresh": self.ignore_thresh, "downsample_ratio": self.downsample_ratio, "use_label_smooth": self.use_label_smooth, + "scale_x_y": self.scale_x_y, } self.inputs = { @@ -222,6 +227,7 @@ class TestYolov3LossOp(OpTest): self.gtbox_shape = (3, 5, 4) self.gtscore = True self.use_label_smooth = True + self.scale_x_y = 1. class TestYolov3LossWithoutLabelSmooth(TestYolov3LossOp): @@ -238,6 +244,7 @@ class TestYolov3LossWithoutLabelSmooth(TestYolov3LossOp): self.gtbox_shape = (3, 5, 4) self.gtscore = True self.use_label_smooth = False + self.scale_x_y = 1. class TestYolov3LossNoGTScore(TestYolov3LossOp): @@ -254,6 +261,24 @@ class TestYolov3LossNoGTScore(TestYolov3LossOp): self.gtbox_shape = (3, 5, 4) self.gtscore = False self.use_label_smooth = True + self.scale_x_y = 1. + + +class TestYolov3LossWithScaleXY(TestYolov3LossOp): + def initTestCase(self): + self.anchors = [ + 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, + 373, 326 + ] + self.anchor_mask = [0, 1, 2] + self.class_num = 5 + self.ignore_thresh = 0.7 + self.downsample_ratio = 32 + self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) + self.gtbox_shape = (3, 5, 4) + self.gtscore = True + self.use_label_smooth = True + self.scale_x_y = 1.2 if __name__ == "__main__":