未验证 提交 99a11e38 编写于 作者: W wangguanzhong 提交者: GitHub

enhance unittest for yolo_box (#33070)

上级 b8e4ec7d
...@@ -36,7 +36,8 @@ def YoloBox(x, img_size, attrs): ...@@ -36,7 +36,8 @@ def YoloBox(x, img_size, attrs):
clip_bbox = attrs['clip_bbox'] clip_bbox = attrs['clip_bbox']
scale_x_y = attrs['scale_x_y'] scale_x_y = attrs['scale_x_y']
bias_x_y = -0.5 * (scale_x_y - 1.) bias_x_y = -0.5 * (scale_x_y - 1.)
input_size = downsample * h input_h = downsample * h
input_w = downsample * w
x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
...@@ -50,7 +51,7 @@ def YoloBox(x, img_size, attrs): ...@@ -50,7 +51,7 @@ def YoloBox(x, img_size, attrs):
anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)] anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
anchors_s = np.array( anchors_s = np.array(
[(an_w / input_size, an_h / input_size) for an_w, an_h in anchors]) [(an_w / input_w, an_h / input_h) for an_w, an_h in anchors])
anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1)) anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1))
anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1)) anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1))
pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w
...@@ -191,5 +192,19 @@ class TestYoloBoxStatic(unittest.TestCase): ...@@ -191,5 +192,19 @@ class TestYoloBoxStatic(unittest.TestCase):
assert boxes is not None and scores is not None assert boxes is not None and scores is not None
class TestYoloBoxOpHW(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int(len(self.anchors) // 2)
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = False
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 9)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册