提交 96ec586d 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(dnn): fix bool cvt

GitOrigin-RevId: 2f883dcbe005a8a8cd43e9051d1d0a59b8e61c0e
上级 f26cd398
......@@ -59,6 +59,7 @@ CondTakeImpl::Output CondTakeImpl::exec(
break; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
default:
megdnn_throw("bad mask dtype");
......
......@@ -111,8 +111,7 @@ struct TypeCvtOpFromQuantized<
ctype_dest, ctype_src,
typename std::enable_if<
std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_quint8>::value ||
std::is_same<ctype_src, dt_bool>::value>::type> {
std::is_same<ctype_src, dt_quint8>::value>::type> {
ctype_dest* dest;
CudaDTypeParam<ctype_src> param;
using src_vect_type = typename VectTypeTrait<ctype_src>::vect_type;
......@@ -140,8 +139,7 @@ struct TypeCvtOpBetweenQuantized<
ctype_dest, ctype_src,
typename std::enable_if<
std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_quint8>::value ||
std::is_same<ctype_src, dt_bool>::value>::type> {
std::is_same<ctype_src, dt_quint8>::value>::type> {
ctype_dest* dest;
CudaDTypeParam<ctype_src> src_param;
CudaDTypeParam<ctype_dest> dst_param;
......
......@@ -109,6 +109,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
default:
megdnn_assert_internal(0);
......
......@@ -568,6 +568,27 @@ TEST(TestOprBasicArith, TypeCvtBool) {
ASSERT_EQ(TensorShape({3}), host_y.shape());
}
TEST(TestOprBasicArith, TypeCvtFromBool) {
auto graph = ComputingGraph::make();
HostTensorGenerator<dtype::Bool> gen;
auto host_x = gen({2});
auto px = host_x->ptr<bool>();
px[0] = true;
px[1] = false;
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::TypeCvt::make(x, dtype::Int32{});
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
auto py = host_y.ptr<int>();
for (size_t i = 0;i < 2;i ++) {
ASSERT_EQ(static_cast<int>(px[i]), py[i]);
}
ASSERT_EQ(TensorShape({2}), host_y.shape());
}
TEST(TestOprBasicArith, ElemwiseMemFwd) {
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册