From 912d733ea9fcb7d4ac38c950a1bd22e340432b83 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 10 Oct 2020 12:46:37 +0800 Subject: [PATCH] fix(dnn): support bool for IndexingMultiAxisVec GitOrigin-RevId: ddcfaa06b0a1eefb1964e28ebc0b8022e482eb5e --- .../kern_apply_opr_impl.cuinl | 1 + .../indexing_multi_axis_vec/kern_apply_opr_incr.cu | 5 +++++ dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp | 1 + dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp | 1 + imperative/python/test/unit/core/test_indexing_op.py | 12 ++++++++++++ 5 files changed, 20 insertions(+) diff --git a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl index a640d865e..de8ec0338 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl +++ b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl @@ -72,6 +72,7 @@ namespace indexing_multi_axis_vec { #define cb0(_dtype) \ MEGDNN_FOREACH_TENSOR_NDIM(INST, DTypeTrait<_dtype>::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE(cb0) + cb0(::megdnn::dtype::Bool) #undef cb0 #undef INST diff --git a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu index 74874f23d..cfefef12e 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu +++ b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu @@ -39,6 +39,11 @@ __device__ void atomicAdd(megdnn::dt_int16 *, megdnn::dt_int16) { ((int*)0)[0] = 1; } +__device__ void atomicAdd(megdnn::dt_bool *, megdnn::dt_bool) { + __trap(); + ((int*)0)[0] = 1; +} + #define KERN_APPLY_OPR_OPR \ ::megdnn::cuda::indexing_multi_axis_vec::OprAtomicIncr #include "./kern_apply_opr_impl.cuinl" diff --git a/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp b/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp index 4e8649051..85b834681 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp +++ b/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp @@ -120,6 +120,7 @@ void ExecImpl::dispatch_exec() { case DTypeTrait<_dtype>::enumv: \ return dispatch_exec_ctype::ctype>(); MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) #undef cb default: megdnn_throw("bad dtype"); diff --git a/dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp b/dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp index 52de1335e..16ca74b0c 100644 --- a/dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp +++ b/dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp @@ -88,6 +88,7 @@ void dispatch_exec(HandleImpl *handle, } switch (data.layout.dtype.enumv()) { MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) default: megdnn_throw(megdnn_mangle("bad dtype")); } diff --git a/imperative/python/test/unit/core/test_indexing_op.py b/imperative/python/test/unit/core/test_indexing_op.py index e369ea08c..80478d63a 100644 --- a/imperative/python/test/unit/core/test_indexing_op.py +++ b/imperative/python/test/unit/core/test_indexing_op.py @@ -519,6 +519,18 @@ def test_advance_indexing_with_bool(): np.testing.assert_equal(a[b], aa[bb].numpy()) np.testing.assert_equal(a[:, [True, False]], aa[:, [True, False]].numpy()) + a = np.array([[True, False], [False, True]]) + b = np.array([1]) + aa = Tensor(a) + bb = Tensor(b) + np.testing.assert_equal(a[b], aa[bb].numpy()) + b = np.array([[True, True], [False, True]]) + bb = Tensor(b) + np.testing.assert_equal(a[b], aa[bb].numpy()) + a[b] = False + aa[bb] = False + np.testing.assert_equal(a, aa.numpy()) + # XXX: trace does not expect empty condtake tensor if not use_tensor_shape(): a = np.ones((2, 2), dtype=np.int32) -- GitLab