提交 2f06d580 编写于 作者: M Megvii Engine Team

feat(xla): add topk and sort for xla

GitOrigin-RevId: 0e881f30429a8d849ad9cdd0e0f47c3e0921ff97
上级 b0470e73
...@@ -192,6 +192,7 @@ class TraceResult: ...@@ -192,6 +192,7 @@ class TraceResult:
dtype_to_str = { dtype_to_str = {
"float16": "f16", "float16": "f16",
"float32": "f32", "float32": "f32",
"int8": "i8",
"int32": "i32", "int32": "i32",
"int64": "i64", "int64": "i64",
"uint8": "u8", "uint8": "u8",
...@@ -417,6 +418,10 @@ def f32_attr(i): ...@@ -417,6 +418,10 @@ def f32_attr(i):
return ir.FloatAttr.get(ir.F32Type.get(), i) return ir.FloatAttr.get(ir.F32Type.get(), i)
def bool_attr(i):
return ir.BoolAttr.get(i)
def precision_attr(lhs_prec, rhs_prec) -> ir.ArrayAttr: def precision_attr(lhs_prec, rhs_prec) -> ir.ArrayAttr:
lhs_prec = str(lhs_prec) lhs_prec = str(lhs_prec)
rhs_prec = str(rhs_prec) rhs_prec = str(rhs_prec)
......
...@@ -66,7 +66,7 @@ def _hslice_with_step_is_one(inp, slices): ...@@ -66,7 +66,7 @@ def _hslice_with_step_is_one(inp, slices):
def _hslice_with_any_step(inp, slices): def _hslice_with_any_step(inp, slices):
""" """
if inp_shape is N-dim, slices should contain N slice, slice can not None if inp_shape is N-dim, slices should contain N slice, slice can not None.
for shape [12, 15], slices can be [slice(0, 3, 1), slice(12, 15, 1)] for shape [12, 15], slices can be [slice(0, 3, 1), slice(12, 15, 1)]
""" """
starts = [int(sl.start) for sl in slices] starts = [int(sl.start) for sl in slices]
...@@ -83,7 +83,7 @@ def _hslice_with_any_step(inp, slices): ...@@ -83,7 +83,7 @@ def _hslice_with_any_step(inp, slices):
def index_with_slices(inp, slices): def index_with_slices(inp, slices):
""" """
if inp_shape is N-dim, slices should contain N slice, slice can be None if inp_shape is N-dim, slices should contain N slice, slice can be None.
for shape [12, 15], slices can be [slice(0, 3, 1), slice(12, 15, 1)] or [None, None] for shape [12, 15], slices can be [slice(0, 3, 1), slice(12, 15, 1)] or [None, None]
""" """
assert isinstance(slices, Sequence), f"{slices}" assert isinstance(slices, Sequence), f"{slices}"
......
...@@ -4,9 +4,13 @@ import numpy as np ...@@ -4,9 +4,13 @@ import numpy as np
from ...core._imperative_rt import ops as mops from ...core._imperative_rt import ops as mops
from .. import ir_utils from .. import ir_utils
from ..ir_utils import i64_attr from ..ir_utils import bool_attr, i64_attr
from ..lib.mlir import ir
from ..lib.mlir.dialects import chlo, hlo from ..lib.mlir.dialects import chlo, hlo
from ..utils import flatten_list
from .hlotensor import HLOTensor from .hlotensor import HLOTensor
from .indexing import ScatterDimensionNumbers, scatter
from .tensor import concat, expand_dims, fill, iota
from .utils import _can_broadcast_to, _shape_equal, register_lower_rule from .utils import _can_broadcast_to, _shape_equal, register_lower_rule
...@@ -241,5 +245,192 @@ def batched_matmul_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): ...@@ -241,5 +245,192 @@ def batched_matmul_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
).transpose(permutation) ).transpose(permutation)
def _sort_according_to_key(key, *vals, axis=-1, descending=True, is_stable=True):
"""
sort key and vals in the specified axis, return the sorted key and vals.
key and vals should have the same shape, then we reorder both key and vals according
to the value of the key.
example 1: (implement argsort)
inp: 1.7783 -> 0, -1.8184 -> 1, 1.0701 -> 2
[[ 1.7783 -1.8184 1.0701]
[-0.0712 -1.4623 1.3243]]
[[0 1 2]
[0 1 2]]
axis: -1
descend: True
return: after reorder, 1.7783 -> 0, -1.8184 -> 1, 1.0701 -> 2
[[ 1.7783 1.0701 -1.8184]
[ 1.3243 -0.0712 -1.4623]]
[[0 2 1]
[2 0 1]]
example 2:
inp:
[[0 2 1]
[2 0 1]]
[[ 1.7783 1.0701 -1.8184]
[ 1.3243 -0.0712 -1.4623]]
axis: -1
descend: False
return:
[[0 1 2]
[0 1 2]]
[[ 1.7783 -1.8184 1.0701]
[-0.0712 -1.4623 1.3243]]
"""
for val in vals:
assert _shape_equal(
key.shape, val.shape
), f"sort key and vals shape mismatch: {key.shape}, {val.shape}"
axis = axis + key.ndim if axis < 0 else axis
sorted_key = ir_utils.make_ir_type_according_meta(key.shape, key.dtype)
sorted_vals = [
ir_utils.make_ir_type_according_meta(val.shape, val.dtype) for val in vals
]
sort_op = hlo.SortOp(
[sorted_key, *sorted_vals],
[key.tensor, *[val.tensor for val in vals]],
dimension=i64_attr(axis),
is_stable=bool_attr(is_stable),
)
key_type = ir_utils.make_ir_type_according_meta(tuple(), key.dtype)
val_types = [
ir_utils.make_ir_type_according_meta(tuple(), val.dtype) for val in vals
]
arg_types = [key_type] + val_types
comparator = sort_op.comparator.blocks.append(
*flatten_list(zip(arg_types, arg_types))
)
with ir.InsertionPoint(comparator):
lhs = HLOTensor(comparator.arguments[0])
rhs = HLOTensor(comparator.arguments[1])
if descending:
hlo.ReturnOp([(lhs > rhs).tensor])
else:
hlo.ReturnOp([(lhs < rhs).tensor])
assert len(sort_op.results) == len(vals) + 1, f"{len(vals)}, {len(sort_op.results)}"
return (HLOTensor(ret) for ret in sort_op.results)
def argsort(inp, axis=-1, descending=True, is_stable=True):
"""
sort inp in the specfic axis, and return the sorted value and index
for example:
inp:
[[ 1.7783 -1.8184 1.0701]
[-0.0712 -1.4623 1.3243]]
axis: -1
descend: True
return:
[[ 1.7783 1.0701 -1.8184]
[ 1.3243 -0.0712 -1.4623]]
[[0 2 1]
[2 0 1]]
"""
axis = axis + inp.ndim if axis < 0 else axis
idx = iota(np.int32, inp.shape, axis)
return _sort_according_to_key(
inp, idx, axis=axis, descending=descending, is_stable=is_stable
)
@register_lower_rule(mops.Argsort)
def argsort_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert (
len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 2
), f"{len(args)}, {len(ctx.vars_in)}, {len(ctx.vars_out)}"
assert ctx.op.order in [
mops.Argsort.Order.DESCENDING,
mops.Argsort.Order.ASCENDING,
], f"{ctx.op.order}"
descending = ctx.op.order == mops.Argsort.Order.DESCENDING
axis = args[0].ndim - 1 # megengine only support sort in the last dimension
return argsort(args[0], axis, descending, is_stable=True)
@register_lower_rule("ArgsortBackward")
def argsort_backward_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert (
len(args) == 3 and len(ctx.vars_in) == 3 and len(ctx.vars_out) == 1
), f"{len(args)}, {len(ctx.vars_in)}, {len(ctx.vars_out)}"
dy, idx, x = args[0], args[1], args[2]
if _shape_equal(x.shape, dy.shape):
# for argsort backward
_, dx = _sort_according_to_key(
idx, dy, axis=-1, descending=False, is_stable=True
)
else:
# for topk backward, only support axis=-1 and the dx is 2d tensor
dx = fill(0, ctx.vars_out[0].shape, ctx.vars_out[0].dtype)
expander = iota(np.int32, idx.shape, dimension=0)
idx = expand_dims(idx, -1)
expander = expand_dims(expander, -1)
idx = concat([expander, idx], -1)
dnums = ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1),
scatter_dims_to_operand_dims=(0, 1),
)
dx = scatter(dx, idx, dy, dnums, unique_indices=True)
return dx
def topk(inp, k, descending=True, kth_only=False, no_sort=False): def topk(inp, k, descending=True, kth_only=False, no_sort=False):
return [HLOTensor(rst) for rst in chlo.TopKOp(inp.tensor, i64_attr(k)).results] """
do topk in the last dimension of inp, for example:
inp.shape = (2, 3, 4), k = 2, out_shape = (2, 3, 2)
"""
assert k > 0, f"k of topk must bigger than 0, get {k}"
assert no_sort == False, f"no_sort must be False now"
assert kth_only == False, f"kth_only is not support now"
if descending == True:
out, idx = [
HLOTensor(rst) for rst in chlo.TopKOp(inp.tensor, i64_attr(k)).results
]
else:
inp = -inp
out, idx = [
HLOTensor(rst) for rst in chlo.TopKOp(inp.tensor, i64_attr(k)).results
]
out = -out
return out, idx
@register_lower_rule(mops.TopK)
def topk_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert (
len(args) == 2 and len(ctx.vars_in) == 2
), f"{len(args)}, {len(ctx.vars_in)}, {len(ctx.vars_out)}"
assert isinstance(
ctx.vars_in[1].bound_data, np.ndarray
), f"{ctx.vars_in[1].bound_data}"
k = int(ctx.vars_in[1].bound_data)
descending = True if k < 0 else False
k = -k if k < 0 else k
if ctx.op.mode == mops.TopK.Mode.VALUE_IDX_SORTED:
assert len(ctx.vars_out) == 2, f"{len(ctx.vars_out)}"
kth_only, no_sort = False, False
elif ctx.op.mode == mops.TopK.Mode.VALUE_IDX_NOSORT:
assert len(ctx.vars_out) == 2, f"{len(ctx.vars_out)}"
kth_only, no_sort = False, True
else:
assert (
ctx.op.mode == mops.TopK.Mode.KTH_ONLY
), f"invalid mode for topk, {ctx.op.mode}"
kth_only, no_sort = True, False
assert len(ctx.vars_out) == 1, f"{len(ctx.vars_out)}"
return topk(args[0], k, descending, kth_only, no_sort)
...@@ -79,14 +79,13 @@ def transpose(inp, permutation): ...@@ -79,14 +79,13 @@ def transpose(inp, permutation):
def expand_dims(inp, axis): def expand_dims(inp, axis):
assert isinstance(axis, int), f"only int axis supported, get {axis}" assert isinstance(axis, int), f"only int axis supported, get {axis}"
axis = (axis + inp.ndim) if axis < 0 else axis assert (
assert axis >= 0 and axis <= inp.ndim, f"invalid axis {axis} for {inp.shape}" axis >= -inp.ndim - 1 and axis <= inp.ndim
), f"invalid axis {axis} for {inp.shape}"
dst_shape = [] dst_shape = list(inp.shape)
for i in range(inp.ndim): insert_pos = axis if axis >= 0 else (axis + inp.ndim + 1)
if i == axis: dst_shape.insert(insert_pos, 1)
dst_shape.append(1)
dst_shape.append(inp.shape[i])
return inp.reshape(tuple(dst_shape)) return inp.reshape(tuple(dst_shape))
...@@ -94,14 +93,29 @@ def expand_dims(inp, axis): ...@@ -94,14 +93,29 @@ def expand_dims(inp, axis):
@register_lower_rule(mops.Dimshuffle) @register_lower_rule(mops.Dimshuffle)
def dim_shuffle_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): def dim_shuffle_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 1 assert len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 1
permutation = ctx.op.pattern # mge dimshuffle can do transpose and broadcast simutaneously
return transpose(args[0], permutation) # for example:
# case1: (16, 32, 64) with pattern [0, 2, 1] -> (16, 64, 32)
# case2: (16, 32, 64) with pattern [0, -1, 2, -1, 1] -> (16, 1, 64, 1, 32)
# case3: (16, 1, 64, 1, 32) with pattern [0, 4, 2] -> (16, 32, 64)
pattern = ctx.op.pattern
inp = args[0]
if len(pattern) == inp.ndim:
permutation = pattern
return transpose(inp, permutation)
elif len(pattern) > inp.ndim:
permutation = [item for item in pattern if item != -1]
return transpose(inp, permutation).reshape(ctx.vars_out[0].shape)
else:
permutation = [i for i in range(inp.ndim) if i not in pattern] + list(pattern)
return transpose(inp, permutation).reshape(ctx.vars_out[0].shape)
def concat(inps, axis): def concat(inps, axis):
assert len(inps) > 0, f"concat inputs should not be empty" assert len(inps) > 0, f"concat inputs should not be empty"
if axis < 0: if axis < 0:
axis = axis + inps[0].ndim[0] axis = axis + inps[0].ndim
hlo_inps = [inp.tensor for inp in inps] hlo_inps = [inp.tensor for inp in inps]
...@@ -175,6 +189,21 @@ def fill(value, shape, dtype): ...@@ -175,6 +189,21 @@ def fill(value, shape, dtype):
return broadcast_to(HLOTensor(value, dtype=dtype), shape) return broadcast_to(HLOTensor(value, dtype=dtype), shape)
def iota(dtype, shape, dimension):
"""
do some thing like arange.
for example:
shape = (2, 3), dimension=1, output is [[0, 1, 2], [0, 1, 2]]
shape = (2, 3), dimension=-1, output is [[0, 0, 0], [1, 1, 1]]
"""
dimension = dimension + len(shape) if dimension < 0 else dimension
ret = hlo.IotaOp(
ir_utils.make_ir_type_according_meta(shape, dtype), ir_utils.i64_attr(dimension)
).results
assert len(ret) == 1, f"{len(ret)}"
return HLOTensor(ret[0])
@register_lower_rule(mops.Fill) @register_lower_rule(mops.Fill)
def fill_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): def fill_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 1 assert len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 1
......
...@@ -31,7 +31,6 @@ def test_matmul(): ...@@ -31,7 +31,6 @@ def test_matmul():
return out, lhs.grad, rhs.grad return out, lhs.grad, rhs.grad
mge_rsts = func(lhs, rhs, dout) mge_rsts = func(lhs, rhs, dout)
mge_rsts[0].numpy()
xla_rsts = func(lhs, rhs, dout) xla_rsts = func(lhs, rhs, dout)
for mge_rst, xla_rst in zip(mge_rsts, xla_rsts): for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
...@@ -79,3 +78,109 @@ def test_matmul(): ...@@ -79,3 +78,109 @@ def test_matmul():
tester((1, 2, 8, 7), (4, 2, 2, 9, 8), True, True) tester((1, 2, 8, 7), (4, 2, 2, 9, 8), True, True)
tester((1, 8, 7), (4, 3, 2, 8, 9), True, False) tester((1, 8, 7), (4, 3, 2, 8, 9), True, False)
tester((1, 8, 7), (4, 3, 1, 9, 8), True, True) tester((1, 8, 7), (4, 3, 1, 9, 8), True, True)
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
def test_sort_and_argsort():
def tester(ishape, descending, dtype=None):
dtype = dtype or np.float32
inp1 = tensor(np.random.randn(*ishape), dtype=dtype)
inp2 = tensor(np.random.randn(*ishape), dtype=dtype)
dout = tensor(np.random.randn(*ishape), dtype=dtype)
gm = GradManager()
@jit.xla_trace(without_host=True)
def func(inp1, inp2, dout):
gm.attach([inp1, inp2])
with gm:
out, idx1 = F.sort(inp1, descending)
idx2 = F.argsort(inp2, -descending)
gm.backward(out, dout)
return out, idx1, idx2, inp1.grad
mge_rsts = func(inp1, inp2, dout)
xla_rsts = func(inp1, inp2, dout)
for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)
for descending in [True, False]:
tester((16, 32), descending)
tester((16, 1), descending)
tester((1, 16), descending)
tester((1, 1), descending)
tester((16,), descending)
tester((1,), descending)
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
def test_topk():
def tester(ishape, k, descending, kth_only, no_sort, dtype=None):
dtype = dtype or np.float32
inp = tensor(np.random.randn(*ishape), dtype=dtype)
out, _ = F.topk(inp, k, descending, kth_only, no_sort)
dout = tensor(0.1 * np.random.randn(*out.shape), dtype=dtype)
gm = GradManager()
@jit.xla_trace(without_host=True)
def func(inp, dout):
gm.attach([inp])
with gm:
out, index = F.topk(inp, k, descending, kth_only, no_sort)
gm.backward(out, dout)
return out, index, inp.grad
mge_rsts = func(inp, dout)
xla_rsts = func(inp, dout)
for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)
for descending in [True, False]:
tester((2, 16,), 1, descending, False, False)
tester((2, 16,), 8, descending, False, False)
tester((1, 16,), 1, descending, False, False)
tester((1, 16,), 5, descending, False, False)
tester((16,), 8, descending, False, False)
tester((16,), 8, descending, False, False)
tester((1,), 1, descending, False, False)
tester((1,), 1, descending, False, False)
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
def test_topk_accuracy():
def tester(batch, nr_class, topk, dtype=None):
dtype = dtype or np.float32
logits = tensor(np.random.uniform(0, 1, (batch, nr_class)), dtype=dtype)
target = tensor(np.random.randint(0, nr_class, (batch,), np.int32))
out = F.topk_accuracy(logits, target, topk)
dout = tensor(0.1 * np.random.randn(*out.shape), dtype=dtype)
gm = GradManager()
@jit.xla_trace(without_host=True)
def func(logits, target, dout):
gm.attach([logits])
with gm:
out = F.topk_accuracy(logits, target, topk)
gm.backward(out, dout)
return [out]
mge_rsts = func(logits, target, dout)
xla_rsts = func(logits, target, dout)
for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)
tester(32, 1000, 10)
tester(32, 1, 1)
tester(1, 1000, 10)
tester(1, 1, 1)
...@@ -113,6 +113,13 @@ def test_transpose(): ...@@ -113,6 +113,13 @@ def test_transpose():
tester((2, 3, 1), (0, 1, 2)) tester((2, 3, 1), (0, 1, 2))
tester((2, 3, 1, 4), (3, 1, 0, 2)) tester((2, 3, 1, 4), (3, 1, 0, 2))
tester((1,), ("x", 0))
# tester((1,), (0, 'x')) # bug for mge
tester((1, 2), ("x", 0, 1))
tester((1, 2), (0, "x", 1))
# tester((1, 2), (0, 1, 'x')) # bug for mge
tester((16, 32, 64), (0, "x", 2, "x", 1))
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38") @pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now") @pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册