diff --git a/imperative/python/megengine/xla/rules/indexing.py b/imperative/python/megengine/xla/rules/indexing.py index 194f8854ee1d92fceb2db09ebfd82fa8eb32a4a8..5619c5dc8e211f80f212b5b609353369495e9f75 100644 --- a/imperative/python/megengine/xla/rules/indexing.py +++ b/imperative/python/megengine/xla/rules/indexing.py @@ -409,11 +409,15 @@ def scatter( oshape, odtype = oup_var.shape, oup_var.dtype else: oshape, odtype = x.shape, x.dtype - + indices = ( + ir_utils.ir_constant(indices) + if not isinstance(indices, HLOTensor) + else indices.tensor + ) op = hlo.ScatterOp( ir_utils.make_ir_type_according_meta_tuple(oshape, odtype), [x.tensor], - ir_utils.ir_constant(indices), + indices, [y.tensor], scatter_dnums, indices_are_sorted=ir.BoolAttr.get(indices_are_sorted), @@ -424,7 +428,32 @@ def scatter( update = op.update_computation.blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(update): - hlo.ReturnOp((update.arguments[1],)) + if mode == "add": + add = hlo.AddOp(*update.arguments) + hlo.ReturnOp(add.results) + else: + hlo.ReturnOp((update.arguments[1],)) + + return HLOTensor(op.results) + + +def gather( + x, indices, dnums, slice_sizes, indices_are_sorted=False, unique_indices=False, +): + gather_dnums = hlo.GatherDimensionNumbers.get( + collapsed_slice_dims=list(dnums.collapsed_slice_dims), + index_vector_dim=len(indices.shape) - 1, + offset_dims=list(dnums.offset_dims), + start_index_map=list(dnums.start_index_map), + ) + + op = hlo.GatherOp( + x.tensor, + indices.tensor, + gather_dnums, + indices_are_sorted=ir.BoolAttr.get(indices_are_sorted), + slice_sizes=ir_utils.dense_int_elements(slice_sizes), + ) return HLOTensor(op.results) @@ -554,3 +583,68 @@ def indexing_set_one_hot_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]] assert ctx.op.ndim == args[0].ndim, f"{ctx.op.ndim}, {args[0].shape}" return indexing_set_with_tensor_index(args[0], args[2], args[1], ctx.op.axis) + + +def convert_negative_index(indices: HLOTensor, max_indices: int): + max_i = HLOTensor(np.array([max_indices], dtype="int32")) + zero = HLOTensor(np.array([0], dtype="int32")) + zeros = zero.broadcast_to(indices.shape) + max_i = max_i.broadcast_to(indices.shape) + positive_indices = indices + max_i + mask = indices < zeros + return HLOTensor( + hlo.SelectOp(mask.tensor, positive_indices.tensor, indices.tensor).results + ) + + +@register_lower_rule(mops.IndexingMultiAxisVec) +def vec_indexing_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): + assert len(ctx.param["items"]) == 1 + axis, _, _, _, is_index = ctx.param["items"][0] + assert is_index + inp = args[0] + indices = args[1] + indices = convert_negative_index(indices, inp.shape[axis]) + offset_dims = tuple(i for i in range(len(inp.shape)) if i != axis) + collapsed_slice_dims = (axis,) + start_index_map = (axis,) + indices = indices.reshape(indices.shape + (1,)) + slices_size = tuple( + (inp.shape[i] if i != axis else 1 for i in range(len(inp.shape))) + ) + return gather( + inp, + indices, + GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map), + slices_size, + ) + + +@register_lower_rule(mops.IndexingIncrMultiAxisVec) +def vec_indexing_incr_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]): + assert len(ctx.param["items"]) == 1 + axis, _, _, _, is_index = ctx.param["items"][0] + assert is_index + inp = args[0] + indices = args[2] + indices = convert_negative_index(indices, inp.shape[axis]) + indices = indices.reshape(indices.shape + (1,)) + y = args[1] + offset_dims = tuple(i for i in range(len(inp.shape)) if i != axis) + collapsed_slice_dims = (axis,) + start_index_map = (axis,) + dnums = ScatterDimensionNumbers( + update_window_dims=offset_dims, + inserted_window_dims=collapsed_slice_dims, + scatter_dims_to_operand_dims=start_index_map, + ) + out = scatter( + inp, + indices, + y, + dnums, + indices_are_sorted=False, + unique_indices=False, + mode="add", + ) + return out diff --git a/imperative/python/test/unit/xla/functional/test_xla_indexing.py b/imperative/python/test/unit/xla/functional/test_xla_indexing.py index bad632d5895eca4e73d31c6f45ff98953856a5b9..5abc61785c0d8cf86bf4bc12124084d9cf17f4be 100644 --- a/imperative/python/test/unit/xla/functional/test_xla_indexing.py +++ b/imperative/python/test/unit/xla/functional/test_xla_indexing.py @@ -149,3 +149,40 @@ def test_indexing_one_hot(): tester((4, 8, 16), -1, False) tester((4, 1, 16), -2, True) tester((4, 1, 16), -2, False) + + +@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38") +@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now") +@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now") +def test_index_multi_vec(): + def tester(x_shape, index_type, dtype): + dtype = dtype or np.float32 + x = tensor(np.random.randn(*x_shape), dtype=dtype) + max_val = x.shape[0] + ind = tensor(np.random.randint(-max_val + 1, max_val, 24).astype("int32")) + gm = GradManager() + rand_num = tensor(np.random.random(x[ind].shape).astype(dtype)) + + @jit.xla_trace(without_host=True, capture_as_const=True) + def func(inp, ind): + gm.attach([inp]) + with gm: + x = inp + if index_type == "set": + x[ind] = tensor(rand_num) + else: + x = x[ind] + gm.backward((x * x).sum()) + return x, inp.grad + + mge_rsts = func(x, ind) + xla_rsts = func(x, ind) + for mge_rst, xla_rst in zip(mge_rsts, xla_rsts): + np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5) + + tester((3, 4, 5, 6), "get", np.float32) + tester((3, 4, 5, 6), "get", np.float16) + + # tester((2,2,2,2), "set", np.float32) + # tester((3,4,5,6), "set", np.float16) + # tester((3,4,5,6), "set", np.float16)