From 2f06d580b9d69415df34d114bbf67a61e006192c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 24 Jul 2023 20:35:44 +0800 Subject: [PATCH] feat(xla): add topk and sort for xla GitOrigin-RevId: 0e881f30429a8d849ad9cdd0e0f47c3e0921ff97 --- imperative/python/megengine/xla/ir_utils.py | 5 + .../python/megengine/xla/rules/indexing.py | 4 +- imperative/python/megengine/xla/rules/math.py | 195 +++++++++++++++++- .../python/megengine/xla/rules/tensor.py | 49 ++++- .../test/unit/xla/functional/test_xla_math.py | 107 +++++++++- .../unit/xla/functional/test_xla_tensor.py | 7 + 6 files changed, 352 insertions(+), 15 deletions(-) diff --git a/imperative/python/megengine/xla/ir_utils.py b/imperative/python/megengine/xla/ir_utils.py index a8e80b287..f45cd44fb 100644 --- a/imperative/python/megengine/xla/ir_utils.py +++ b/imperative/python/megengine/xla/ir_utils.py @@ -192,6 +192,7 @@ class TraceResult: dtype_to_str = { "float16": "f16", "float32": "f32", + "int8": "i8", "int32": "i32", "int64": "i64", "uint8": "u8", @@ -417,6 +418,10 @@ def f32_attr(i): return ir.FloatAttr.get(ir.F32Type.get(), i) +def bool_attr(i): + return ir.BoolAttr.get(i) + + def precision_attr(lhs_prec, rhs_prec) -> ir.ArrayAttr: lhs_prec = str(lhs_prec) rhs_prec = str(rhs_prec) diff --git a/imperative/python/megengine/xla/rules/indexing.py b/imperative/python/megengine/xla/rules/indexing.py index 5619c5dc8..1ceba258e 100644 --- a/imperative/python/megengine/xla/rules/indexing.py +++ b/imperative/python/megengine/xla/rules/indexing.py @@ -66,7 +66,7 @@ def _hslice_with_step_is_one(inp, slices): def _hslice_with_any_step(inp, slices): """ - if inp_shape is N-dim, slices should contain N slice, slice can not None + if inp_shape is N-dim, slices should contain N slice, slice can not None. for shape [12, 15], slices can be [slice(0, 3, 1), slice(12, 15, 1)] """ starts = [int(sl.start) for sl in slices] @@ -83,7 +83,7 @@ def _hslice_with_any_step(inp, slices): def index_with_slices(inp, slices): """ - if inp_shape is N-dim, slices should contain N slice, slice can be None + if inp_shape is N-dim, slices should contain N slice, slice can be None. for shape [12, 15], slices can be [slice(0, 3, 1), slice(12, 15, 1)] or [None, None] """ assert isinstance(slices, Sequence), f"{slices}" diff --git a/imperative/python/megengine/xla/rules/math.py b/imperative/python/megengine/xla/rules/math.py index 7e8dd2044..af127308c 100644 --- a/imperative/python/megengine/xla/rules/math.py +++ b/imperative/python/megengine/xla/rules/math.py @@ -4,9 +4,13 @@ import numpy as np from ...core._imperative_rt import ops as mops from .. import ir_utils -from ..ir_utils import i64_attr +from ..ir_utils import bool_attr, i64_attr +from ..lib.mlir import ir from ..lib.mlir.dialects import chlo, hlo +from ..utils import flatten_list from .hlotensor import HLOTensor +from .indexing import ScatterDimensionNumbers, scatter +from .tensor import concat, expand_dims, fill, iota from .utils import _can_broadcast_to, _shape_equal, register_lower_rule @@ -241,5 +245,192 @@ def batched_matmul_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): ).transpose(permutation) +def _sort_according_to_key(key, *vals, axis=-1, descending=True, is_stable=True): + """ + sort key and vals in the specified axis, return the sorted key and vals. + key and vals should have the same shape, then we reorder both key and vals according + to the value of the key. + + example 1: (implement argsort) + inp: 1.7783 -> 0, -1.8184 -> 1, 1.0701 -> 2 + [[ 1.7783 -1.8184 1.0701] + [-0.0712 -1.4623 1.3243]] + [[0 1 2] + [0 1 2]] + axis: -1 + descend: True + return: after reorder, 1.7783 -> 0, -1.8184 -> 1, 1.0701 -> 2 + [[ 1.7783 1.0701 -1.8184] + [ 1.3243 -0.0712 -1.4623]] + [[0 2 1] + [2 0 1]] + + example 2: + inp: + [[0 2 1] + [2 0 1]] + [[ 1.7783 1.0701 -1.8184] + [ 1.3243 -0.0712 -1.4623]] + axis: -1 + descend: False + return: + [[0 1 2] + [0 1 2]] + [[ 1.7783 -1.8184 1.0701] + [-0.0712 -1.4623 1.3243]] + """ + + for val in vals: + assert _shape_equal( + key.shape, val.shape + ), f"sort key and vals shape mismatch: {key.shape}, {val.shape}" + + axis = axis + key.ndim if axis < 0 else axis + sorted_key = ir_utils.make_ir_type_according_meta(key.shape, key.dtype) + sorted_vals = [ + ir_utils.make_ir_type_according_meta(val.shape, val.dtype) for val in vals + ] + + sort_op = hlo.SortOp( + [sorted_key, *sorted_vals], + [key.tensor, *[val.tensor for val in vals]], + dimension=i64_attr(axis), + is_stable=bool_attr(is_stable), + ) + + key_type = ir_utils.make_ir_type_according_meta(tuple(), key.dtype) + val_types = [ + ir_utils.make_ir_type_according_meta(tuple(), val.dtype) for val in vals + ] + arg_types = [key_type] + val_types + + comparator = sort_op.comparator.blocks.append( + *flatten_list(zip(arg_types, arg_types)) + ) + with ir.InsertionPoint(comparator): + lhs = HLOTensor(comparator.arguments[0]) + rhs = HLOTensor(comparator.arguments[1]) + + if descending: + hlo.ReturnOp([(lhs > rhs).tensor]) + else: + hlo.ReturnOp([(lhs < rhs).tensor]) + + assert len(sort_op.results) == len(vals) + 1, f"{len(vals)}, {len(sort_op.results)}" + return (HLOTensor(ret) for ret in sort_op.results) + + +def argsort(inp, axis=-1, descending=True, is_stable=True): + """ + sort inp in the specfic axis, and return the sorted value and index + for example: + inp: + [[ 1.7783 -1.8184 1.0701] + [-0.0712 -1.4623 1.3243]] + axis: -1 + descend: True + return: + [[ 1.7783 1.0701 -1.8184] + [ 1.3243 -0.0712 -1.4623]] + [[0 2 1] + [2 0 1]] + """ + axis = axis + inp.ndim if axis < 0 else axis + idx = iota(np.int32, inp.shape, axis) + return _sort_according_to_key( + inp, idx, axis=axis, descending=descending, is_stable=is_stable + ) + + +@register_lower_rule(mops.Argsort) +def argsort_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): + assert ( + len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 2 + ), f"{len(args)}, {len(ctx.vars_in)}, {len(ctx.vars_out)}" + assert ctx.op.order in [ + mops.Argsort.Order.DESCENDING, + mops.Argsort.Order.ASCENDING, + ], f"{ctx.op.order}" + descending = ctx.op.order == mops.Argsort.Order.DESCENDING + axis = args[0].ndim - 1 # megengine only support sort in the last dimension + return argsort(args[0], axis, descending, is_stable=True) + + +@register_lower_rule("ArgsortBackward") +def argsort_backward_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): + assert ( + len(args) == 3 and len(ctx.vars_in) == 3 and len(ctx.vars_out) == 1 + ), f"{len(args)}, {len(ctx.vars_in)}, {len(ctx.vars_out)}" + dy, idx, x = args[0], args[1], args[2] + if _shape_equal(x.shape, dy.shape): + # for argsort backward + _, dx = _sort_according_to_key( + idx, dy, axis=-1, descending=False, is_stable=True + ) + else: + # for topk backward, only support axis=-1 and the dx is 2d tensor + dx = fill(0, ctx.vars_out[0].shape, ctx.vars_out[0].dtype) + expander = iota(np.int32, idx.shape, dimension=0) + idx = expand_dims(idx, -1) + expander = expand_dims(expander, -1) + idx = concat([expander, idx], -1) + + dnums = ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1), + ) + dx = scatter(dx, idx, dy, dnums, unique_indices=True) + return dx + + def topk(inp, k, descending=True, kth_only=False, no_sort=False): - return [HLOTensor(rst) for rst in chlo.TopKOp(inp.tensor, i64_attr(k)).results] + """ + do topk in the last dimension of inp, for example: + inp.shape = (2, 3, 4), k = 2, out_shape = (2, 3, 2) + """ + assert k > 0, f"k of topk must bigger than 0, get {k}" + assert no_sort == False, f"no_sort must be False now" + assert kth_only == False, f"kth_only is not support now" + + if descending == True: + out, idx = [ + HLOTensor(rst) for rst in chlo.TopKOp(inp.tensor, i64_attr(k)).results + ] + else: + inp = -inp + out, idx = [ + HLOTensor(rst) for rst in chlo.TopKOp(inp.tensor, i64_attr(k)).results + ] + out = -out + + return out, idx + + +@register_lower_rule(mops.TopK) +def topk_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): + assert ( + len(args) == 2 and len(ctx.vars_in) == 2 + ), f"{len(args)}, {len(ctx.vars_in)}, {len(ctx.vars_out)}" + assert isinstance( + ctx.vars_in[1].bound_data, np.ndarray + ), f"{ctx.vars_in[1].bound_data}" + k = int(ctx.vars_in[1].bound_data) + + descending = True if k < 0 else False + k = -k if k < 0 else k + + if ctx.op.mode == mops.TopK.Mode.VALUE_IDX_SORTED: + assert len(ctx.vars_out) == 2, f"{len(ctx.vars_out)}" + kth_only, no_sort = False, False + elif ctx.op.mode == mops.TopK.Mode.VALUE_IDX_NOSORT: + assert len(ctx.vars_out) == 2, f"{len(ctx.vars_out)}" + kth_only, no_sort = False, True + else: + assert ( + ctx.op.mode == mops.TopK.Mode.KTH_ONLY + ), f"invalid mode for topk, {ctx.op.mode}" + kth_only, no_sort = True, False + assert len(ctx.vars_out) == 1, f"{len(ctx.vars_out)}" + + return topk(args[0], k, descending, kth_only, no_sort) diff --git a/imperative/python/megengine/xla/rules/tensor.py b/imperative/python/megengine/xla/rules/tensor.py index 64e21510c..555f32932 100644 --- a/imperative/python/megengine/xla/rules/tensor.py +++ b/imperative/python/megengine/xla/rules/tensor.py @@ -79,14 +79,13 @@ def transpose(inp, permutation): def expand_dims(inp, axis): assert isinstance(axis, int), f"only int axis supported, get {axis}" - axis = (axis + inp.ndim) if axis < 0 else axis - assert axis >= 0 and axis <= inp.ndim, f"invalid axis {axis} for {inp.shape}" + assert ( + axis >= -inp.ndim - 1 and axis <= inp.ndim + ), f"invalid axis {axis} for {inp.shape}" - dst_shape = [] - for i in range(inp.ndim): - if i == axis: - dst_shape.append(1) - dst_shape.append(inp.shape[i]) + dst_shape = list(inp.shape) + insert_pos = axis if axis >= 0 else (axis + inp.ndim + 1) + dst_shape.insert(insert_pos, 1) return inp.reshape(tuple(dst_shape)) @@ -94,14 +93,29 @@ def expand_dims(inp, axis): @register_lower_rule(mops.Dimshuffle) def dim_shuffle_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): assert len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 1 - permutation = ctx.op.pattern - return transpose(args[0], permutation) + # mge dimshuffle can do transpose and broadcast simutaneously + # for example: + # case1: (16, 32, 64) with pattern [0, 2, 1] -> (16, 64, 32) + # case2: (16, 32, 64) with pattern [0, -1, 2, -1, 1] -> (16, 1, 64, 1, 32) + # case3: (16, 1, 64, 1, 32) with pattern [0, 4, 2] -> (16, 32, 64) + + pattern = ctx.op.pattern + inp = args[0] + if len(pattern) == inp.ndim: + permutation = pattern + return transpose(inp, permutation) + elif len(pattern) > inp.ndim: + permutation = [item for item in pattern if item != -1] + return transpose(inp, permutation).reshape(ctx.vars_out[0].shape) + else: + permutation = [i for i in range(inp.ndim) if i not in pattern] + list(pattern) + return transpose(inp, permutation).reshape(ctx.vars_out[0].shape) def concat(inps, axis): assert len(inps) > 0, f"concat inputs should not be empty" if axis < 0: - axis = axis + inps[0].ndim[0] + axis = axis + inps[0].ndim hlo_inps = [inp.tensor for inp in inps] @@ -175,6 +189,21 @@ def fill(value, shape, dtype): return broadcast_to(HLOTensor(value, dtype=dtype), shape) +def iota(dtype, shape, dimension): + """ + do some thing like arange. + for example: + shape = (2, 3), dimension=1, output is [[0, 1, 2], [0, 1, 2]] + shape = (2, 3), dimension=-1, output is [[0, 0, 0], [1, 1, 1]] + """ + dimension = dimension + len(shape) if dimension < 0 else dimension + ret = hlo.IotaOp( + ir_utils.make_ir_type_according_meta(shape, dtype), ir_utils.i64_attr(dimension) + ).results + assert len(ret) == 1, f"{len(ret)}" + return HLOTensor(ret[0]) + + @register_lower_rule(mops.Fill) def fill_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): assert len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 1 diff --git a/imperative/python/test/unit/xla/functional/test_xla_math.py b/imperative/python/test/unit/xla/functional/test_xla_math.py index 09f763727..6afd5b6e9 100644 --- a/imperative/python/test/unit/xla/functional/test_xla_math.py +++ b/imperative/python/test/unit/xla/functional/test_xla_math.py @@ -31,7 +31,6 @@ def test_matmul(): return out, lhs.grad, rhs.grad mge_rsts = func(lhs, rhs, dout) - mge_rsts[0].numpy() xla_rsts = func(lhs, rhs, dout) for mge_rst, xla_rst in zip(mge_rsts, xla_rsts): @@ -79,3 +78,109 @@ def test_matmul(): tester((1, 2, 8, 7), (4, 2, 2, 9, 8), True, True) tester((1, 8, 7), (4, 3, 2, 8, 9), True, False) tester((1, 8, 7), (4, 3, 1, 9, 8), True, True) + + +@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38") +@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now") +@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now") +def test_sort_and_argsort(): + def tester(ishape, descending, dtype=None): + dtype = dtype or np.float32 + inp1 = tensor(np.random.randn(*ishape), dtype=dtype) + inp2 = tensor(np.random.randn(*ishape), dtype=dtype) + dout = tensor(np.random.randn(*ishape), dtype=dtype) + + gm = GradManager() + + @jit.xla_trace(without_host=True) + def func(inp1, inp2, dout): + gm.attach([inp1, inp2]) + with gm: + out, idx1 = F.sort(inp1, descending) + idx2 = F.argsort(inp2, -descending) + gm.backward(out, dout) + return out, idx1, idx2, inp1.grad + + mge_rsts = func(inp1, inp2, dout) + xla_rsts = func(inp1, inp2, dout) + + for mge_rst, xla_rst in zip(mge_rsts, xla_rsts): + np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5) + + for descending in [True, False]: + tester((16, 32), descending) + tester((16, 1), descending) + tester((1, 16), descending) + tester((1, 1), descending) + tester((16,), descending) + tester((1,), descending) + + +@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38") +@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now") +@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now") +def test_topk(): + def tester(ishape, k, descending, kth_only, no_sort, dtype=None): + dtype = dtype or np.float32 + inp = tensor(np.random.randn(*ishape), dtype=dtype) + out, _ = F.topk(inp, k, descending, kth_only, no_sort) + dout = tensor(0.1 * np.random.randn(*out.shape), dtype=dtype) + + gm = GradManager() + + @jit.xla_trace(without_host=True) + def func(inp, dout): + gm.attach([inp]) + with gm: + out, index = F.topk(inp, k, descending, kth_only, no_sort) + gm.backward(out, dout) + return out, index, inp.grad + + mge_rsts = func(inp, dout) + xla_rsts = func(inp, dout) + + for mge_rst, xla_rst in zip(mge_rsts, xla_rsts): + np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5) + + for descending in [True, False]: + tester((2, 16,), 1, descending, False, False) + tester((2, 16,), 8, descending, False, False) + tester((1, 16,), 1, descending, False, False) + tester((1, 16,), 5, descending, False, False) + tester((16,), 8, descending, False, False) + tester((16,), 8, descending, False, False) + tester((1,), 1, descending, False, False) + tester((1,), 1, descending, False, False) + + +@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38") +@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now") +@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now") +def test_topk_accuracy(): + def tester(batch, nr_class, topk, dtype=None): + dtype = dtype or np.float32 + logits = tensor(np.random.uniform(0, 1, (batch, nr_class)), dtype=dtype) + target = tensor(np.random.randint(0, nr_class, (batch,), np.int32)) + out = F.topk_accuracy(logits, target, topk) + dout = tensor(0.1 * np.random.randn(*out.shape), dtype=dtype) + + gm = GradManager() + + @jit.xla_trace(without_host=True) + def func(logits, target, dout): + gm.attach([logits]) + with gm: + out = F.topk_accuracy(logits, target, topk) + gm.backward(out, dout) + return [out] + + mge_rsts = func(logits, target, dout) + xla_rsts = func(logits, target, dout) + + for mge_rst, xla_rst in zip(mge_rsts, xla_rsts): + np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5) + + tester(32, 1000, 10) + tester(32, 1, 1) + tester(1, 1000, 10) + tester(1, 1, 1) diff --git a/imperative/python/test/unit/xla/functional/test_xla_tensor.py b/imperative/python/test/unit/xla/functional/test_xla_tensor.py index b82cd506f..83964a689 100644 --- a/imperative/python/test/unit/xla/functional/test_xla_tensor.py +++ b/imperative/python/test/unit/xla/functional/test_xla_tensor.py @@ -113,6 +113,13 @@ def test_transpose(): tester((2, 3, 1), (0, 1, 2)) tester((2, 3, 1, 4), (3, 1, 0, 2)) + tester((1,), ("x", 0)) + # tester((1,), (0, 'x')) # bug for mge + tester((1, 2), ("x", 0, 1)) + tester((1, 2), (0, "x", 1)) + # tester((1, 2), (0, 1, 'x')) # bug for mge + tester((16, 32, 64), (0, "x", 2, "x", 1)) + @pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38") @pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now") -- GitLab