提交 4b0ecb5d 编写于 作者: M Megvii Engine Team

fix(ops/recv): use std::vector to store shape to support scalar

GitOrigin-RevId: e1dac3c9199539b03609838908626f15b70de555
上级 c3d63f14
......@@ -47,8 +47,13 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv(
auto group_client = std::make_shared<opr::GroupClientProxy>(
ssprintf("%s:%d", recv.addr.data(), recv.port));
auto&& graph = inputs[0]->owner_graph();
mgb_assert(!recv.shape.empty());
TensorShape shape;
for (auto&& dim : recv.shape) {
shape[shape.ndim++] = dim;
}
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>(
recv.key, inputs[0], *graph, group_client, config, recv.shape, recv.dtype,
recv.key, inputs[0], *graph, group_client, config, shape, recv.dtype,
recv.backend));
}
......
......@@ -42,7 +42,7 @@ TEST(TestImperative, IORemote) {
auto run_recv = [&](std::shared_ptr<HostTensorND> hnd) {
auto def = imperative::RemoteRecv::make(
"io_remote_test", server_addr, port, 0, CompNode::load("gpu1"),
TensorShape{vector_size}, dtype::Float32(), "nccl");
std::vector<int32_t>{(int32_t)vector_size}, dtype::Float32(), "nccl");
auto inp = Tensor::make(*hnd);
auto oup = OpDef::apply_on_physical_tensor(*def, {inp});
HostTensorND host_v;
......
......@@ -284,7 +284,7 @@ def RemoteRecv : MgbHashableOp<"RemoteRecv"> {
MgbUI32Attr:$port,
MgbUI32Attr:$rank_from,
MgbCompNodeAttr:$cn,
MgbTensorShapeAttr:$shape,
MgbArrayAttr<MgbI32Attr>:$shape,
MgbDTypeAttr:$dtype,
MgbStringAttr:$backend
);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册