提交 912d733e 编写于 作者: M Megvii Engine Team

fix(dnn): support bool for IndexingMultiAxisVec

GitOrigin-RevId: ddcfaa06b0a1eefb1964e28ebc0b8022e482eb5e
上级 dacc4854
...@@ -72,6 +72,7 @@ namespace indexing_multi_axis_vec { ...@@ -72,6 +72,7 @@ namespace indexing_multi_axis_vec {
#define cb0(_dtype) \ #define cb0(_dtype) \
MEGDNN_FOREACH_TENSOR_NDIM(INST, DTypeTrait<_dtype>::ctype) MEGDNN_FOREACH_TENSOR_NDIM(INST, DTypeTrait<_dtype>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb0) MEGDNN_FOREACH_COMPUTING_DTYPE(cb0)
cb0(::megdnn::dtype::Bool)
#undef cb0 #undef cb0
#undef INST #undef INST
......
...@@ -39,6 +39,11 @@ __device__ void atomicAdd(megdnn::dt_int16 *, megdnn::dt_int16) { ...@@ -39,6 +39,11 @@ __device__ void atomicAdd(megdnn::dt_int16 *, megdnn::dt_int16) {
((int*)0)[0] = 1; ((int*)0)[0] = 1;
} }
__device__ void atomicAdd(megdnn::dt_bool *, megdnn::dt_bool) {
__trap();
((int*)0)[0] = 1;
}
#define KERN_APPLY_OPR_OPR \ #define KERN_APPLY_OPR_OPR \
::megdnn::cuda::indexing_multi_axis_vec::OprAtomicIncr ::megdnn::cuda::indexing_multi_axis_vec::OprAtomicIncr
#include "./kern_apply_opr_impl.cuinl" #include "./kern_apply_opr_impl.cuinl"
......
...@@ -120,6 +120,7 @@ void ExecImpl<Opr>::dispatch_exec() { ...@@ -120,6 +120,7 @@ void ExecImpl<Opr>::dispatch_exec() {
case DTypeTrait<_dtype>::enumv: \ case DTypeTrait<_dtype>::enumv: \
return dispatch_exec_ctype<DTypeTrait<_dtype>::ctype>(); return dispatch_exec_ctype<DTypeTrait<_dtype>::ctype>();
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb #undef cb
default: default:
megdnn_throw("bad dtype"); megdnn_throw("bad dtype");
......
...@@ -88,6 +88,7 @@ void dispatch_exec(HandleImpl *handle, ...@@ -88,6 +88,7 @@ void dispatch_exec(HandleImpl *handle,
} }
switch (data.layout.dtype.enumv()) { switch (data.layout.dtype.enumv()) {
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
default: default:
megdnn_throw(megdnn_mangle("bad dtype")); megdnn_throw(megdnn_mangle("bad dtype"));
} }
......
...@@ -519,6 +519,18 @@ def test_advance_indexing_with_bool(): ...@@ -519,6 +519,18 @@ def test_advance_indexing_with_bool():
np.testing.assert_equal(a[b], aa[bb].numpy()) np.testing.assert_equal(a[b], aa[bb].numpy())
np.testing.assert_equal(a[:, [True, False]], aa[:, [True, False]].numpy()) np.testing.assert_equal(a[:, [True, False]], aa[:, [True, False]].numpy())
a = np.array([[True, False], [False, True]])
b = np.array([1])
aa = Tensor(a)
bb = Tensor(b)
np.testing.assert_equal(a[b], aa[bb].numpy())
b = np.array([[True, True], [False, True]])
bb = Tensor(b)
np.testing.assert_equal(a[b], aa[bb].numpy())
a[b] = False
aa[bb] = False
np.testing.assert_equal(a, aa.numpy())
# XXX: trace does not expect empty condtake tensor # XXX: trace does not expect empty condtake tensor
if not use_tensor_shape(): if not use_tensor_shape():
a = np.ones((2, 2), dtype=np.int32) a = np.ones((2, 2), dtype=np.int32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册