From 452373decbf9f196d9c3f52fd21214e439a68ece Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Tue, 19 Feb 2019 10:31:28 +0800 Subject: [PATCH] resize box in input image scale. test=develop --- .../fluid/operators/detection/yolo_box_op.cc | 14 +++++++++ .../fluid/operators/detection/yolo_box_op.h | 23 +++++++++----- python/paddle/fluid/layers/detection.py | 27 +++++++++++----- python/paddle/fluid/tests/test_detection.py | 4 ++- .../fluid/tests/unittests/test_yolo_box_op.py | 31 +++++++++++-------- 5 files changed, 70 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/detection/yolo_box_op.cc b/paddle/fluid/operators/detection/yolo_box_op.cc index 4c2c5d1e6..f78a98067 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cc +++ b/paddle/fluid/operators/detection/yolo_box_op.cc @@ -23,12 +23,15 @@ class YoloBoxOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of YoloBoxOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("ImgSize"), + "Input(ImgSize) of YoloBoxOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Boxes"), "Output(Boxes) of YoloBoxOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Scores"), "Output(Scores) of YoloBoxOp should not be null."); auto dim_x = ctx->GetInputDim("X"); + auto dim_imgsize = ctx->GetInputDim("ImgSize"); auto anchors = ctx->Attrs().Get>("anchors"); int anchor_num = anchors.size() / 2; auto class_num = ctx->Attrs().Get("class_num"); @@ -39,6 +42,12 @@ class YoloBoxOp : public framework::OperatorWithKernel { dim_x[1], anchor_num * (5 + class_num), "Input(X) dim[1] should be equal to (anchor_mask_number * (5 " "+ class_num))."); + PADDLE_ENFORCE_EQ(dim_imgsize.size(), 2, + "Input(ImgSize) should be a 2-D tensor."); + PADDLE_ENFORCE_EQ( + dim_imgsize[0], dim_x[0], + "Input(ImgSize) dim[0] and Input(X) dim[0] should be same."); + PADDLE_ENFORCE_EQ(dim_imgsize[1], 2, "Input(ImgSize) dim[1] should be 2."); PADDLE_ENFORCE_GT(anchors.size(), 0, "Attr(anchors) length should be greater then 0."); PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, @@ -72,6 +81,11 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { "box locations, confidence score and classification one-hot" "keys of each anchor box. Generally, X should be the output" "of YOLOv3 network."); + AddInput("ImgSize", + "The image size tensor of YoloBox operator, " + "This is a 2-D tensor with shape of [N, 2]. This tensor holds" + "height and width of each input image using for resize output" + "box in input image scale."); AddOutput("Boxes", "The output tensor of detection boxes of YoloBox operator, " "This is a 3-D tensor with shape of [N, M, 4], N is the" diff --git a/paddle/fluid/operators/detection/yolo_box_op.h b/paddle/fluid/operators/detection/yolo_box_op.h index 7a9ebf46d..0ea8c1786 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.h +++ b/paddle/fluid/operators/detection/yolo_box_op.h @@ -32,12 +32,15 @@ static inline T sigmoid(T x) { template static inline Box GetYoloBox(const T* x, std::vector anchors, int i, int j, int an_idx, int grid_size, - int input_size, int index, int stride) { + int input_size, int index, int stride, + int img_height, int img_width) { Box b; - b.x = (i + sigmoid(x[index])) * input_size / grid_size; - b.y = (j + sigmoid(x[index + stride])) * input_size / grid_size; - b.w = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx]; - b.h = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1]; + b.x = (i + sigmoid(x[index])) * img_width / grid_size; + b.y = (j + sigmoid(x[index + stride])) * img_height / grid_size; + b.w = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width / + input_size; + b.h = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * img_height / + input_size; return b; } @@ -69,6 +72,7 @@ class YoloBoxKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("X"); + auto* imgsize = ctx.Input("ImgSize"); auto* boxes = ctx.Output("Boxes"); auto* scores = ctx.Output("Scores"); auto anchors = ctx.Attr>("anchors"); @@ -87,6 +91,7 @@ class YoloBoxKernel : public framework::OpKernel { const int an_stride = (class_num + 5) * stride; const T* input_data = input->data(); + const int* imgsize_data = imgsize->data(); T* boxes_data = boxes->mutable_data({n, box_num, 4}, ctx.GetPlace()); memset(boxes_data, 0, boxes->numel() * sizeof(T)); T* scores_data = @@ -94,6 +99,9 @@ class YoloBoxKernel : public framework::OpKernel { memset(scores_data, 0, scores->numel() * sizeof(T)); for (int i = 0; i < n; i++) { + int img_height = imgsize_data[2 * i]; + int img_width = imgsize_data[2 * i + 1]; + for (int j = 0; j < an_num; j++) { for (int k = 0; k < h; k++) { for (int l = 0; l < w; l++) { @@ -106,8 +114,9 @@ class YoloBoxKernel : public framework::OpKernel { int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0); - Box pred = GetYoloBox(input_data, anchors, l, k, j, h, - input_size, box_idx, stride); + Box pred = + GetYoloBox(input_data, anchors, l, k, j, h, input_size, box_idx, + stride, img_height, img_width); box_idx = (i * box_num + j * stride + k * w + l) * 4; CalcDetectionBox(boxes_data, pred, box_idx); diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 29020f824..b64e19320 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -611,12 +611,19 @@ def yolov3_loss(x, @templatedoc(op_type="yolo_box") -def yolo_box(x, anchors, class_num, conf_thresh, downsample_ratio, name=None): +def yolo_box(x, + img_size, + anchors, + class_num, + conf_thresh, + downsample_ratio, + name=None): """ ${comment} Args: x (Variable): ${x_comment} + img_size (Variable): ${img_size_comment} anchors (list|tuple): ${anchors_comment} class_num (int): ${class_num_comment} conf_thresh (float): ${conf_thresh_comment} @@ -643,16 +650,17 @@ def yolo_box(x, anchors, class_num, conf_thresh, downsample_ratio, name=None): helper = LayerHelper('yolo_box', **locals()) if not isinstance(x, Variable): - raise TypeError("Input x of yolov3_loss must be Variable") + raise TypeError("Input x of yolo_box must be Variable") + if not isinstance(img_size, Variable): + raise TypeError("Input img_size of yolo_box must be Variable") if not isinstance(anchors, list) and not isinstance(anchors, tuple): - raise TypeError("Attr anchors of yolov3_loss must be list or tuple") + raise TypeError("Attr anchors of yolo_box must be list or tuple") if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple): - raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple") + raise TypeError("Attr anchor_mask of yolo_box must be list or tuple") if not isinstance(class_num, int): - raise TypeError("Attr class_num of yolov3_loss must be an integer") + raise TypeError("Attr class_num of yolo_box must be an integer") if not isinstance(conf_thresh, float): - raise TypeError( - "Attr ignore_thresh of yolov3_loss must be a float number") + raise TypeError("Attr ignore_thresh of yolo_box must be a float number") boxes = helper.create_variable_for_type_inference(dtype=x.dtype) scores = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -666,7 +674,10 @@ def yolo_box(x, anchors, class_num, conf_thresh, downsample_ratio, name=None): helper.append_op( type='yolo_box', - inputs={"X": x, }, + inputs={ + "X": x, + "ImgSize": img_size, + }, outputs={ 'Boxes': boxes, 'Scores': scores, diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 9592bbe2e..b8743debe 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -484,7 +484,9 @@ class TestYoloDetection(unittest.TestCase): program = Program() with program_guard(program): x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') - boxes, scores = layers.yolo_box(x, [10, 13, 30, 13], 10, 0.01, 32) + img_size = layers.data(name='x', shape=[2], dtype='int32') + boxes, scores = layers.yolo_box(x, img_size, [10, 13, 30, 13], 10, + 0.01, 32) self.assertIsNotNone(boxes) self.assertIsNotNone(scores) diff --git a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py index bed0be9a5..48465c8f6 100644 --- a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py @@ -25,7 +25,7 @@ def sigmoid(x): return 1.0 / (1.0 + np.exp(-1.0 * x)) -def YoloBox(x, attrs): +def YoloBox(x, img_size, attrs): n, c, h, w = x.shape anchors = attrs['anchors'] an_num = int(len(anchors) // 2) @@ -56,15 +56,14 @@ def YoloBox(x, attrs): pred_box = pred_box * (pred_conf > 0.).astype('float32') pred_box = pred_box.reshape((n, -1, 4)) - pred_box[:, :, : - 2], pred_box[:, :, 2: - 4] = pred_box[:, :, : - 2] - pred_box[:, :, 2: - 4] / 2., pred_box[:, :, : - 2] + pred_box[:, :, - 2: - 4] / 2.0 - pred_box = pred_box * input_size + pred_box[:, :, :2], pred_box[:, :, 2:4] = \ + pred_box[:, :, :2] - pred_box[:, :, 2:4] / 2., \ + pred_box[:, :, :2] + pred_box[:, :, 2:4] / 2.0 + # pred_box = pred_box * input_size + pred_box[:, :, 0] = pred_box[:, :, 0] * img_size[:, 1][:, np.newaxis] + pred_box[:, :, 1] = pred_box[:, :, 1] * img_size[:, 0][:, np.newaxis] + pred_box[:, :, 2] = pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis] + pred_box[:, :, 3] = pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis] return pred_box, pred_score.reshape((n, -1, class_num)) @@ -74,6 +73,7 @@ class TestYoloBoxOp(OpTest): self.initTestCase() self.op_type = 'yolo_box' x = np.random.random(self.x_shape).astype('float32') + img_size = np.random.randint(10, 20, self.imgsize_shape).astype('int32') self.attrs = { "anchors": self.anchors, @@ -82,8 +82,11 @@ class TestYoloBoxOp(OpTest): "downsample": self.downsample, } - self.inputs = {'X': x, } - boxes, scores = YoloBox(x, self.attrs) + self.inputs = { + 'X': x, + 'ImgSize': img_size, + } + boxes, scores = YoloBox(x, img_size, self.attrs) self.outputs = { "Boxes": boxes, "Scores": scores, @@ -95,10 +98,12 @@ class TestYoloBoxOp(OpTest): def initTestCase(self): self.anchors = [10, 13, 16, 30, 33, 23] an_num = int(len(self.anchors) // 2) + self.batch_size = 3 self.class_num = 2 self.conf_thresh = 0.5 self.downsample = 32 - self.x_shape = (3, an_num * (5 + self.class_num), 5, 5) + self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 5, 5) + self.imgsize_shape = (self.batch_size, 2) if __name__ == "__main__": -- GitLab