diff --git a/imperative/python/test/unit/core/test_indexing_op.py b/imperative/python/test/unit/core/test_indexing_op.py index dad883c1070d0ea48a6dcd1324e6def9a109dc47..6f483eb6934d2573bc20511cb197a7ffa35f5c79 100644 --- a/imperative/python/test/unit/core/test_indexing_op.py +++ b/imperative/python/test/unit/core/test_indexing_op.py @@ -5,6 +5,7 @@ from tempfile import NamedTemporaryFile import numpy as np import pytest +import torch from utils import make_tensor import megengine @@ -13,6 +14,7 @@ import megengine.functional as F import megengine.jit as jit import megengine.random as rand import megengine.utils.comp_graph_tools as cgtools +from megengine.autodiff import GradManager from megengine.core._imperative_rt.core2 import apply from megengine.core._trace_option import use_symbolic_shape from megengine.core.ops import builtin @@ -622,6 +624,25 @@ def test_advance_indexing_with_bool(test_varnode): ) +def test_advance_indexing_autodiff(): + x = Tensor([2, 2, 3, 4, 5, 6, 7, 8, 2], dtype="float32") + gm = GradManager() + gm.attach(x) + with gm: + a = x + 1 + a[x > 3] = 0.3 + b = a + 1 + gm.backward(b.sum()) + torch_x = torch.tensor( + [2, 2, 3, 4, 5, 6, 7, 8, 2], dtype=torch.float32, requires_grad=True + ) + a = torch_x + 1 + a[torch_x > 3] = 0.3 + b = a + 1 + (b.sum()).backward() + np.testing.assert_equal(x.grad.numpy(), torch_x.grad.numpy()) + + @pytest.mark.parametrize("symbolic", [True, False, None]) def test_subtensor_on_empty_tensor(symbolic): np_x = np.array([], dtype=np.float32).reshape(10, 0, 10) diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index 67adb0c6e08e79494995bafb3530800cad9b91bc..4b486f8d530d30844b12c9874f1894e69c014383 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -306,7 +306,7 @@ cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& in SmallVector get_input_layout_constraint( const OpDef& def, const SmallVector& inputs) { SmallVector layout_checker(inputs.size()); - layout_checker[0] = [](const TensorLayout& layout) { + layout_checker[1] = [](const TensorLayout& layout) { return layout.is_contiguous(); }; return layout_checker;