提交 9c561134 编写于 作者: M Megvii Engine Team

feat(mge): do not export F.nn.nms

GitOrigin-RevId: 7a0a328c37ff45d10e75a9880a35a03ddbd96a44
上级 56d624f5
...@@ -43,7 +43,6 @@ __all__ = [ ...@@ -43,7 +43,6 @@ __all__ = [
"logsoftmax", "logsoftmax",
"matmul", "matmul",
"max_pool2d", "max_pool2d",
"nms",
"one_hot", "one_hot",
"prelu", "prelu",
"roi_align", "roi_align",
...@@ -1482,7 +1481,7 @@ def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor: ...@@ -1482,7 +1481,7 @@ def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor:
x[:,2:] = np.random.rand(100,2)*20 + 100 x[:,2:] = np.random.rand(100,2)*20 + 100
scores = tensor(np.random.rand(100)) scores = tensor(np.random.rand(100))
inp = tensor(x) inp = tensor(x)
result = F.nms(inp, scores, iou_thresh=0.7) result = F.nn.nms(inp, scores, iou_thresh=0.7)
print(result.numpy()) print(result.numpy())
Outputs: Outputs:
......
...@@ -357,7 +357,7 @@ def test_nms(): ...@@ -357,7 +357,7 @@ def test_nms():
) )
inp = tensor(x) inp = tensor(x)
scores = tensor([0.5, 0.8, 0.9, 0.6], dtype=np.float32) scores = tensor([0.5, 0.8, 0.9, 0.6], dtype=np.float32)
result = F.nms(inp, scores=scores, iou_thresh=0.5) result = F.nn.nms(inp, scores=scores, iou_thresh=0.5)
np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32)) np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册