diff --git a/imperative/python/megengine/core/tensor/indexing.py b/imperative/python/megengine/core/tensor/indexing.py index 2c612c9b4899c774d50423f76e2f83c577dccee8..8da5a66d515279540aa63712d773f65fbe22b0a0 100644 --- a/imperative/python/megengine/core/tensor/indexing.py +++ b/imperative/python/megengine/core/tensor/indexing.py @@ -180,7 +180,7 @@ def try_condtake(tensor, index): if index.dtype != np.bool_ or index.shape != tensor.shape: return [] if isinstance(index, np.ndarray): - (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) + (index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) assert isinstance(index, (TensorBase, TensorWrapperBase)) if not isinstance(tensor, (TensorWrapperBase, TensorBase)): raise TypeError("input must be a tensor") diff --git a/imperative/python/test/unit/test_indexing_op.py b/imperative/python/test/unit/test_indexing_op.py index 70b2911f046883eca5d2fbd96b44b1191034ea1f..213819daedde65038cad6f3320326f49e7ecd577 100644 --- a/imperative/python/test/unit/test_indexing_op.py +++ b/imperative/python/test/unit/test_indexing_op.py @@ -522,6 +522,7 @@ def test_advance_indexing_with_bool(): b = np.array([[False, False], [False, False]]) aa = Tensor(a) bb = Tensor(b) + np.testing.assert_equal(a[b], aa[b].numpy()) np.testing.assert_equal(a[b], aa[bb].numpy()) b = np.array([False, False])