提交 912d733e 编写于 作者: M Megvii Engine Team

fix(dnn): support bool for IndexingMultiAxisVec

GitOrigin-RevId: ddcfaa06b0a1eefb1964e28ebc0b8022e482eb5e
上级 dacc4854
......@@ -72,6 +72,7 @@ namespace indexing_multi_axis_vec {
#define cb0(_dtype) \
MEGDNN_FOREACH_TENSOR_NDIM(INST, DTypeTrait<_dtype>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb0)
cb0(::megdnn::dtype::Bool)
#undef cb0
#undef INST
......
......@@ -39,6 +39,11 @@ __device__ void atomicAdd(megdnn::dt_int16 *, megdnn::dt_int16) {
((int*)0)[0] = 1;
}
__device__ void atomicAdd(megdnn::dt_bool *, megdnn::dt_bool) {
__trap();
((int*)0)[0] = 1;
}
#define KERN_APPLY_OPR_OPR \
::megdnn::cuda::indexing_multi_axis_vec::OprAtomicIncr
#include "./kern_apply_opr_impl.cuinl"
......
......@@ -120,6 +120,7 @@ void ExecImpl<Opr>::dispatch_exec() {
case DTypeTrait<_dtype>::enumv: \
return dispatch_exec_ctype<DTypeTrait<_dtype>::ctype>();
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
default:
megdnn_throw("bad dtype");
......
......@@ -88,6 +88,7 @@ void dispatch_exec(HandleImpl *handle,
}
switch (data.layout.dtype.enumv()) {
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
default:
megdnn_throw(megdnn_mangle("bad dtype"));
}
......
......@@ -519,6 +519,18 @@ def test_advance_indexing_with_bool():
np.testing.assert_equal(a[b], aa[bb].numpy())
np.testing.assert_equal(a[:, [True, False]], aa[:, [True, False]].numpy())
a = np.array([[True, False], [False, True]])
b = np.array([1])
aa = Tensor(a)
bb = Tensor(b)
np.testing.assert_equal(a[b], aa[bb].numpy())
b = np.array([[True, True], [False, True]])
bb = Tensor(b)
np.testing.assert_equal(a[b], aa[bb].numpy())
a[b] = False
aa[bb] = False
np.testing.assert_equal(a, aa.numpy())
# XXX: trace does not expect empty condtake tensor
if not use_tensor_shape():
a = np.ones((2, 2), dtype=np.int32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册