提交 03320a05 编写于 作者: M Megvii Engine Team

fix(functional): change to user-friendly nms api and fix concat

GitOrigin-RevId: 76b419027029874b8bedcafc40636f19ce754c74
上级 50d5421a
......@@ -1454,7 +1454,7 @@ def indexing_one_hot(
return result
def nms(boxes: Tensor, iou_thresh: float, scores: Optional[Tensor] = None) -> Tensor:
def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor:
r"""
Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU).
......@@ -1477,7 +1477,7 @@ def nms(boxes: Tensor, iou_thresh: float, scores: Optional[Tensor] = None) -> Te
x[:,2:] = np.random.rand(100,2)*20 + 100
scores = tensor(np.random.rand(100))
inp = tensor(x)
result = F.nms(inp, iou_thresh=0.7, scores=scores)
result = F.nms(inp, scores, iou_thresh=0.7)
print(result.numpy())
Outputs:
......@@ -1490,24 +1490,27 @@ def nms(boxes: Tensor, iou_thresh: float, scores: Optional[Tensor] = None) -> Te
assert (
boxes.ndim == 2 and boxes.shape[1] == 4
), "the expected shape of boxes is (N, 4)"
assert scores.ndim == 1, "the expected shape of scores is (N,)"
assert (
boxes.shape[0] == scores.shape[0]
), "number of boxes and scores are not matched"
sorted_idx = None
if not scores is None:
assert scores.ndim == 1, "the expected shape of scores is (N,)"
sorted_idx = argsort(scores, descending=True)
boxes = boxes[sorted_idx]
boxes = boxes.detach()
scores = scores.detach()
sorted_idx = argsort(scores, descending=True)
boxes = boxes[sorted_idx]
max_output = boxes.shape[0]
op = builtin.NMSKeep(iou_thresh, max_output)
inp = utils.convert_inputs(boxes.reshape(1, -1, 4))
indices, count = apply(op, *inp)
indices = indices[0][: count.item()]
ret = sorted_idx[indices] if sorted_idx is not None else indices
return ret
keep_inds = sorted_idx[indices]
return keep_inds
def batched_nms(
boxes: Tensor, iou_thresh: float, idxs: Tensor, scores: Optional[Tensor] = None
boxes: Tensor, scores: Tensor, idxs: Tensor, iou_thresh: float,
) -> Tensor:
r"""
Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU).
......@@ -1533,7 +1536,7 @@ def batched_nms(
scores = tensor(np.random.rand(100))
idxs = tensor(np.random.randint(0, 10, 100))
inp = tensor(x)
result = F.batched_nms(inp, iou_thresh=0.6, idxs=idxs, scores=scores)
result = F.batched_nms(inp, scores, idxs, iou_thresh=0.6)
print(result.numpy())
Outputs:
......@@ -1546,20 +1549,24 @@ def batched_nms(
assert (
boxes.ndim == 2 and boxes.shape[1] == 4
), "the expected shape of boxes is (N, 4)"
assert scores.ndim == 1, "the expected shape of scores is (N,)"
assert idxs.ndim == 1, "the expected shape of idxs is (N,)"
assert boxes.shape[0] == scores.shape[0] == idxs.shape[0]
boxes = boxes.detach()
scores = scores.detach()
idxs = idxs.detach()
max_coordinate = boxes.max()
offsets = idxs.astype("float32") * (max_coordinate + 1)
boxes = boxes + offsets.reshape(-1, 1).broadcast(boxes.shape[0], 4)
sorted_idx = None
if not scores is None:
assert scores.ndim == 1, "the expected shape of scores is (N,)"
sorted_idx = argsort(scores, descending=True)
boxes = boxes[sorted_idx]
sorted_idx = argsort(scores, descending=True)
boxes = boxes[sorted_idx]
max_output = boxes.shape[0]
op = builtin.NMSKeep(iou_thresh, max_output)
inp = utils.convert_inputs(boxes.reshape(1, -1, 4))
indices, count = apply(op, *inp)
indices = indices[0][: count.item()]
ret = sorted_idx[indices] if sorted_idx is not None else indices
return ret
keep_inds = sorted_idx[indices]
return keep_inds
......@@ -231,6 +231,9 @@ def concat(
[ 9. 10. 11.]]
"""
if len(inps) == 1:
return inps[0]
dtype = dtype_promotion(inps)
device = get_device(inps)
......
......@@ -470,7 +470,7 @@ def test_nms():
)
inp = tensor(x)
scores = tensor([0.5, 0.8, 0.9, 0.6], dtype=np.float32)
result = F.nms(inp, iou_thresh=0.5, scores=scores)
result = F.nms(inp, scores=scores, iou_thresh=0.5)
np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32))
......@@ -489,7 +489,7 @@ def test_batched_nms():
inp = tensor(x)
scores = tensor([0.6, 0.9, 0.5, 0.6, 0.8, 0.7], dtype=np.float32)
idxs = tensor([0, 1, 0, 1, 0, 1], dtype=np.int32)
results = F.batched_nms(inp, iou_thresh=0.5, idxs=idxs, scores=scores)
results = F.batched_nms(inp, scores=scores, idxs=idxs, iou_thresh=0.5)
np.testing.assert_equal(results.numpy(), np.array([1, 4, 5], dtype=np.int32))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册