From 96ec586d28a6b25c13f9a87db6111f50f4f87eea Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 17 Aug 2020 11:44:59 +0800 Subject: [PATCH] fix(dnn): fix bool cvt GitOrigin-RevId: 2f883dcbe005a8a8cd43e9051d1d0a59b8e61c0e --- dnn/src/cuda/cond_take/opr_impl.cpp | 1 + dnn/src/cuda/type_cvt/kern.cu | 6 ++---- dnn/src/cuda/type_cvt/opr_impl.cpp | 1 + src/opr/test/basic_arith/others.cpp | 21 +++++++++++++++++++++ 4 files changed, 25 insertions(+), 4 deletions(-) diff --git a/dnn/src/cuda/cond_take/opr_impl.cpp b/dnn/src/cuda/cond_take/opr_impl.cpp index 4e5191c7..5dde8afc 100644 --- a/dnn/src/cuda/cond_take/opr_impl.cpp +++ b/dnn/src/cuda/cond_take/opr_impl.cpp @@ -59,6 +59,7 @@ CondTakeImpl::Output CondTakeImpl::exec( break; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) #undef cb default: megdnn_throw("bad mask dtype"); diff --git a/dnn/src/cuda/type_cvt/kern.cu b/dnn/src/cuda/type_cvt/kern.cu index 1d9131eb..5b4d75dd 100644 --- a/dnn/src/cuda/type_cvt/kern.cu +++ b/dnn/src/cuda/type_cvt/kern.cu @@ -111,8 +111,7 @@ struct TypeCvtOpFromQuantized< ctype_dest, ctype_src, typename std::enable_if< std::is_same::value || - std::is_same::value || - std::is_same::value>::type> { + std::is_same::value>::type> { ctype_dest* dest; CudaDTypeParam param; using src_vect_type = typename VectTypeTrait::vect_type; @@ -140,8 +139,7 @@ struct TypeCvtOpBetweenQuantized< ctype_dest, ctype_src, typename std::enable_if< std::is_same::value || - std::is_same::value || - std::is_same::value>::type> { + std::is_same::value>::type> { ctype_dest* dest; CudaDTypeParam src_param; CudaDTypeParam dst_param; diff --git a/dnn/src/cuda/type_cvt/opr_impl.cpp b/dnn/src/cuda/type_cvt/opr_impl.cpp index 685fe873..201c4b49 100644 --- a/dnn/src/cuda/type_cvt/opr_impl.cpp +++ b/dnn/src/cuda/type_cvt/opr_impl.cpp @@ -109,6 +109,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) #undef cb default: megdnn_assert_internal(0); diff --git a/src/opr/test/basic_arith/others.cpp b/src/opr/test/basic_arith/others.cpp index 69a3894b..a3a617af 100644 --- a/src/opr/test/basic_arith/others.cpp +++ b/src/opr/test/basic_arith/others.cpp @@ -568,6 +568,27 @@ TEST(TestOprBasicArith, TypeCvtBool) { ASSERT_EQ(TensorShape({3}), host_y.shape()); } +TEST(TestOprBasicArith, TypeCvtFromBool) { + auto graph = ComputingGraph::make(); + HostTensorGenerator gen; + auto host_x = gen({2}); + auto px = host_x->ptr(); + px[0] = true; + px[1] = false; + + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + y = opr::TypeCvt::make(x, dtype::Int32{}); + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + + auto py = host_y.ptr(); + for (size_t i = 0;i < 2;i ++) { + ASSERT_EQ(static_cast(px[i]), py[i]); + } + ASSERT_EQ(TensorShape({2}), host_y.shape()); +} + TEST(TestOprBasicArith, ElemwiseMemFwd) { auto graph = ComputingGraph::make(); graph->options().graph_opt_level = 0; -- GitLab