提交 92f7cceb 编写于 作者: M Megvii Engine Team

fix(mge/functional): fix cond_take when index is numpy array

GitOrigin-RevId: 5fb93740f49ff1b6283ca8d0e30e5d417d66717d
上级 f4927db2
...@@ -180,7 +180,7 @@ def try_condtake(tensor, index): ...@@ -180,7 +180,7 @@ def try_condtake(tensor, index):
if index.dtype != np.bool_ or index.shape != tensor.shape: if index.dtype != np.bool_ or index.shape != tensor.shape:
return [] return []
if isinstance(index, np.ndarray): if isinstance(index, np.ndarray):
(i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) (index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor)
assert isinstance(index, (TensorBase, TensorWrapperBase)) assert isinstance(index, (TensorBase, TensorWrapperBase))
if not isinstance(tensor, (TensorWrapperBase, TensorBase)): if not isinstance(tensor, (TensorWrapperBase, TensorBase)):
raise TypeError("input must be a tensor") raise TypeError("input must be a tensor")
......
...@@ -522,6 +522,7 @@ def test_advance_indexing_with_bool(): ...@@ -522,6 +522,7 @@ def test_advance_indexing_with_bool():
b = np.array([[False, False], [False, False]]) b = np.array([[False, False], [False, False]])
aa = Tensor(a) aa = Tensor(a)
bb = Tensor(b) bb = Tensor(b)
np.testing.assert_equal(a[b], aa[b].numpy())
np.testing.assert_equal(a[b], aa[bb].numpy()) np.testing.assert_equal(a[b], aa[bb].numpy())
b = np.array([False, False]) b = np.array([False, False])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册