From 99a11e388d69ff641a36c7cddddfcd49d3e73147 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Mon, 24 May 2021 18:56:56 +0800 Subject: [PATCH] enhance unittest for yolo_box (#33070) --- .../fluid/tests/unittests/test_yolo_box_op.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py index 844115d4ace..24c463ebfc9 100644 --- a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py @@ -36,7 +36,8 @@ def YoloBox(x, img_size, attrs): clip_bbox = attrs['clip_bbox'] scale_x_y = attrs['scale_x_y'] bias_x_y = -0.5 * (scale_x_y - 1.) - input_size = downsample * h + input_h = downsample * h + input_w = downsample * w x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) @@ -50,7 +51,7 @@ def YoloBox(x, img_size, attrs): anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)] anchors_s = np.array( - [(an_w / input_size, an_h / input_size) for an_w, an_h in anchors]) + [(an_w / input_w, an_h / input_h) for an_w, an_h in anchors]) anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1)) anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1)) pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w @@ -191,5 +192,19 @@ class TestYoloBoxStatic(unittest.TestCase): assert boxes is not None and scores is not None +class TestYoloBoxOpHW(TestYoloBoxOp): + def initTestCase(self): + self.anchors = [10, 13, 16, 30, 33, 23] + an_num = int(len(self.anchors) // 2) + self.batch_size = 32 + self.class_num = 2 + self.conf_thresh = 0.5 + self.downsample = 32 + self.clip_bbox = False + self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 9) + self.imgsize_shape = (self.batch_size, 2) + self.scale_x_y = 1. + + if __name__ == "__main__": unittest.main() -- GitLab