提交 4b0ecb5d 编写于 作者: M Megvii Engine Team

fix(ops/recv): use std::vector to store shape to support scalar

GitOrigin-RevId: e1dac3c9199539b03609838908626f15b70de555
上级 c3d63f14
...@@ -47,8 +47,13 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( ...@@ -47,8 +47,13 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv(
auto group_client = std::make_shared<opr::GroupClientProxy>( auto group_client = std::make_shared<opr::GroupClientProxy>(
ssprintf("%s:%d", recv.addr.data(), recv.port)); ssprintf("%s:%d", recv.addr.data(), recv.port));
auto&& graph = inputs[0]->owner_graph(); auto&& graph = inputs[0]->owner_graph();
mgb_assert(!recv.shape.empty());
TensorShape shape;
for (auto&& dim : recv.shape) {
shape[shape.ndim++] = dim;
}
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>(
recv.key, inputs[0], *graph, group_client, config, recv.shape, recv.dtype, recv.key, inputs[0], *graph, group_client, config, shape, recv.dtype,
recv.backend)); recv.backend));
} }
......
...@@ -42,7 +42,7 @@ TEST(TestImperative, IORemote) { ...@@ -42,7 +42,7 @@ TEST(TestImperative, IORemote) {
auto run_recv = [&](std::shared_ptr<HostTensorND> hnd) { auto run_recv = [&](std::shared_ptr<HostTensorND> hnd) {
auto def = imperative::RemoteRecv::make( auto def = imperative::RemoteRecv::make(
"io_remote_test", server_addr, port, 0, CompNode::load("gpu1"), "io_remote_test", server_addr, port, 0, CompNode::load("gpu1"),
TensorShape{vector_size}, dtype::Float32(), "nccl"); std::vector<int32_t>{(int32_t)vector_size}, dtype::Float32(), "nccl");
auto inp = Tensor::make(*hnd); auto inp = Tensor::make(*hnd);
auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); auto oup = OpDef::apply_on_physical_tensor(*def, {inp});
HostTensorND host_v; HostTensorND host_v;
......
...@@ -284,7 +284,7 @@ def RemoteRecv : MgbHashableOp<"RemoteRecv"> { ...@@ -284,7 +284,7 @@ def RemoteRecv : MgbHashableOp<"RemoteRecv"> {
MgbUI32Attr:$port, MgbUI32Attr:$port,
MgbUI32Attr:$rank_from, MgbUI32Attr:$rank_from,
MgbCompNodeAttr:$cn, MgbCompNodeAttr:$cn,
MgbTensorShapeAttr:$shape, MgbArrayAttr<MgbI32Attr>:$shape,
MgbDTypeAttr:$dtype, MgbDTypeAttr:$dtype,
MgbStringAttr:$backend MgbStringAttr:$backend
); );
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册