From 1f0cc891b0c3d7270d18883e75514bdd3529ecd7 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 3 Nov 2021 13:52:44 +0800 Subject: [PATCH] feat(dnn): enable eye to support bool GitOrigin-RevId: 76d874d5b7b32efbaf4254b49b43e6ad5ef1213d --- dnn/src/cuda/eye/eye.cu | 1 + dnn/src/cuda/eye/opr_impl.cpp | 1 + dnn/src/naive/eye/opr_impl.cpp | 1 + dnn/src/rocm/eye/eye.cpp.hip | 2 +- dnn/src/rocm/eye/opr_impl.cpp | 1 + .../test/unit/functional/test_tensor.py | 29 ++++++++++--------- 6 files changed, 20 insertions(+), 15 deletions(-) diff --git a/dnn/src/cuda/eye/eye.cu b/dnn/src/cuda/eye/eye.cu index 706671434..75dcb60d0 100644 --- a/dnn/src/cuda/eye/eye.cu +++ b/dnn/src/cuda/eye/eye.cu @@ -39,6 +39,7 @@ void exec_internal(T* dst, size_t m, size_t n, int k, cudaStream_t stream) { #define INST(T) template void exec_internal(T*, size_t, size_t, int, cudaStream_t); #define cb(DType) INST(typename DTypeTrait::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +cb(::megdnn::dtype::Bool) } // namespace eye } // namespace cuda diff --git a/dnn/src/cuda/eye/opr_impl.cpp b/dnn/src/cuda/eye/opr_impl.cpp index fd23b0ee9..a50ab28cb 100644 --- a/dnn/src/cuda/eye/opr_impl.cpp +++ b/dnn/src/cuda/eye/opr_impl.cpp @@ -26,6 +26,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { cuda_stream(handle())); \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) #undef cb } diff --git a/dnn/src/naive/eye/opr_impl.cpp b/dnn/src/naive/eye/opr_impl.cpp index 3e6dd1a9b..460409b68 100644 --- a/dnn/src/naive/eye/opr_impl.cpp +++ b/dnn/src/naive/eye/opr_impl.cpp @@ -31,6 +31,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal(dst.ptr(), m, n)); \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) #undef cb } diff --git a/dnn/src/rocm/eye/eye.cpp.hip b/dnn/src/rocm/eye/eye.cpp.hip index dc569af66..4aa01992f 100644 --- a/dnn/src/rocm/eye/eye.cpp.hip +++ b/dnn/src/rocm/eye/eye.cpp.hip @@ -44,7 +44,7 @@ void exec_internal(T* dst, size_t m, size_t n, int k, hipStream_t stream) { template void exec_internal(T*, size_t, size_t, int, hipStream_t); #define cb(DType) INST(typename DTypeTrait::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE(cb) - +cb(::megdnn::dtype::Bool) } // namespace eye } // namespace rocm } // namespace megdnn diff --git a/dnn/src/rocm/eye/opr_impl.cpp b/dnn/src/rocm/eye/opr_impl.cpp index 6ea49197f..a6c4dbb1a 100644 --- a/dnn/src/rocm/eye/opr_impl.cpp +++ b/dnn/src/rocm/eye/opr_impl.cpp @@ -27,6 +27,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { hip_stream(handle())); \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) #undef cb } diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 3f014325f..d45944f40 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -24,21 +24,22 @@ from megengine.utils.network_node import VarNode def test_eye(): - dtype = np.float32 + dtypes = [np.float32, np.bool] cases = [{"input": [10, 20]}, {"input": [30]}] - for case in cases: - np.testing.assert_allclose( - F.eye(case["input"], dtype=dtype).numpy(), - np.eye(*case["input"]).astype(dtype), - ) - np.testing.assert_allclose( - F.eye(*case["input"], dtype=dtype).numpy(), - np.eye(*case["input"]).astype(dtype), - ) - np.testing.assert_allclose( - F.eye(tensor(case["input"]), dtype=dtype).numpy(), - np.eye(*case["input"]).astype(dtype), - ) + for dtype in dtypes: + for case in cases: + np.testing.assert_allclose( + F.eye(case["input"], dtype=dtype).numpy(), + np.eye(*case["input"]).astype(dtype), + ) + np.testing.assert_allclose( + F.eye(*case["input"], dtype=dtype).numpy(), + np.eye(*case["input"]).astype(dtype), + ) + np.testing.assert_allclose( + F.eye(tensor(case["input"]), dtype=dtype).numpy(), + np.eye(*case["input"]).astype(dtype), + ) def test_full(): -- GitLab