From ff60fdb82d86baa00a483421ad0a4c260687d60f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 13 Aug 2020 16:32:26 +0800 Subject: [PATCH] feat(dnn): add bool type cvt on gpu GitOrigin-RevId: ab0fecf368b86bd71035b086dea175a4b1181c21 --- dnn/src/cuda/type_cvt/opr_impl.cpp | 1 + src/opr/test/basic_arith/others.cpp | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/dnn/src/cuda/type_cvt/opr_impl.cpp b/dnn/src/cuda/type_cvt/opr_impl.cpp index 5dde7e5fb..685fe8738 100644 --- a/dnn/src/cuda/type_cvt/opr_impl.cpp +++ b/dnn/src/cuda/type_cvt/opr_impl.cpp @@ -73,6 +73,7 @@ void exec_src_normal(const TensorND& dst, const TensorND& src, return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb); + cb(::megdnn::dtype::Bool); #undef cb default: megdnn_assert_internal(0); diff --git a/src/opr/test/basic_arith/others.cpp b/src/opr/test/basic_arith/others.cpp index aec4238f5..69a3894bc 100644 --- a/src/opr/test/basic_arith/others.cpp +++ b/src/opr/test/basic_arith/others.cpp @@ -546,6 +546,28 @@ TEST(TestOprBasicArith, TypeCvt) { ASSERT_EQ(TensorShape({3, 0}), host_y.shape()); } +TEST(TestOprBasicArith, TypeCvtBool) { + auto graph = ComputingGraph::make(); + HostTensorGenerator gen; + auto host_x = gen({3}); + auto px = host_x->ptr(); + px[0] = -1; + px[1] = 0; + px[2] = 1; + + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + y = opr::TypeCvt::make(x, dtype::Bool{}); + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + + auto py = host_y.ptr(); + for (size_t i = 0;i < 3;i ++) { + ASSERT_EQ(static_cast(px[i]), py[i]); + } + ASSERT_EQ(TensorShape({3}), host_y.shape()); +} + TEST(TestOprBasicArith, ElemwiseMemFwd) { auto graph = ComputingGraph::make(); graph->options().graph_opt_level = 0; -- GitLab