提交 2f06d580 编写于 作者: M Megvii Engine Team

feat(xla): add topk and sort for xla

GitOrigin-RevId: 0e881f30429a8d849ad9cdd0e0f47c3e0921ff97
上级 b0470e73
......@@ -192,6 +192,7 @@ class TraceResult:
dtype_to_str = {
"float16": "f16",
"float32": "f32",
"int8": "i8",
"int32": "i32",
"int64": "i64",
"uint8": "u8",
......@@ -417,6 +418,10 @@ def f32_attr(i):
return ir.FloatAttr.get(ir.F32Type.get(), i)
def bool_attr(i):
return ir.BoolAttr.get(i)
def precision_attr(lhs_prec, rhs_prec) -> ir.ArrayAttr:
lhs_prec = str(lhs_prec)
rhs_prec = str(rhs_prec)
......@@ -66,7 +66,7 @@ def _hslice_with_step_is_one(inp, slices):
def _hslice_with_any_step(inp, slices):
if inp_shape is N-dim, slices should contain N slice, slice can not None
if inp_shape is N-dim, slices should contain N slice, slice can not None.
for shape [12, 15], slices can be [slice(0, 3, 1), slice(12, 15, 1)]
starts = [int(sl.start) for sl in slices]
......@@ -83,7 +83,7 @@ def _hslice_with_any_step(inp, slices):
def index_with_slices(inp, slices):
if inp_shape is N-dim, slices should contain N slice, slice can be None
if inp_shape is N-dim, slices should contain N slice, slice can be None.
for shape [12, 15], slices can be [slice(0, 3, 1), slice(12, 15, 1)] or [None, None]
assert isinstance(slices, Sequence), f"{slices}"
......@@ -4,9 +4,13 @@ import numpy as np
from ...core._imperative_rt import ops as mops
from .. import ir_utils
from ..ir_utils import i64_attr
from ..ir_utils import bool_attr, i64_attr
from ..lib.mlir import ir
from ..lib.mlir.dialects import chlo, hlo
from ..utils import flatten_list
from .hlotensor import HLOTensor
from .indexing import ScatterDimensionNumbers, scatter
from .tensor import concat, expand_dims, fill, iota
from .utils import _can_broadcast_to, _shape_equal, register_lower_rule
......@@ -241,5 +245,192 @@ def batched_matmul_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
def _sort_according_to_key(key, *vals, axis=-1, descending=True, is_stable=True):
sort key and vals in the specified axis, return the sorted key and vals.
key and vals should have the same shape, then we reorder both key and vals according
to the value of the key.
example 1: (implement argsort)
inp: 1.7783 -> 0, -1.8184 -> 1, 1.0701 -> 2
[[ 1.7783 -1.8184 1.0701]
[-0.0712 -1.4623 1.3243]]
[[0 1 2]
[0 1 2]]
axis: -1
descend: True
return: after reorder, 1.7783 -> 0, -1.8184 -> 1, 1.0701 -> 2
[[ 1.7783 1.0701 -1.8184]
[ 1.3243 -0.0712 -1.4623]]
[[0 2 1]
[2 0 1]]
example 2:
[[0 2 1]
[2 0 1]]
[[ 1.7783 1.0701 -1.8184]
[ 1.3243 -0.0712 -1.4623]]
axis: -1
descend: False
[[0 1 2]
[0 1 2]]
[[ 1.7783 -1.8184 1.0701]
[-0.0712 -1.4623 1.3243]]
for val in vals:
assert _shape_equal(
key.shape, val.shape
), f"sort key and vals shape mismatch: {key.shape}, {val.shape}"
axis = axis + key.ndim if axis < 0 else axis
sorted_key = ir_utils.make_ir_type_according_meta(key.shape, key.dtype)
sorted_vals = [
ir_utils.make_ir_type_according_meta(val.shape, val.dtype) for val in vals
sort_op = hlo.SortOp(
[sorted_key, *sorted_vals],
[key.tensor, *[val.tensor for val in vals]],
key_type = ir_utils.make_ir_type_according_meta(tuple(), key.dtype)
val_types = [
ir_utils.make_ir_type_according_meta(tuple(), val.dtype) for val in vals
arg_types = [key_type] + val_types
comparator = sort_op.comparator.blocks.append(
*flatten_list(zip(arg_types, arg_types))
with ir.InsertionPoint(comparator):
lhs = HLOTensor(comparator.arguments[0])
rhs = HLOTensor(comparator.arguments[1])
if descending:
hlo.ReturnOp([(lhs > rhs).tensor])
hlo.ReturnOp([(lhs < rhs).tensor])
assert len(sort_op.results) == len(vals) + 1, f"{len(vals)}, {len(sort_op.results)}"
return (HLOTensor(ret) for ret in sort_op.results)
def argsort(inp, axis=-1, descending=True, is_stable=True):
sort inp in the specfic axis, and return the sorted value and index
for example:
[[ 1.7783 -1.8184 1.0701]
[-0.0712 -1.4623 1.3243]]
axis: -1
descend: True
[[ 1.7783 1.0701 -1.8184]
[ 1.3243 -0.0712 -1.4623]]
[[0 2 1]
[2 0 1]]
axis = axis + inp.ndim if axis < 0 else axis
idx = iota(np.int32, inp.shape, axis)
return _sort_according_to_key(
inp, idx, axis=axis, descending=descending, is_stable=is_stable
def argsort_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert (
len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 2
), f"{len(args)}, {len(ctx.vars_in)}, {len(ctx.vars_out)}"
assert ctx.op.order in [
], f"{ctx.op.order}"
descending = ctx.op.order == mops.Argsort.Order.DESCENDING
axis = args[0].ndim - 1 # megengine only support sort in the last dimension
return argsort(args[0], axis, descending, is_stable=True)
def argsort_backward_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert (
len(args) == 3 and len(ctx.vars_in) == 3 and len(ctx.vars_out) == 1
), f"{len(args)}, {len(ctx.vars_in)}, {len(ctx.vars_out)}"
dy, idx, x = args[0], args[1], args[2]
if _shape_equal(x.shape, dy.shape):
# for argsort backward
_, dx = _sort_according_to_key(
idx, dy, axis=-1, descending=False, is_stable=True
# for topk backward, only support axis=-1 and the dx is 2d tensor
dx = fill(0, ctx.vars_out[0].shape, ctx.vars_out[0].dtype)
expander = iota(np.int32, idx.shape, dimension=0)
idx = expand_dims(idx, -1)
expander = expand_dims(expander, -1)
idx = concat([expander, idx], -1)
dnums = ScatterDimensionNumbers(
inserted_window_dims=(0, 1),
scatter_dims_to_operand_dims=(0, 1),
dx = scatter(dx, idx, dy, dnums, unique_indices=True)
return dx
def topk(inp, k, descending=True, kth_only=False, no_sort=False):
return [HLOTensor(rst) for rst in chlo.TopKOp(inp.tensor, i64_attr(k)).results]
do topk in the last dimension of inp, for example:
inp.shape = (2, 3, 4), k = 2, out_shape = (2, 3, 2)
assert k > 0, f"k of topk must bigger than 0, get {k}"
assert no_sort == False, f"no_sort must be False now"
assert kth_only == False, f"kth_only is not support now"
if descending == True:
out, idx = [
HLOTensor(rst) for rst in chlo.TopKOp(inp.tensor, i64_attr(k)).results
inp = -inp
out, idx = [
HLOTensor(rst) for rst in chlo.TopKOp(inp.tensor, i64_attr(k)).results
out = -out
return out, idx
def topk_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert (
len(args) == 2 and len(ctx.vars_in) == 2
), f"{len(args)}, {len(ctx.vars_in)}, {len(ctx.vars_out)}"
assert isinstance(
ctx.vars_in[1].bound_data, np.ndarray
), f"{ctx.vars_in[1].bound_data}"
k = int(ctx.vars_in[1].bound_data)
descending = True if k < 0 else False
k = -k if k < 0 else k
if ctx.op.mode == mops.TopK.Mode.VALUE_IDX_SORTED:
assert len(ctx.vars_out) == 2, f"{len(ctx.vars_out)}"
kth_only, no_sort = False, False
elif ctx.op.mode == mops.TopK.Mode.VALUE_IDX_NOSORT:
assert len(ctx.vars_out) == 2, f"{len(ctx.vars_out)}"
kth_only, no_sort = False, True
assert (
ctx.op.mode == mops.TopK.Mode.KTH_ONLY
), f"invalid mode for topk, {ctx.op.mode}"
kth_only, no_sort = True, False
assert len(ctx.vars_out) == 1, f"{len(ctx.vars_out)}"
return topk(args[0], k, descending, kth_only, no_sort)
......@@ -79,14 +79,13 @@ def transpose(inp, permutation):
def expand_dims(inp, axis):
assert isinstance(axis, int), f"only int axis supported, get {axis}"
axis = (axis + inp.ndim) if axis < 0 else axis
assert axis >= 0 and axis <= inp.ndim, f"invalid axis {axis} for {inp.shape}"
assert (
axis >= -inp.ndim - 1 and axis <= inp.ndim
), f"invalid axis {axis} for {inp.shape}"
dst_shape = []
for i in range(inp.ndim):
if i == axis:
dst_shape = list(inp.shape)
insert_pos = axis if axis >= 0 else (axis + inp.ndim + 1)
dst_shape.insert(insert_pos, 1)
return inp.reshape(tuple(dst_shape))
......@@ -94,14 +93,29 @@ def expand_dims(inp, axis):
def dim_shuffle_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 1
permutation = ctx.op.pattern
return transpose(args[0], permutation)
# mge dimshuffle can do transpose and broadcast simutaneously
# for example:
# case1: (16, 32, 64) with pattern [0, 2, 1] -> (16, 64, 32)
# case2: (16, 32, 64) with pattern [0, -1, 2, -1, 1] -> (16, 1, 64, 1, 32)
# case3: (16, 1, 64, 1, 32) with pattern [0, 4, 2] -> (16, 32, 64)
pattern = ctx.op.pattern
inp = args[0]
if len(pattern) == inp.ndim:
permutation = pattern
return transpose(inp, permutation)
elif len(pattern) > inp.ndim:
permutation = [item for item in pattern if item != -1]
return transpose(inp, permutation).reshape(ctx.vars_out[0].shape)
permutation = [i for i in range(inp.ndim) if i not in pattern] + list(pattern)
return transpose(inp, permutation).reshape(ctx.vars_out[0].shape)
def concat(inps, axis):
assert len(inps) > 0, f"concat inputs should not be empty"
if axis < 0:
axis = axis + inps[0].ndim[0]
axis = axis + inps[0].ndim
hlo_inps = [inp.tensor for inp in inps]
......@@ -175,6 +189,21 @@ def fill(value, shape, dtype):
return broadcast_to(HLOTensor(value, dtype=dtype), shape)
def iota(dtype, shape, dimension):
do some thing like arange.
for example:
shape = (2, 3), dimension=1, output is [[0, 1, 2], [0, 1, 2]]
shape = (2, 3), dimension=-1, output is [[0, 0, 0], [1, 1, 1]]
dimension = dimension + len(shape) if dimension < 0 else dimension
ret = hlo.IotaOp(
ir_utils.make_ir_type_according_meta(shape, dtype), ir_utils.i64_attr(dimension)
assert len(ret) == 1, f"{len(ret)}"
return HLOTensor(ret[0])
def fill_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert len(args) == 1 and len(ctx.vars_in) == 1 and len(ctx.vars_out) == 1
......@@ -31,7 +31,6 @@ def test_matmul():
return out, lhs.grad, rhs.grad
mge_rsts = func(lhs, rhs, dout)
xla_rsts = func(lhs, rhs, dout)
for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
......@@ -79,3 +78,109 @@ def test_matmul():
tester((1, 2, 8, 7), (4, 2, 2, 9, 8), True, True)
tester((1, 8, 7), (4, 3, 2, 8, 9), True, False)
tester((1, 8, 7), (4, 3, 1, 9, 8), True, True)
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
def test_sort_and_argsort():
def tester(ishape, descending, dtype=None):
dtype = dtype or np.float32
inp1 = tensor(np.random.randn(*ishape), dtype=dtype)
inp2 = tensor(np.random.randn(*ishape), dtype=dtype)
dout = tensor(np.random.randn(*ishape), dtype=dtype)
gm = GradManager()
def func(inp1, inp2, dout):
gm.attach([inp1, inp2])
with gm:
out, idx1 = F.sort(inp1, descending)
idx2 = F.argsort(inp2, -descending)
gm.backward(out, dout)
return out, idx1, idx2, inp1.grad
mge_rsts = func(inp1, inp2, dout)
xla_rsts = func(inp1, inp2, dout)
for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)
for descending in [True, False]:
tester((16, 32), descending)
tester((16, 1), descending)
tester((1, 16), descending)
tester((1, 1), descending)
tester((16,), descending)
tester((1,), descending)
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
def test_topk():
def tester(ishape, k, descending, kth_only, no_sort, dtype=None):
dtype = dtype or np.float32
inp = tensor(np.random.randn(*ishape), dtype=dtype)
out, _ = F.topk(inp, k, descending, kth_only, no_sort)
dout = tensor(0.1 * np.random.randn(*out.shape), dtype=dtype)
gm = GradManager()
def func(inp, dout):
with gm:
out, index = F.topk(inp, k, descending, kth_only, no_sort)
gm.backward(out, dout)
return out, index, inp.grad
mge_rsts = func(inp, dout)
xla_rsts = func(inp, dout)
for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)
for descending in [True, False]:
tester((2, 16,), 1, descending, False, False)
tester((2, 16,), 8, descending, False, False)
tester((1, 16,), 1, descending, False, False)
tester((1, 16,), 5, descending, False, False)
tester((16,), 8, descending, False, False)
tester((16,), 8, descending, False, False)
tester((1,), 1, descending, False, False)
tester((1,), 1, descending, False, False)
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
def test_topk_accuracy():
def tester(batch, nr_class, topk, dtype=None):
dtype = dtype or np.float32
logits = tensor(np.random.uniform(0, 1, (batch, nr_class)), dtype=dtype)
target = tensor(np.random.randint(0, nr_class, (batch,), np.int32))
out = F.topk_accuracy(logits, target, topk)
dout = tensor(0.1 * np.random.randn(*out.shape), dtype=dtype)
gm = GradManager()
def func(logits, target, dout):
with gm:
out = F.topk_accuracy(logits, target, topk)
gm.backward(out, dout)
return [out]
mge_rsts = func(logits, target, dout)
xla_rsts = func(logits, target, dout)
for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)
tester(32, 1000, 10)
tester(32, 1, 1)
tester(1, 1000, 10)
tester(1, 1, 1)
......@@ -113,6 +113,13 @@ def test_transpose():
tester((2, 3, 1), (0, 1, 2))
tester((2, 3, 1, 4), (3, 1, 0, 2))
tester((1,), ("x", 0))
# tester((1,), (0, 'x')) # bug for mge
tester((1, 2), ("x", 0, 1))
tester((1, 2), (0, "x", 1))
# tester((1, 2), (0, 1, 'x')) # bug for mge
tester((16, 32, 64), (0, "x", 2, "x", 1))
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
想要评论请 注册