提交 92f7cceb 编写于 作者: M Megvii Engine Team

fix(mge/functional): fix cond_take when index is numpy array

GitOrigin-RevId: 5fb93740f49ff1b6283ca8d0e30e5d417d66717d
上级 f4927db2
......@@ -180,7 +180,7 @@ def try_condtake(tensor, index):
if index.dtype != np.bool_ or index.shape != tensor.shape:
return []
if isinstance(index, np.ndarray):
(i,) = Const(i, dtype=np.bool_, device=inp.device)(inp)
(index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor)
assert isinstance(index, (TensorBase, TensorWrapperBase))
if not isinstance(tensor, (TensorWrapperBase, TensorBase)):
raise TypeError("input must be a tensor")
......
......@@ -522,6 +522,7 @@ def test_advance_indexing_with_bool():
b = np.array([[False, False], [False, False]])
aa = Tensor(a)
bb = Tensor(b)
np.testing.assert_equal(a[b], aa[b].numpy())
np.testing.assert_equal(a[b], aa[bb].numpy())
b = np.array([False, False])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册