diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 9c79629699981e086d822e96ae2178285ae64953..05165bf5ad910166c37672a0d72d802bcdfc7aa0 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -673,7 +673,7 @@ def topk( :param descending: if True, return the largest elements instead. Default: False :param kth_only: if True, only the k-th element will be returned. Default: False :param no_sort: if True, the returned elements can be unordered. Default: False - :return: tuple of two tensors `(topk_tensor, indices_of_int32)`. + :return: tuple of two tensors ``(topk_tensor, indices_of_int32)`` Examples: @@ -695,7 +695,7 @@ def topk( """ if descending: - inp = -inp + k = -k if kth_only: mode = "kth_only" @@ -709,21 +709,25 @@ def topk( (k,) = Const(k, dtype="int32", device=inp.device)() if len(inp.shape) == 1: - inp = inp.reshape(1, -1) - res = apply(op, inp, k) if kth_only: - tns = res[0] + (tns,) = apply(op, expand_dims(inp, 0), k) + # FIXME: + # could use a dedicated kernel + # gradient may be routed to other indices if k-th value is not unique + ind = argmax((tns == inp).astype("int8")) + tns = squeeze(tns, 0) else: - tns, ind = res[0][0], res[1][0] + tns, ind = apply(op, expand_dims(inp, 0), k) + tns = squeeze(tns, 0) + ind = squeeze(ind, 0) else: - res = apply(op, inp, k) if kth_only: - tns = res + (tns,) = apply(op, inp, k) + # FIXME: same as above + ind = argmax((expand_dims(tns, 1) == inp).astype("int8"), 1) else: - tns, ind = res[0], res[1] + tns, ind = apply(op, inp, k) - if descending: - tns = -tns return tns, ind diff --git a/imperative/python/test/unit/functional/test_math.py b/imperative/python/test/unit/functional/test_math.py index a14e8b54de7ae19d558d31a49c9638123110e4bb..d013dfbd7d8014d72dc6fc73aba8e65f371ae6e2 100644 --- a/imperative/python/test/unit/functional/test_math.py +++ b/imperative/python/test/unit/functional/test_math.py @@ -168,3 +168,39 @@ def test_has_inf(): data[0][0][0][0] = float("inf") rst = F.math._has_inf(tensor(data)) np.testing.assert_equal(rst.numpy(), [1]) + + +@pytest.mark.parametrize("descending", [True, False]) +@pytest.mark.parametrize("sorted", [True, False]) +@pytest.mark.parametrize("inp1d", [True, False]) +@pytest.mark.parametrize("kth_only", [True, False]) +def test_topk(descending, sorted, inp1d, kth_only): + k = 3 + if inp1d: + data = np.random.permutation(7) + else: + data = np.random.permutation(5 * 7).reshape(5, 7) + data = data.astype(np.int32) + + def np_sort(x): + if descending: + return np.sort(x)[..., ::-1] + return np.sort(x) + + res = F.topk( + tensor(data), k, descending=descending, no_sort=(not sorted), kth_only=kth_only + ) + + values, indices = res + values = values.numpy() + indices = indices.numpy() + if kth_only: + np.testing.assert_equal( + values, np.take_along_axis(data, indices[..., None], -1).squeeze(-1) + ) + np.testing.assert_equal(values, np_sort(data)[..., k - 1]) + else: + np.testing.assert_equal(values, np.take_along_axis(data, indices, -1)) + if not sorted: + values = np_sort(values) + np.testing.assert_equal(values, np_sort(data)[..., :k])