提交 281ecd0b 编写于 作者: M Megvii Engine Team

feat(xla): support IndexingMultiAxisVec and IndexingIncrMultiAxisVec

GitOrigin-RevId: ca13d142ef5f6d952350f7217f1aebc2ff644dd6
上级 5e013d8c
...@@ -409,11 +409,15 @@ def scatter( ...@@ -409,11 +409,15 @@ def scatter(
oshape, odtype = oup_var.shape, oup_var.dtype oshape, odtype = oup_var.shape, oup_var.dtype
else: else:
oshape, odtype = x.shape, x.dtype oshape, odtype = x.shape, x.dtype
indices = (
ir_utils.ir_constant(indices)
if not isinstance(indices, HLOTensor)
else indices.tensor
)
op = hlo.ScatterOp( op = hlo.ScatterOp(
ir_utils.make_ir_type_according_meta_tuple(oshape, odtype), ir_utils.make_ir_type_according_meta_tuple(oshape, odtype),
[x.tensor], [x.tensor],
ir_utils.ir_constant(indices), indices,
[y.tensor], [y.tensor],
scatter_dnums, scatter_dnums,
indices_are_sorted=ir.BoolAttr.get(indices_are_sorted), indices_are_sorted=ir.BoolAttr.get(indices_are_sorted),
...@@ -424,7 +428,32 @@ def scatter( ...@@ -424,7 +428,32 @@ def scatter(
update = op.update_computation.blocks.append(scalar_type, scalar_type) update = op.update_computation.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(update): with ir.InsertionPoint(update):
if mode == "add":
add = hlo.AddOp(*update.arguments)
hlo.ReturnOp(add.results)
else:
hlo.ReturnOp((update.arguments[1],)) hlo.ReturnOp((update.arguments[1],))
return HLOTensor(op.results)
def gather(
x, indices, dnums, slice_sizes, indices_are_sorted=False, unique_indices=False,
):
gather_dnums = hlo.GatherDimensionNumbers.get(
collapsed_slice_dims=list(dnums.collapsed_slice_dims),
index_vector_dim=len(indices.shape) - 1,
offset_dims=list(dnums.offset_dims),
start_index_map=list(dnums.start_index_map),
)
op = hlo.GatherOp(
x.tensor,
indices.tensor,
gather_dnums,
indices_are_sorted=ir.BoolAttr.get(indices_are_sorted),
slice_sizes=ir_utils.dense_int_elements(slice_sizes),
)
return HLOTensor(op.results) return HLOTensor(op.results)
...@@ -554,3 +583,68 @@ def indexing_set_one_hot_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]] ...@@ -554,3 +583,68 @@ def indexing_set_one_hot_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]
assert ctx.op.ndim == args[0].ndim, f"{ctx.op.ndim}, {args[0].shape}" assert ctx.op.ndim == args[0].ndim, f"{ctx.op.ndim}, {args[0].shape}"
return indexing_set_with_tensor_index(args[0], args[2], args[1], ctx.op.axis) return indexing_set_with_tensor_index(args[0], args[2], args[1], ctx.op.axis)
def convert_negative_index(indices: HLOTensor, max_indices: int):
max_i = HLOTensor(np.array([max_indices], dtype="int32"))
zero = HLOTensor(np.array([0], dtype="int32"))
zeros = zero.broadcast_to(indices.shape)
max_i = max_i.broadcast_to(indices.shape)
positive_indices = indices + max_i
mask = indices < zeros
return HLOTensor(
hlo.SelectOp(mask.tensor, positive_indices.tensor, indices.tensor).results
)
@register_lower_rule(mops.IndexingMultiAxisVec)
def vec_indexing_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert len(ctx.param["items"]) == 1
axis, _, _, _, is_index = ctx.param["items"][0]
assert is_index
inp = args[0]
indices = args[1]
indices = convert_negative_index(indices, inp.shape[axis])
offset_dims = tuple(i for i in range(len(inp.shape)) if i != axis)
collapsed_slice_dims = (axis,)
start_index_map = (axis,)
indices = indices.reshape(indices.shape + (1,))
slices_size = tuple(
(inp.shape[i] if i != axis else 1 for i in range(len(inp.shape)))
)
return gather(
inp,
indices,
GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map),
slices_size,
)
@register_lower_rule(mops.IndexingIncrMultiAxisVec)
def vec_indexing_incr_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
assert len(ctx.param["items"]) == 1
axis, _, _, _, is_index = ctx.param["items"][0]
assert is_index
inp = args[0]
indices = args[2]
indices = convert_negative_index(indices, inp.shape[axis])
indices = indices.reshape(indices.shape + (1,))
y = args[1]
offset_dims = tuple(i for i in range(len(inp.shape)) if i != axis)
collapsed_slice_dims = (axis,)
start_index_map = (axis,)
dnums = ScatterDimensionNumbers(
update_window_dims=offset_dims,
inserted_window_dims=collapsed_slice_dims,
scatter_dims_to_operand_dims=start_index_map,
)
out = scatter(
inp,
indices,
y,
dnums,
indices_are_sorted=False,
unique_indices=False,
mode="add",
)
return out
...@@ -149,3 +149,40 @@ def test_indexing_one_hot(): ...@@ -149,3 +149,40 @@ def test_indexing_one_hot():
tester((4, 8, 16), -1, False) tester((4, 8, 16), -1, False)
tester((4, 1, 16), -2, True) tester((4, 1, 16), -2, True)
tester((4, 1, 16), -2, False) tester((4, 1, 16), -2, False)
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
def test_index_multi_vec():
def tester(x_shape, index_type, dtype):
dtype = dtype or np.float32
x = tensor(np.random.randn(*x_shape), dtype=dtype)
max_val = x.shape[0]
ind = tensor(np.random.randint(-max_val + 1, max_val, 24).astype("int32"))
gm = GradManager()
rand_num = tensor(np.random.random(x[ind].shape).astype(dtype))
@jit.xla_trace(without_host=True, capture_as_const=True)
def func(inp, ind):
gm.attach([inp])
with gm:
x = inp
if index_type == "set":
x[ind] = tensor(rand_num)
else:
x = x[ind]
gm.backward((x * x).sum())
return x, inp.grad
mge_rsts = func(x, ind)
xla_rsts = func(x, ind)
for mge_rst, xla_rst in zip(mge_rsts, xla_rsts):
np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5)
tester((3, 4, 5, 6), "get", np.float32)
tester((3, 4, 5, 6), "get", np.float16)
# tester((2,2,2,2), "set", np.float32)
# tester((3,4,5,6), "set", np.float16)
# tester((3,4,5,6), "set", np.float16)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册