提交 cd8ab9e3 编写于 作者: M Megvii Engine Team

test(mgb/opr-mm): add io_remote test

GitOrigin-RevId: c47b6156fe671e87c698fbb3793f0b4be50a01fc
上级 4e0054f7
......@@ -59,7 +59,31 @@ const auto recv_tag = RemoteIOBase::Type::RECV;
} // anonymous namespace
TEST(TestOprIORemote, Identity) {
REQUIRE_GPU(2);
auto cn0 = CompNode::load("gpu0");
auto cn1 = CompNode::load("gpu1");
HostTensorGenerator<> gen;
auto host_x = gen({28, 28});
HostTensorND host_y;
auto client = std::make_shared<MockGroupClient>();
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0);
auto xr = opr::RemoteSend::make({"x", send_tag, false}, x, client);
auto y = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(),
client, {cn1}, host_x->shape(),
host_x->dtype());
auto func = graph->compile({{xr, {}}, make_callback_copy(y, host_y)});
func->execute();
MGB_ASSERT_TENSOR_EQ(*host_x, host_y);
}
TEST(TestOprIORemote, IdentityMultiThread) {
auto cns = load_multiple_xpus(2);
HostTensorGenerator<> gen;
auto host_x = gen({2, 3}, cns[1]);
......@@ -67,6 +91,7 @@ TEST(TestOprIORemote, Identity) {
auto client = std::make_shared<MockGroupClient>();
auto sender = [&]() {
auto graph = ComputingGraph::make();
sys::set_thread_name("sender");
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
xr = opr::RemoteSend::make({"x", send_tag, false}, x, client);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册