diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 14e4fe83ae4c83334728b7c42ceb7a733c68f817..527f77213abd337304933045d29d53c7adc4b2c1 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1454,7 +1454,7 @@ def indexing_one_hot( return result -def nms(boxes: Tensor, iou_thresh: float, scores: Optional[Tensor] = None) -> Tensor: +def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor: r""" Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU). @@ -1477,7 +1477,7 @@ def nms(boxes: Tensor, iou_thresh: float, scores: Optional[Tensor] = None) -> Te x[:,2:] = np.random.rand(100,2)*20 + 100 scores = tensor(np.random.rand(100)) inp = tensor(x) - result = F.nms(inp, iou_thresh=0.7, scores=scores) + result = F.nms(inp, scores, iou_thresh=0.7) print(result.numpy()) Outputs: @@ -1490,24 +1490,27 @@ def nms(boxes: Tensor, iou_thresh: float, scores: Optional[Tensor] = None) -> Te assert ( boxes.ndim == 2 and boxes.shape[1] == 4 ), "the expected shape of boxes is (N, 4)" + assert scores.ndim == 1, "the expected shape of scores is (N,)" + assert ( + boxes.shape[0] == scores.shape[0] + ), "number of boxes and scores are not matched" - sorted_idx = None - if not scores is None: - assert scores.ndim == 1, "the expected shape of scores is (N,)" - sorted_idx = argsort(scores, descending=True) - boxes = boxes[sorted_idx] + boxes = boxes.detach() + scores = scores.detach() + sorted_idx = argsort(scores, descending=True) + boxes = boxes[sorted_idx] max_output = boxes.shape[0] op = builtin.NMSKeep(iou_thresh, max_output) inp = utils.convert_inputs(boxes.reshape(1, -1, 4)) indices, count = apply(op, *inp) indices = indices[0][: count.item()] - ret = sorted_idx[indices] if sorted_idx is not None else indices - return ret + keep_inds = sorted_idx[indices] + return keep_inds def batched_nms( - boxes: Tensor, iou_thresh: float, idxs: Tensor, scores: Optional[Tensor] = None + boxes: Tensor, scores: Tensor, idxs: Tensor, iou_thresh: float, ) -> Tensor: r""" Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU). @@ -1533,7 +1536,7 @@ def batched_nms( scores = tensor(np.random.rand(100)) idxs = tensor(np.random.randint(0, 10, 100)) inp = tensor(x) - result = F.batched_nms(inp, iou_thresh=0.6, idxs=idxs, scores=scores) + result = F.batched_nms(inp, scores, idxs, iou_thresh=0.6) print(result.numpy()) Outputs: @@ -1546,20 +1549,24 @@ def batched_nms( assert ( boxes.ndim == 2 and boxes.shape[1] == 4 ), "the expected shape of boxes is (N, 4)" + assert scores.ndim == 1, "the expected shape of scores is (N,)" + assert idxs.ndim == 1, "the expected shape of idxs is (N,)" + assert boxes.shape[0] == scores.shape[0] == idxs.shape[0] + + boxes = boxes.detach() + scores = scores.detach() + idxs = idxs.detach() max_coordinate = boxes.max() offsets = idxs.astype("float32") * (max_coordinate + 1) boxes = boxes + offsets.reshape(-1, 1).broadcast(boxes.shape[0], 4) - sorted_idx = None - if not scores is None: - assert scores.ndim == 1, "the expected shape of scores is (N,)" - sorted_idx = argsort(scores, descending=True) - boxes = boxes[sorted_idx] + sorted_idx = argsort(scores, descending=True) + boxes = boxes[sorted_idx] max_output = boxes.shape[0] op = builtin.NMSKeep(iou_thresh, max_output) inp = utils.convert_inputs(boxes.reshape(1, -1, 4)) indices, count = apply(op, *inp) indices = indices[0][: count.item()] - ret = sorted_idx[indices] if sorted_idx is not None else indices - return ret + keep_inds = sorted_idx[indices] + return keep_inds diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index ef1e3a76e89512631a02543ef9ef120ccbf2bafa..e63ed3462d18c98267b4f8d4817d563f66c889d9 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -231,6 +231,9 @@ def concat( [ 9. 10. 11.]] """ + if len(inps) == 1: + return inps[0] + dtype = dtype_promotion(inps) device = get_device(inps) diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 58a582d0542742826ca32df7067ea2bf3cffc83b..3401ad102bd2c15d8fb6b325b0e9afeb8745a7a2 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -470,7 +470,7 @@ def test_nms(): ) inp = tensor(x) scores = tensor([0.5, 0.8, 0.9, 0.6], dtype=np.float32) - result = F.nms(inp, iou_thresh=0.5, scores=scores) + result = F.nms(inp, scores=scores, iou_thresh=0.5) np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32)) @@ -489,7 +489,7 @@ def test_batched_nms(): inp = tensor(x) scores = tensor([0.6, 0.9, 0.5, 0.6, 0.8, 0.7], dtype=np.float32) idxs = tensor([0, 1, 0, 1, 0, 1], dtype=np.int32) - results = F.batched_nms(inp, iou_thresh=0.5, idxs=idxs, scores=scores) + results = F.batched_nms(inp, scores=scores, idxs=idxs, iou_thresh=0.5) np.testing.assert_equal(results.numpy(), np.array([1, 4, 5], dtype=np.int32))