提交 618c77c0 编写于 作者: M Megvii Engine Team 提交者: “wenjuan”

feat(dnn): enable eye to support bool

GitOrigin-RevId: 76d874d5b7b32efbaf4254b49b43e6ad5ef1213d
上级 a8941af0
......@@ -39,6 +39,7 @@ void exec_internal(T* dst, size_t m, size_t n, int k, cudaStream_t stream) {
#define INST(T) template void exec_internal<T>(T*, size_t, size_t, int, cudaStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
} // namespace eye
} // namespace cuda
......
......@@ -26,6 +26,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) {
cuda_stream(handle())); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
}
......
......@@ -31,6 +31,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) {
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>(dst.ptr<ctype>(), m, n)); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
}
......
......@@ -44,7 +44,7 @@ void exec_internal(T* dst, size_t m, size_t n, int k, hipStream_t stream) {
template void exec_internal<T>(T*, size_t, size_t, int, hipStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
} // namespace eye
} // namespace rocm
} // namespace megdnn
......
......@@ -27,6 +27,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) {
hip_stream(handle())); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
}
......
......@@ -24,21 +24,22 @@ from megengine.utils.network_node import VarNode
def test_eye():
dtype = np.float32
dtypes = [np.float32, np.bool]
cases = [{"input": [10, 20]}, {"input": [30]}]
for case in cases:
np.testing.assert_allclose(
F.eye(case["input"], dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)
np.testing.assert_allclose(
F.eye(*case["input"], dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)
np.testing.assert_allclose(
F.eye(tensor(case["input"]), dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)
for dtype in dtypes:
for case in cases:
np.testing.assert_allclose(
F.eye(case["input"], dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)
np.testing.assert_allclose(
F.eye(*case["input"], dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)
np.testing.assert_allclose(
F.eye(tensor(case["input"]), dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)
def test_full():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册