提交 ff60fdb8 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(dnn): add bool type cvt on gpu

GitOrigin-RevId: ab0fecf368b86bd71035b086dea175a4b1181c21
上级 e8571cca
......@@ -73,6 +73,7 @@ void exec_src_normal(const TensorND& dst, const TensorND& src,
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb);
cb(::megdnn::dtype::Bool);
#undef cb
default:
megdnn_assert_internal(0);
......
......@@ -546,6 +546,28 @@ TEST(TestOprBasicArith, TypeCvt) {
ASSERT_EQ(TensorShape({3, 0}), host_y.shape());
}
TEST(TestOprBasicArith, TypeCvtBool) {
auto graph = ComputingGraph::make();
HostTensorGenerator<dtype::Int32> gen;
auto host_x = gen({3});
auto px = host_x->ptr<int>();
px[0] = -1;
px[1] = 0;
px[2] = 1;
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::TypeCvt::make(x, dtype::Bool{});
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
auto py = host_y.ptr<bool>();
for (size_t i = 0;i < 3;i ++) {
ASSERT_EQ(static_cast<bool>(px[i]), py[i]);
}
ASSERT_EQ(TensorShape({3}), host_y.shape());
}
TEST(TestOprBasicArith, ElemwiseMemFwd) {
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册