提交 618c77c0 编写于 作者: M Megvii Engine Team 提交者: “wenjuan”

feat(dnn): enable eye to support bool

GitOrigin-RevId: 76d874d5b7b32efbaf4254b49b43e6ad5ef1213d
上级 a8941af0
...@@ -39,6 +39,7 @@ void exec_internal(T* dst, size_t m, size_t n, int k, cudaStream_t stream) { ...@@ -39,6 +39,7 @@ void exec_internal(T* dst, size_t m, size_t n, int k, cudaStream_t stream) {
#define INST(T) template void exec_internal<T>(T*, size_t, size_t, int, cudaStream_t); #define INST(T) template void exec_internal<T>(T*, size_t, size_t, int, cudaStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) #define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
} // namespace eye } // namespace eye
} // namespace cuda } // namespace cuda
......
...@@ -26,6 +26,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { ...@@ -26,6 +26,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) {
cuda_stream(handle())); \ cuda_stream(handle())); \
} }
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb #undef cb
} }
......
...@@ -31,6 +31,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { ...@@ -31,6 +31,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) {
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>(dst.ptr<ctype>(), m, n)); \ MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>(dst.ptr<ctype>(), m, n)); \
} }
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb #undef cb
} }
......
...@@ -44,7 +44,7 @@ void exec_internal(T* dst, size_t m, size_t n, int k, hipStream_t stream) { ...@@ -44,7 +44,7 @@ void exec_internal(T* dst, size_t m, size_t n, int k, hipStream_t stream) {
template void exec_internal<T>(T*, size_t, size_t, int, hipStream_t); template void exec_internal<T>(T*, size_t, size_t, int, hipStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) #define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
} // namespace eye } // namespace eye
} // namespace rocm } // namespace rocm
} // namespace megdnn } // namespace megdnn
......
...@@ -27,6 +27,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { ...@@ -27,6 +27,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) {
hip_stream(handle())); \ hip_stream(handle())); \
} }
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb #undef cb
} }
......
...@@ -24,21 +24,22 @@ from megengine.utils.network_node import VarNode ...@@ -24,21 +24,22 @@ from megengine.utils.network_node import VarNode
def test_eye(): def test_eye():
dtype = np.float32 dtypes = [np.float32, np.bool]
cases = [{"input": [10, 20]}, {"input": [30]}] cases = [{"input": [10, 20]}, {"input": [30]}]
for case in cases: for dtype in dtypes:
np.testing.assert_allclose( for case in cases:
F.eye(case["input"], dtype=dtype).numpy(), np.testing.assert_allclose(
np.eye(*case["input"]).astype(dtype), F.eye(case["input"], dtype=dtype).numpy(),
) np.eye(*case["input"]).astype(dtype),
np.testing.assert_allclose( )
F.eye(*case["input"], dtype=dtype).numpy(), np.testing.assert_allclose(
np.eye(*case["input"]).astype(dtype), F.eye(*case["input"], dtype=dtype).numpy(),
) np.eye(*case["input"]).astype(dtype),
np.testing.assert_allclose( )
F.eye(tensor(case["input"]), dtype=dtype).numpy(), np.testing.assert_allclose(
np.eye(*case["input"]).astype(dtype), F.eye(tensor(case["input"]), dtype=dtype).numpy(),
) np.eye(*case["input"]).astype(dtype),
)
def test_full(): def test_full():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册