diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index cc1c5f416ed5617c0329311c88789e0b72755e96..ec63540178e82465c39188112e15061459cbdf72 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -1015,7 +1015,7 @@ class IndexingOneHotBase : public OperatorBase { DEF_OPR_PARAM(Axis); protected: - void deduce_layout_fwd( + MGE_WIN_DECLSPEC_FUC void deduce_layout_fwd( const TensorLayout& src, const TensorLayout& index, TensorLayout& dst); void check_layout_fwd( const TensorLayout& src, const TensorLayout& index, diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index d32b9b61c2f161dd4c7164647109618acf8867b0..684dcdaffcf28d9e3ba5a82d689655b7900751b7 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1558,7 +1558,7 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor: ) ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device) - op = builtin.IndexingSetOneHot(axis=inp.ndim) + op = builtin.IndexingSetOneHot(axis=inp.ndim, ndim=inp.ndim) (result,) = apply(op, zeros_tensor, inp, ones_tensor) return result @@ -1609,7 +1609,7 @@ def indexing_one_hot( array([1.], dtype=float32) """ assert isinstance(src, Tensor), "src must be of Tensor type" - op = builtin.IndexingOneHot(axis=axis) + op = builtin.IndexingOneHot(axis=axis, ndim=src.ndim) index = convert_single_value(index, dtype="int32", device=src.device) (result,) = apply(op, src, index) if not keepdims: diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 0bc05023d4b5cc6af1af8edbc31785a22253d378..24eabf545722113f324bb5ae697152c45e876aea 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -393,6 +393,8 @@ def split(inp, nsplits_or_sections, axis=0): def _get_idx(index, axis): index_dims = len(index.shape) idx = [] + if axis < 0: + axis += index_dims for i in range(index_dims): if i != axis: shape = [1] * index_dims @@ -457,21 +459,6 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: "But the input dims:{}, the index dims:{}".format(input_dims, index_dims) ) - if axis < 0 or axis >= input_dims: - raise ValueError( - "Index axis {} is output of bounds, should in range [0 {})".format( - axis, input_dims - ) - ) - - for i in range(input_dims): - if i != axis and input_shape[i] != index_shape[i]: - raise ValueError( - "The input {} and index {} must have the same size apart from axis {}".format( - input_shape, index_shape, axis - ) - ) - idx = _get_idx(index, axis) return inp[idx].reshape(index.shape) # pylint: disable=no-member @@ -524,7 +511,7 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: >>> inp = Tensor(np.zeros(shape=(3,5),dtype=np.float32)) >>> source = Tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]]) >>> index = Tensor([[0,2,0,2,1],[2,0,1,1,2]]) - >>> oup = F.scatter(inp, 0, index,source) + >>> oup = F.scatter(inp, 0, index, source) >>> oup.numpy() array([[0.9935, 0.0718, 0.2256, 0. , 0. ], [0. , 0. , 0.5939, 0.357 , 0.4396], @@ -540,13 +527,6 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: if input_dims != index_dims or input_dims != source_dims: raise ValueError("The input, source and index tensor must have same dimensions") - if axis < 0 or axis >= input_dims: - raise ValueError( - "Index axis {} is output of bounds, should in range [0 {})".format( - axis, input_dims - ) - ) - for i in range(source_dims): if source_shape[i] > input_shape[i]: raise ValueError( @@ -792,6 +772,8 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: >>> out.numpy().shape (2, 2, 9) """ + if start_axis < 0: + start_axis += len(inp.shape) target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,) if end_axis != -1: target_shape += (*inp.shape[end_axis + 1 :],) @@ -1158,6 +1140,5 @@ def cumsum(inp: Tensor, axis: int): [ 4 9 15]], dtype=int32, device=xpux:0) """ assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor" - assert axis >= 0 and axis < inp.ndim, "input axis {} out of bound".format(axis) op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False) return apply(op, inp)[0] diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index 3da9f6187befd31cac891b56391d51c851f0a400..be6c95f94f86a63f8e76e0316edb85a4810e1268 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -490,6 +490,84 @@ std::optional pixelShuffle_grad_rule( return imperative::apply(op, inputs); } +std::optional indexing_grad_rule( + const OpDef& op, Span inputs, Span inputs_require_grad, + CustomBackward& backward) { + auto&& indexing = op.cast_final_safe(); + mgb_assert(inputs.size() == 2); + bool flag = inputs_require_grad[0]; + auto&& grad_op = IndexingSetOneHot::make(indexing.axis, indexing.ndim); + SmallVector inputs2; + if (flag) { + inputs2.push_back(get_shape(inputs[0])); + for (size_t i = 1; i < inputs.size(); ++i) { + inputs2.push_back(inputs[i]); + } + } + auto maker = CustomGradMaker(backward, inputs.size()); + maker.output_size(1).output_captured(0, false); + maker.backward([inputs = std::move(inputs2), + grad_op_ = std::move(grad_op)](Span grads) { + mgb_assert(grads.size() == 1); + ValueRef grad = grads[0]; + SmallVector ret(1); + if (grad && inputs[0]) { + ValueRefList args_(inputs.size() + 1); + auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); + args_[0] = zeros; + args_[1] = inputs[1]; + args_[2] = grads[0]; + ret[0] = imperative::apply(*grad_op_, args_)[0]; + } + return ret; + }); + maker.finalize(); + return imperative::apply(op, inputs); +} + +std::optional indexing_set_one_hot_grad_rule( + const OpDef& op, Span inputs, Span inputs_require_grad, + CustomBackward& backward) { + auto&& indexingSetOneHot = op.cast_final_safe(); + mgb_assert(inputs.size() == 3); + SmallVector inputs2; + inputs2.push_back(get_shape(inputs[0])); + inputs2.push_back(inputs[1]); + inputs2.push_back(get_shape(inputs[2])); + auto maker = CustomGradMaker(backward, inputs.size()); + maker.output_size(1).output_captured(0, false); + maker.backward([inputs = std::move(inputs2), + &indexingSetOneHot](Span grads) { + mgb_assert(grads.size() == 1); + ValueRef grad = grads[0]; + SmallVector ret(3); + if (!grad) { + return ret; + } + if (inputs[0]) { + auto&& grad_op = IndexingSetOneHot::make( + indexingSetOneHot.axis, indexingSetOneHot.ndim); + ValueRefList args_(inputs.size()); + auto&& zeros = make_empty_tensor(grad.device(), inputs[2], grad.dtype()); + args_[0] = grads[0]; + args_[1] = inputs[1]; + args_[2] = zeros; + ret[0] = imperative::apply(*grad_op, args_)[0]; + } + if (inputs[2]) { + auto&& grad_op = IndexingOneHot::make( + indexingSetOneHot.axis, indexingSetOneHot.ndim); + ValueRefList args_(inputs.size() - 1); + args_[0] = grads[0]; + args_[1] = inputs[1]; + ret[2] = imperative::apply(*grad_op, args_)[0]; + } + return ret; + }); + maker.finalize(); + return imperative::apply(op, inputs); +} + std::optional fastpathcopy_grad_rule( const OpDef& op, Span inputs, Span inputs_require_grad, CustomBackward& backward) { @@ -521,6 +599,10 @@ struct Init { CustomBackward::register_grad_rule(AddAxis::typeinfo(), addAxis_grad_rule); CustomBackward::register_grad_rule( RemoveAxis::typeinfo(), removeAxis_grad_rule); + CustomBackward::register_grad_rule( + IndexingOneHot::typeinfo(), indexing_grad_rule); + CustomBackward::register_grad_rule( + IndexingSetOneHot::typeinfo(), indexing_set_one_hot_grad_rule); CustomBackward::register_grad_rule( FastpathCopy::typeinfo(), fastpathcopy_grad_rule); CustomBackward::register_grad_rule( diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index 358aa3a51f8f68034ce4e0437657361fab63ab57..a404b7c4e1faf35ec8de993c0c5e091d0a1d7f13 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -8,11 +8,15 @@ import megengine as mge import megengine.distributed as dist import megengine.functional as F import megengine.module as M +from megengine import Tensor +from megengine.core import _imperative_rt from megengine.core._imperative_rt import CompNode, TensorAttr, imperative from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync from megengine.core.autodiff.grad import Grad +from megengine.core.ops import builtin from megengine.core.ops.builtin import Elemwise, Identity from megengine.functional.distributed import remote_recv, remote_send +from megengine.functional.tensor import ones, zeros def _elwise(mode): @@ -553,3 +557,46 @@ def test_matmul(): if ydim == 1 and transposeB == True: continue test_one(xdim, ydim, transposeA, transposeB) + + +def test_indexing(): + x = np.array([[1.0, 2.0]]).astype("float32") + x = mge.Tensor(x) + index = mge.Tensor([0]) + + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + + def f(x): + return F.indexing_one_hot(x, index, -1) + + y = f(x) + grad(y, F.ones_like(y)) + + np.testing.assert_equal(np.array([[1, 0]], dtype=np.float32), x.grad.numpy()) + + +def test_indexing_set_one_hot(): + x = mge.tensor(np.arange(1, 4, dtype=np.int32)) + + with Grad() as grad: + zeros_tensor = zeros((3, 4), dtype=x.dtype, device=x.device) + ones_tensor = ones((3, 1), dtype=x.dtype, device=x.device) + + grad.wrt(zeros_tensor, callback=save_to(zeros_tensor)) + grad.wrt(ones_tensor, callback=save_to(ones_tensor)) + + def f(x): + op = builtin.IndexingSetOneHot(axis=x.ndim, ndim=x.ndim) + (result,) = apply(op, zeros_tensor, x, ones_tensor) + return result + + y = f(x) + grad(y, F.ones_like(y)) + np.testing.assert_equal( + np.array([[1, 0, 1, 1], [1, 1, 0, 1], [1, 1, 1, 0]], dtype=np.int32), + zeros_tensor.grad.numpy(), + ) + np.testing.assert_equal( + np.array([[1], [1], [1]], dtype=np.int32), ones_tensor.grad.numpy(), + ) diff --git a/imperative/python/test/unit/core/test_function.py b/imperative/python/test/unit/core/test_function.py index c09b0c882eefe223c66a9d4f82646c3375587ef6..8a5e9e8efd1356bda9ec6ff4ba48707d1ac2959e 100644 --- a/imperative/python/test/unit/core/test_function.py +++ b/imperative/python/test/unit/core/test_function.py @@ -6,9 +6,7 @@ import pytest import megengine.autodiff as ad import megengine.functional as F import megengine.optimizer as optimizer -from megengine import Parameter -from megengine import Tensor as tensor -from megengine import tensor +from megengine import Parameter, Tensor, tensor from megengine.autodiff import Function from megengine.module import Module diff --git a/imperative/python/test/unit/functional/test_loss.py b/imperative/python/test/unit/functional/test_loss.py index 15dde82df5c94d2564696a2fb4080e13d88b4c62..2a1950deb5451d835143fe15ba1539a069317e45 100644 --- a/imperative/python/test/unit/functional/test_loss.py +++ b/imperative/python/test/unit/functional/test_loss.py @@ -3,15 +3,15 @@ import numpy as np import pytest import megengine.functional as F -from megengine import tensor +import megengine.tensor as Tensor def test_cross_entropy_with_logits(): - data = tensor([[0, 50], [0, -150]]).astype(np.float32) - label = tensor([1, 0]).astype(np.int32) + data = Tensor([[0, 50], [0, -150]]).astype(np.float32) + label = Tensor([1, 0]).astype(np.int32) loss = F.nn.cross_entropy(data, label) np.testing.assert_allclose(loss.numpy(), 0.0) - label = tensor([0, 1]).astype(np.int32) + label = Tensor([0, 1]).astype(np.int32) loss = F.nn.cross_entropy(data, label) np.testing.assert_allclose(loss.numpy(), 100) @@ -35,19 +35,24 @@ def test_cross_entropy(): x[i, y[i]] += np.random.rand() * 2 x = softmax(x) l_ref = ref(x, y) - l = F.nn.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False) + l = F.nn.cross_entropy(Tensor(x, "float32"), Tensor(y, "int32"), with_logits=False) np.testing.assert_allclose(l.numpy(), l_ref, 1e-6, 1e-6) + l1 = F.nn.cross_entropy( + Tensor(x, "float32"), Tensor(y, "int32"), axis=-1, with_logits=False + ) + np.testing.assert_allclose(l1.numpy(), l_ref, 1e-6, 1e-6) + def test_cross_entropy_reduction(): logits = np.random.randn(16, 10) label = np.random.randint(10, size=[16]) - logits = tensor(logits, dtype="float32") - label = tensor(label, dtype="int32") + logits = Tensor(logits, dtype="float32") + label = Tensor(label, dtype="int32") perm = np.random.permutation(16) - logits_perm = tensor(logits[perm], dtype="float32") - label_perm = tensor(label[perm], dtype="int32") + logits_perm = Tensor(logits[perm], dtype="float32") + label_perm = Tensor(label[perm], dtype="int32") loss = F.nn.cross_entropy(logits, label, reduction="none") loss_perm = F.nn.cross_entropy(logits_perm, label_perm, reduction="none") @@ -160,18 +165,18 @@ def _ctc_npy_single_seq(pred, label, blank): def test_ctc_loss(): def test_func(T, C, N): input = np.random.randn(T, N, C) - input = F.softmax(tensor(input), axis=-1).numpy() + input = F.softmax(Tensor(input), axis=-1).numpy() input_lengths = np.ones(N, dtype=np.int32) * T target_lengths = np.random.randint(low=1, high=T + 1, size=(N,), dtype=np.int32) target = np.random.randint( low=1, high=C, size=(sum(target_lengths)), dtype=np.int32 ) - input_mge = tensor(input) - input_lengths_mge = tensor(input_lengths) + input_mge = Tensor(input) + input_lengths_mge = Tensor(input_lengths) - target_mge = tensor(target) - target_lengths_mge = tensor(target_lengths) + target_mge = Tensor(target) + target_lengths_mge = Tensor(target_lengths) blank = np.random.randint(C) for method in ["mean", "sum", "none"]: diff --git a/imperative/python/test/unit/functional/test_math.py b/imperative/python/test/unit/functional/test_math.py index a8f1032690295fb5b942731f2e7d450269994b54..d508a5058bacff2cb0e65c8a3b6b5cef3614166b 100644 --- a/imperative/python/test/unit/functional/test_math.py +++ b/imperative/python/test/unit/functional/test_math.py @@ -6,7 +6,7 @@ import pytest from utils import opr_test import megengine.functional as F -from megengine import jit, tensor +from megengine import Tensor, jit, tensor from megengine.core._imperative_rt.core2 import apply from megengine.core.ops import builtin @@ -61,37 +61,84 @@ def common_test_reduce(opr, ref_opr): def test_sum(): common_test_reduce(opr=F.sum, ref_opr=np.sum) + x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) + y = F.sum(x, axis=-1) + np.testing.assert_equal(y.numpy(), np.array([6, 15]).astype(np.int32)) + def test_prod(): common_test_reduce(opr=F.prod, ref_opr=np.prod) + x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) + y = F.prod(x, axis=-2) + np.testing.assert_equal(y.numpy(), np.array([4, 10, 18]).astype(np.int32)) + def test_mean(): common_test_reduce(opr=F.mean, ref_opr=np.mean) + x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) + y = F.mean(x, axis=-2) + np.testing.assert_equal(y.numpy(), np.array([2.5, 3.5, 4.5]).astype(np.float32)) + def test_var(): common_test_reduce(opr=F.var, ref_opr=np.var) + x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) + y = F.var(x, axis=-2) + np.testing.assert_equal(y.numpy(), np.array([2.25, 2.25, 2.25]).astype(np.float32)) + def test_std(): common_test_reduce(opr=F.std, ref_opr=np.std) + x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) + y = F.std(x, axis=-2) + np.testing.assert_equal(y.numpy(), np.array([1.5, 1.5, 1.5]).astype(np.float32)) + + x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) + y = F.std(x, axis=-2) + np.testing.assert_equal(y.numpy(), np.array([1.5, 1.5, 1.5]).astype(np.float32)) + def test_min(): common_test_reduce(opr=F.min, ref_opr=np.min) + x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) + y = F.min(x, axis=-1) + np.testing.assert_equal(y.numpy(), np.array([1, 4]).astype(np.int32)) + def test_max(): common_test_reduce(opr=F.max, ref_opr=np.max) + x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) + y = F.max(x, axis=-1) + np.testing.assert_equal(y.numpy(), np.array([3, 6]).astype(np.int32)) + def test_argmin(): common_test_reduce(opr=F.argmin, ref_opr=np.argmin) + x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) + y = F.argmin(x, axis=-1) + np.testing.assert_equal(y.numpy(), np.array([0, 0]).astype(np.int32)) + def test_argmax(): common_test_reduce(opr=F.argmax, ref_opr=np.argmax) + x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) + y = F.argmax(x, axis=-2) + np.testing.assert_equal(y.numpy(), np.array([1, 1, 1]).astype(np.int32)) + + +def test_norm(): + x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) + y = F.norm(x, axis=-1) + np.testing.assert_equal( + y.numpy().round(decimals=3), np.array([3.742, 8.775]).astype(np.float32) + ) def test_sqrt(): @@ -136,7 +183,7 @@ def test_sort_empty(is_symbolic): fn_ = fn data = np.random.random(shape).astype(np.float32) for _ in range(3): - outs = fn_(tensor(data)) + outs = fn_(Tensor(data)) ref_outs = (np.sort(data), np.argsort(data)) assert len(ref_outs) == len(outs) for i in range(len(outs)): @@ -146,6 +193,12 @@ def test_sort_empty(is_symbolic): def test_normalize(): + x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) + y = F.normalize(x, axis=-1) + np.testing.assert_equal( + y.numpy().round(decimals=1), + np.array([[0.3, 0.5, 0.8], [0.5, 0.6, 0.7]]).astype(np.float32), + ) cases = [ {"input": np.random.random((2, 3, 12, 12)).astype(np.float32)} for i in range(2) @@ -177,11 +230,11 @@ def test_sum_neg_axis(): shape = (2, 3) data = np.random.random(shape).astype(np.float32) for axis in (-1, -2, (-2, 1), (-1, 0)): - get = F.sum(tensor(data), axis=axis) + get = F.sum(Tensor(data), axis=axis) ref = np.sum(data, axis=axis) np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6) with pytest.raises(AssertionError): - F.sum(tensor(data), axis=(-1, 1)) + F.sum(Tensor(data), axis=(-1, 1)) def test_builtin_reduce(): @@ -204,18 +257,18 @@ def test_non_finite(): data = [] for i in range(2): data.append(np.random.random(shape).astype(np.float32)) - tensorList = [tensor(x) for x in data] + tensorList = [Tensor(x) for x in data] rst = F.math._check_non_finite(tensorList, 0.7) np.testing.assert_equal(rst.numpy(), [0]) for i in range(len(tensorList)): np.testing.assert_allclose(tensorList[i].numpy() / 0.7, data[i], rtol=1e-6) data[1][0][0][0][0] = float("inf") - rst = F.math._check_non_finite([tensor(x) for x in data], 0.7) + rst = F.math._check_non_finite([Tensor(x) for x in data], 0.7) np.testing.assert_equal(rst.numpy(), [1]) data[1][0][0][0][0] = float("nan") - rst = F.math._check_non_finite([tensor(x) for x in data], 0.7) + rst = F.math._check_non_finite([Tensor(x) for x in data], 0.7) np.testing.assert_equal(rst.numpy(), [1]) @@ -237,7 +290,7 @@ def test_topk(descending, sorted, inp1d, kth_only): return np.sort(x) res = F.topk( - tensor(data), k, descending=descending, no_sort=(not sorted), kth_only=kth_only + Tensor(data), k, descending=descending, no_sort=(not sorted), kth_only=kth_only ) values, indices = res @@ -268,7 +321,7 @@ def test_reduce_on_empty_tensor(is_trace): if is_trace: fn = jit.trace(symbolic=symbolic)(fn) for i in range(3): - out = fn(tensor(input, dtype=dtype), axis=axis).numpy() + out = fn(Tensor(input, dtype=dtype), axis=axis).numpy() out_ref = ref_fn(input.astype(dtype), axis=axis) np.testing.assert_equal(out, out_ref) diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 26b807a8e4a4f6b62fd3a6732225c706c13b6761..69a2ed4302e13e96552964a9989dc1fd1069cba7 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -7,7 +7,7 @@ import pytest from utils import get_var_value, make_tensor, opr_test import megengine.functional as F -from megengine import tensor +from megengine import Tensor from megengine.core._trace_option import use_symbolic_shape from megengine.core.tensor import megbrain_graph as G from megengine.core.tensor.utils import astensor1d @@ -30,7 +30,7 @@ def test_eye(): np.eye(*case["input"]).astype(dtype), ) np.testing.assert_allclose( - F.eye(tensor(case["input"]), dtype=dtype).numpy(), + F.eye(Tensor(case["input"]), dtype=dtype).numpy(), np.eye(*case["input"]).astype(dtype), ) @@ -60,7 +60,21 @@ def test_full(): values = [True, 4, 5.0] for value in values: np.testing.assert_allclose(F.full(shape, value).numpy(), np.full(shape, value)) - assert F.full(shape, value).dtype == tensor(value).dtype + assert F.full(shape, value).dtype == Tensor(value).dtype + + +@pytest.mark.parametrize("is_varnode", [True, False]) +def test_cumsum(is_varnode): + if is_varnode: + network = Network() + else: + network = None + + x = Tensor([[1, 2, 3], [4, 5, 6]], np.int32) + y = F.cumsum(x, -1) + np.testing.assert_equal( + y.numpy(), np.array([[1, 3, 6], [4, 9, 15]]).astype(np.int32) + ) @pytest.mark.parametrize("is_varnode", [True, False]) @@ -83,6 +97,14 @@ def test_concat(is_varnode): cases = [{"input": [data1, data2]}, {"input": [data1, data3]}] opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network) + x1 = Tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3))) + x2 = Tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3))) + y = F.concat([x1, x2], axis=-1) + np.testing.assert_equal( + y.numpy(), + np.array([[0, 1, 2, 6, 7, 8], [3, 4, 5, 9, 10, 11]]).astype(np.float32), + ) + @pytest.mark.parametrize("is_varnode", [True, False]) def test_condtake(is_varnode): @@ -139,6 +161,20 @@ def test_stack(is_varnode): cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network ) + x1 = Tensor(np.arange(0, 3, dtype=np.float32).reshape((3))) + x2 = Tensor(np.arange(6, 9, dtype=np.float32).reshape((3))) + y = F.stack([x1, x2], axis=-1) + np.testing.assert_equal( + y.numpy(), np.array([[0, 6], [1, 7], [2, 8]]).astype(np.float32) + ) + + x1 = Tensor(np.arange(0, 3, dtype=np.float32).reshape((3))) + x2 = Tensor(np.arange(6, 9, dtype=np.float32).reshape((3))) + y = F.stack([x1, x2], axis=-1) + np.testing.assert_equal( + y.numpy(), np.array([[0, 6], [1, 7], [2, 8]]).astype(np.float32) + ) + @pytest.mark.parametrize("is_varnode", [True, False]) def test_split_basic(is_varnode): @@ -183,6 +219,12 @@ def test_split_basic(is_varnode): @pytest.mark.parametrize("symbolic", [None, False, True]) def test_split(symbolic): + x = Tensor(np.random.random((10, 20)), dtype=np.float32) + y = F.split(x, 3, axis=-1) + z = F.split(x, [6, 17], axis=-1) + assert str([i.numpy().shape for i in y]) == "[(10, 7), (10, 7), (10, 6)]" + assert str([i.numpy().shape for i in z]) == "[(10, 6), (10, 11), (10, 3)]" + inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32) inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32) @@ -208,12 +250,43 @@ def test_split(symbolic): fn = trace(symbolic=symbolic)(func) for i in range(3 if symbolic is not None else 1): ref_out = ref(*case) - out = fn(tensor(case[0]), case[1], case[2]) + out = fn(Tensor(case[0]), case[1], case[2]) assert len(ref_out) == len(out) for idx in range(len(ref_out)): np.testing.assert_equal(ref_out[idx], out[idx].numpy()) +def test_gather(): + x = Tensor([[1, 2], [3, 4], [5, 6],]) + index = Tensor([[0, 1], [1, 0], [1, 1]]) + y = F.gather(x, 1, index) + np.testing.assert_equal( + y.numpy(), np.array([[1, 2], [4, 3], [6, 6]]).astype(np.int32) + ) + + +def test_scatter(): + x = Tensor(np.zeros(shape=(3, 5), dtype=np.float32)) + source = Tensor( + [ + [0.9935, 0.9465, 0.2256, 0.8926, 0.4396], + [0.7723, 0.0718, 0.5939, 0.357, 0.4576], + ] + ) + index = Tensor([[0, 2, 0, 2, 1], [2, 0, 1, 1, 2]]) + y = F.scatter(x, -2, index, source) + np.testing.assert_equal( + y.numpy().round(decimals=4), + np.array( + [ + [0.9935, 0.0718, 0.2256, 0.0, 0.0], + [0.0, 0.0, 0.5939, 0.357, 0.4396], + [0.7723, 0.9465, 0.0, 0.8926, 0.4576], + ] + ).astype(np.float32), + ) + + @pytest.mark.parametrize("is_varnode", [True, False]) def test_swapaxes(is_varnode): if is_varnode: @@ -221,7 +294,7 @@ def test_swapaxes(is_varnode): else: network = None - x = tensor(np.array([[1, 2, 3]], dtype=np.int32)) + x = Tensor(np.array([[1, 2, 3]], dtype=np.int32)) y = F.swapaxes(x, 0, 1) np.testing.assert_equal(y.numpy(), np.array([[1], [2], [3]]).astype(np.int32)) @@ -280,15 +353,15 @@ def test_broadcast_auto_infer(is_varnode): def test_reshape_on_empty_tensor(is_trace): input1_shape = (100, 0, 1) output1_shape = (100, 0, 10) - data1 = tensor(np.random.random(input1_shape).astype(np.float32)) + data1 = Tensor(np.random.random(input1_shape).astype(np.float32)) input2_shape = (10, 0) output2_shape = (0,) - data2 = tensor(np.random.random(input2_shape).astype(np.float32)) + data2 = Tensor(np.random.random(input2_shape).astype(np.float32)) input3_shape = (10, 0, 10) output3_shape = (0, 1, 2, 3) - data3 = tensor(np.random.random(input3_shape).astype(np.float32)) + data3 = Tensor(np.random.random(input3_shape).astype(np.float32)) def comp(out, target_shp): assert out._tuple_shape == target_shp @@ -338,7 +411,7 @@ def test_reshape_shape_inference(is_varnode): def check_shape(output, target): source = output.shape - if isinstance(source, tensor): + if isinstance(source, Tensor): source = source.numpy() np.testing.assert_equal(source, target.shape) @@ -366,6 +439,10 @@ def test_squeeze(is_varnode): else: network = None + x = Tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1)) + y = F.squeeze(x, -1) + np.testing.assert_equal(y.numpy(), np.array([[[1, 2]]]).astype(np.int32)) + x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) xx = make_tensor(x, network) @@ -385,6 +462,12 @@ def test_expand_dims(is_varnode): else: network = None + x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) + y = F.expand_dims(x, -1) + np.testing.assert_equal( + y.numpy(), np.array([[[1], [2], [3]], [[4], [5], [6]]]).astype(np.int32) + ) + x = np.arange(6, dtype="float32").reshape(2, 3) xx = make_tensor(x, network) @@ -533,6 +616,22 @@ def test_flatten(is_varnode): else: network = None + inp_shape = (2, 2, 3, 3) + x = Tensor(np.arange(36, dtype=np.int32).reshape(inp_shape),) + y = F.flatten(x, -2, -1) + np.testing.assert_equal( + y.numpy(), + np.array( + [ + [[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16, 17]], + [ + [18, 19, 20, 21, 22, 23, 24, 25, 26], + [27, 28, 29, 30, 31, 32, 33, 34, 35], + ], + ] + ).astype(np.int32), + ) + data0_shape = (2, 3, 4, 5) data1_shape = (4, 5, 6, 7) data0 = np.random.random(data0_shape).astype(np.float32) @@ -616,15 +715,15 @@ def test_broadcast(is_varnode): def test_broadcast_on_empty_tensor(is_trace): input1_shape = (100, 0, 1) output1_shape = (100, 0, 10) - data1 = tensor(np.random.random(input1_shape).astype(np.float32)) + data1 = Tensor(np.random.random(input1_shape).astype(np.float32)) input2_shape = (10, 0) output2_shape = (10, 10, 0) - data2 = tensor(np.random.random(input2_shape).astype(np.float32)) + data2 = Tensor(np.random.random(input2_shape).astype(np.float32)) input3_shape = (0, 0, 1, 10) output3_shape = (10, 0, 0, 10, 10) - data3 = tensor(np.random.random(input3_shape).astype(np.float32)) + data3 = Tensor(np.random.random(input3_shape).astype(np.float32)) def comp(out, target_shp): assert out._tuple_shape == target_shp @@ -705,7 +804,7 @@ def test_utils_astensor1d(is_varnode): def test_device(): - x = tensor([1, 2, 3], dtype="float32") + x = Tensor([1, 2, 3], dtype="float32") y1 = F.eye(x.shape, dtype="float32") y2 = F.eye(x.shape, dtype="float32", device=None) @@ -789,7 +888,7 @@ def test_copy_d2d(is_varnode): ) @pytest.mark.parametrize("is_symbolic", [None, True, False]) def test_copy_empty(shape, device_src, device_dst, is_symbolic): - inp = tensor(np.random.randn(*shape).astype("float32"), device=device_src) + inp = Tensor(np.random.randn(*shape).astype("float32"), device=device_src) def func(inp): return F.copy(inp, device_dst) @@ -885,6 +984,12 @@ def test_roll(shape, shifts, axis, is_varnode): else: network = None + x = Tensor([[1, 2], [3, 4], [5, 6]], np.int32) + y = F.roll(x, 1, -1) + np.testing.assert_equal( + y.numpy(), np.array([[2, 1], [4, 3], [6, 5]]).astype(np.int32) + ) + inp = np.random.randn(*shape).astype("float32") def func(inp): @@ -904,7 +1009,7 @@ def test_roll(shape, shifts, axis, is_varnode): ) @pytest.mark.parametrize("is_symbolic", [None, True, False]) def test_roll_empty_tensor(shape, shifts, axis, is_symbolic): - inp = tensor(np.random.randn(*shape).astype("float32")) + inp = Tensor(np.random.randn(*shape).astype("float32")) def func(inp): return F.roll(inp, shifts, axis) diff --git a/imperative/src/impl/ops/indexing.cpp b/imperative/src/impl/ops/indexing.cpp index 6d7a6d07ca0eb2ba420cefaa7c13b2ab42927d62..63ecd1ce05cd65d88077bf4cf02d6ade1d39323d 100644 --- a/imperative/src/impl/ops/indexing.cpp +++ b/imperative/src/impl/ops/indexing.cpp @@ -1,8 +1,10 @@ +#include "../dnn_op_helper.h" #include "megbrain/imperative/ops/autogen.h" #include "../op_trait.h" #include "megbrain/opr/indexing.h" +#include "megdnn/oprs/general.h" namespace mgb { namespace imperative { @@ -12,10 +14,8 @@ namespace indexing_one_hot { std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& input_descs) { - auto& op = def.cast_final_safe(); - + auto&& op = def.cast_final_safe(); mgb_assert(input_descs.size() == 2, "IndexingOneHot expects two inputs"); - auto comp_node = input_descs[0].comp_node; TensorLayout src = input_descs[0].layout, index = input_descs[1].layout; @@ -28,10 +28,15 @@ std::tuple, bool> infer_output_attrs_fallible( mgb_assert(src.ndim >= 2, "src ndim must be at least 2"); mgb_assert(src.is_contiguous(), "src should be contiguous"); mgb_assert( - op.axis >= 0 && op.axis < src.ndim, "axis %d not exists in src", op.axis); - + -static_cast(src.ndim) <= op.axis && + op.axis < static_cast(src.ndim), + "axis %d not exists in src", op.axis); + int real_axis = static_cast(op.axis); + if (real_axis < 0) { + real_axis += static_cast(src.ndim); + } TensorLayout dst = src; - dst.shape[op.axis] = 1; + dst.shape[real_axis] = 1; dst.init_contiguous_stride(); if (!index.ndim) { @@ -40,24 +45,128 @@ std::tuple, bool> infer_output_attrs_fallible( mgb_assert(index.is_contiguous(), "index should be all contiguous"); mgb_assert( - index.eq_shape(src.remove_axis(op.axis)), "index shape doesn't match src"); - + index.eq_shape(src.remove_axis(real_axis)), + "index shape doesn't match src"); return {{{dst, comp_node}}, true}; } auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { - auto&& op = static_cast(def); + auto&& op = def.cast_final_safe(); mgb_assert(inputs.size() == 2); + int real_axis = static_cast(op.axis); + if (real_axis < 0) { + real_axis += static_cast(op.ndim); + } OperatorNodeConfig config{op.make_name()}; - return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config); + return opr::IndexingOneHot::make(inputs[0], inputs[1], real_axis, config); +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, SmallVector inputs, + SmallVector& output_descs, const bool& validated) { + auto&& op = def.cast_final_safe(); + auto&& inp = inputs[0]; + auto&& index = inputs[1]; + TensorLayout layout = inp->layout(); + TensorLayout index_layout = index->layout(); + DnnOprCaller dnn_op(inp->comp_node()); + auto&& indexing_one_hot_param = dnn_op.op->param(); + int real_axis = static_cast(op.axis); + if (real_axis < 0) { + real_axis += static_cast(layout.ndim); + } + mgb_assert( + 0 <= real_axis && real_axis < static_cast(layout.ndim), + "Dimension out of range (expected to be in range of [%d, %d], but got %d)", + 0, static_cast(layout.ndim) - 1, op.axis); + indexing_one_hot_param = real_axis; + TensorLayout tlayout; + dnn_op.op->deduce_layout(layout, index_layout, tlayout); + TensorPtr out = Tensor::make(tlayout, inp->comp_node()); + megdnn::TensorND in = inp->dnn_tensor(); + megdnn::TensorND ind = index->dnn_tensor(); + TensorLayout m_layout( + {dnn_op.op->get_workspace_in_bytes(layout, index_layout, tlayout)}, + dtype::Byte()); + auto dnn_workspace = dnn_op.create_workspace(m_layout); + dnn_op.op->exec(in, ind, out->dnn_tensor(), dnn_workspace); + return {out}; } OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) .infer_output_attrs_fallible(infer_output_attrs_fallible) .apply_on_var_node(apply_on_var_node) + .apply_on_physical_tensor(apply_on_physical_tensor) .fallback(); - } // namespace indexing_one_hot + +namespace indexing_set_one_hot { + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& input_descs) { + mgb_assert(input_descs.size() == 3, "IndexingSetOneHot expects three inputs"); + auto comp_node = input_descs[0].comp_node; + TensorLayout src = input_descs[0].layout, index = input_descs[1].layout; + + mgb_assert(index.dtype == dtype::Int32(), "index dtype must be int32"); + + if (!src.ndim) { + return {{{{{}, src.dtype}, comp_node}}, false}; + } + mgb_assert(src.is_contiguous(), "src should be contiguous"); + return {{input_descs[0]}, true}; +} + +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 3); + int real_axis = static_cast(op.axis); + if (real_axis < 0) { + real_axis += static_cast(op.ndim); + } + OperatorNodeConfig config{op.make_name()}; + return opr::IndexingSetOneHot::make( + inputs[0], inputs[1], inputs[2], real_axis, config); +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, SmallVector inputs, + SmallVector& output_descs, const bool& validated) { + auto&& op = def.cast_final_safe(); + auto&& inp = inputs[0]; + auto&& index = inputs[1]; + auto&& sub = inputs[2]; + TensorLayout layout = inp->layout(); + TensorLayout index_layout = index->layout(); + TensorLayout tlayout = sub->layout(); + mgb_assert(layout.is_contiguous()); + DnnOprCaller dnn_op(inp->comp_node()); + auto&& indexing_one_hot_param = dnn_op.op->param(); + int real_axis = static_cast(op.axis); + if (real_axis < 0) { + real_axis += static_cast(layout.ndim); + } + indexing_one_hot_param = real_axis; + TensorPtr out = Tensor::make(layout, inp->comp_node()); + out->dev_tensor().copy_from_fixlayout(inp->dev_tensor()); + megdnn::TensorND in = inp->dnn_tensor(); + megdnn::TensorND ind = index->dnn_tensor(); + megdnn::TensorND su = sub->dnn_tensor(); + TensorLayout m_layout( + {dnn_op.op->get_workspace_in_bytes(layout, index_layout, tlayout)}, + dtype::Byte()); + auto dnn_workspace = dnn_op.create_workspace(m_layout); + dnn_op.op->exec(out->dnn_tensor(), ind, su, dnn_workspace); + return {out}; +} + +OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_var_node(apply_on_var_node) + .apply_on_physical_tensor(apply_on_physical_tensor) + .fallback(); +} // namespace indexing_set_one_hot + } // anonymous namespace } // namespace imperative } // namespace mgb diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index 1708a0887e3dbd75dd61fcf301aa9010736973c0..6eb935e9d1f7673661d742a0a96bed5283986031 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -372,21 +372,6 @@ OP_TRAIT_REG(GroupLocal, GroupLocal).apply_on_var_node(apply_on_var_node).fallba } // namespace group_local } // namespace -namespace { -namespace indexing_set_one_hot { -auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { - auto&& op = static_cast(def); - mgb_assert(inputs.size() == 3); - OperatorNodeConfig config{op.make_name()}; - return opr::IndexingSetOneHot::make( - inputs[0], inputs[1], inputs[2], op.param(), config); -} -OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) - .apply_on_var_node(apply_on_var_node) - .fallback(); -} // namespace indexing_set_one_hot -} // namespace - namespace { namespace typecvt { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 4d8dff62304abeca8436fb437fb840eb1c89acd0..d70cd20b896a696fdf59b7e40ed37829ce14afdb 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -108,9 +108,17 @@ def Remap: MgbHashableOp<"Remap", [RemapParam]>; def Resize: MgbHashableOp<"Resize", [ResizeParam]>; -def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]>; +def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]> { + let extraArguments = (ins + MgbI32Attr:$ndim + ); +} -def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]>; +def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]> { + let extraArguments = (ins + MgbI32Attr:$ndim + ); +} def Copy: MgbHashableOp<"Copy"> { let extraArguments = (ins