diff --git a/dnn/src/cuda/cond_take/opr_impl.cpp b/dnn/src/cuda/cond_take/opr_impl.cpp index 4e5191c7d0ee1dec6a92057576bed041acff0d3a..5dde8afc3b3ed54a0ea03220db63baae231328b5 100644 --- a/dnn/src/cuda/cond_take/opr_impl.cpp +++ b/dnn/src/cuda/cond_take/opr_impl.cpp @@ -59,6 +59,7 @@ CondTakeImpl::Output CondTakeImpl::exec( break; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) #undef cb default: megdnn_throw("bad mask dtype"); diff --git a/dnn/src/cuda/type_cvt/kern.cu b/dnn/src/cuda/type_cvt/kern.cu index 1d9131eb14851659d1ef7b3dcaf093cc74f007ac..5b4d75dd1eac8c79c238241b92f2b98e66e818c2 100644 --- a/dnn/src/cuda/type_cvt/kern.cu +++ b/dnn/src/cuda/type_cvt/kern.cu @@ -111,8 +111,7 @@ struct TypeCvtOpFromQuantized< ctype_dest, ctype_src, typename std::enable_if< std::is_same::value || - std::is_same::value || - std::is_same::value>::type> { + std::is_same::value>::type> { ctype_dest* dest; CudaDTypeParam param; using src_vect_type = typename VectTypeTrait::vect_type; @@ -140,8 +139,7 @@ struct TypeCvtOpBetweenQuantized< ctype_dest, ctype_src, typename std::enable_if< std::is_same::value || - std::is_same::value || - std::is_same::value>::type> { + std::is_same::value>::type> { ctype_dest* dest; CudaDTypeParam src_param; CudaDTypeParam dst_param; diff --git a/dnn/src/cuda/type_cvt/opr_impl.cpp b/dnn/src/cuda/type_cvt/opr_impl.cpp index 685fe8738b100b2ac9922ace30d90c7ad20df875..201c4b496dc3e734057d92c3c17de21b91d17012 100644 --- a/dnn/src/cuda/type_cvt/opr_impl.cpp +++ b/dnn/src/cuda/type_cvt/opr_impl.cpp @@ -109,6 +109,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) #undef cb default: megdnn_assert_internal(0); diff --git a/src/opr/test/basic_arith/others.cpp b/src/opr/test/basic_arith/others.cpp index 69a3894bc81b67a3d29b8459d807e3de6b8c4a73..a3a617af7ef8a6844acb230150a31bf5212dd468 100644 --- a/src/opr/test/basic_arith/others.cpp +++ b/src/opr/test/basic_arith/others.cpp @@ -568,6 +568,27 @@ TEST(TestOprBasicArith, TypeCvtBool) { ASSERT_EQ(TensorShape({3}), host_y.shape()); } +TEST(TestOprBasicArith, TypeCvtFromBool) { + auto graph = ComputingGraph::make(); + HostTensorGenerator gen; + auto host_x = gen({2}); + auto px = host_x->ptr(); + px[0] = true; + px[1] = false; + + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + y = opr::TypeCvt::make(x, dtype::Int32{}); + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + + auto py = host_y.ptr(); + for (size_t i = 0;i < 2;i ++) { + ASSERT_EQ(static_cast(px[i]), py[i]); + } + ASSERT_EQ(TensorShape({2}), host_y.shape()); +} + TEST(TestOprBasicArith, ElemwiseMemFwd) { auto graph = ComputingGraph::make(); graph->options().graph_opt_level = 0;