diff --git a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl index a640d865e32c8ba5f99819d178074e707f9622f7..de8ec0338737c4c10fa924f0e39510f06a223cf1 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl +++ b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl @@ -72,6 +72,7 @@ namespace indexing_multi_axis_vec { #define cb0(_dtype) \ MEGDNN_FOREACH_TENSOR_NDIM(INST, DTypeTrait<_dtype>::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE(cb0) + cb0(::megdnn::dtype::Bool) #undef cb0 #undef INST diff --git a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu index 74874f23d7559ea8514056011c9f1f49d4f584af..cfefef12ea05406976f1a90b7864275dacc16f64 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu +++ b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu @@ -39,6 +39,11 @@ __device__ void atomicAdd(megdnn::dt_int16 *, megdnn::dt_int16) { ((int*)0)[0] = 1; } +__device__ void atomicAdd(megdnn::dt_bool *, megdnn::dt_bool) { + __trap(); + ((int*)0)[0] = 1; +} + #define KERN_APPLY_OPR_OPR \ ::megdnn::cuda::indexing_multi_axis_vec::OprAtomicIncr #include "./kern_apply_opr_impl.cuinl" diff --git a/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp b/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp index 4e86490510687ce6379d038976ba9a38ab756ffc..85b8346813bcc4bf3ab511c12c3013e37bb880af 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp +++ b/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp @@ -120,6 +120,7 @@ void ExecImpl::dispatch_exec() { case DTypeTrait<_dtype>::enumv: \ return dispatch_exec_ctype::ctype>(); MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) #undef cb default: megdnn_throw("bad dtype"); diff --git a/dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp b/dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp index 52de1335e80c853ceaa4c7d5495e94d31469e55a..16ca74b0cd99d108e2b1148baa207ec2e98ed134 100644 --- a/dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp +++ b/dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp @@ -88,6 +88,7 @@ void dispatch_exec(HandleImpl *handle, } switch (data.layout.dtype.enumv()) { MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) default: megdnn_throw(megdnn_mangle("bad dtype")); } diff --git a/imperative/python/test/unit/core/test_indexing_op.py b/imperative/python/test/unit/core/test_indexing_op.py index e369ea08cb6f78c5b895b8bb144ee971f7ce3ae7..80478d63a7acdcb5e7acb29d1463d62cbd1aa9d6 100644 --- a/imperative/python/test/unit/core/test_indexing_op.py +++ b/imperative/python/test/unit/core/test_indexing_op.py @@ -519,6 +519,18 @@ def test_advance_indexing_with_bool(): np.testing.assert_equal(a[b], aa[bb].numpy()) np.testing.assert_equal(a[:, [True, False]], aa[:, [True, False]].numpy()) + a = np.array([[True, False], [False, True]]) + b = np.array([1]) + aa = Tensor(a) + bb = Tensor(b) + np.testing.assert_equal(a[b], aa[bb].numpy()) + b = np.array([[True, True], [False, True]]) + bb = Tensor(b) + np.testing.assert_equal(a[b], aa[bb].numpy()) + a[b] = False + aa[bb] = False + np.testing.assert_equal(a, aa.numpy()) + # XXX: trace does not expect empty condtake tensor if not use_tensor_shape(): a = np.ones((2, 2), dtype=np.int32)