From 4b0ecb5deb38f0dcc6e37f042bfc1be98e04ff8a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 14 Jan 2022 13:15:07 +0800 Subject: [PATCH] fix(ops/recv): use std::vector to store shape to support scalar GitOrigin-RevId: e1dac3c9199539b03609838908626f15b70de555 --- imperative/src/impl/ops/io_remote.cpp | 7 ++++++- imperative/src/test/io_remote.cpp | 2 +- src/core/include/megbrain/ir/ops.td | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/imperative/src/impl/ops/io_remote.cpp b/imperative/src/impl/ops/io_remote.cpp index 03e4d58ab..255e65d75 100644 --- a/imperative/src/impl/ops/io_remote.cpp +++ b/imperative/src/impl/ops/io_remote.cpp @@ -47,8 +47,13 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( auto group_client = std::make_shared( ssprintf("%s:%d", recv.addr.data(), recv.port)); auto&& graph = inputs[0]->owner_graph(); + mgb_assert(!recv.shape.empty()); + TensorShape shape; + for (auto&& dim : recv.shape) { + shape[shape.ndim++] = dim; + } return graph->insert_opr(std::make_unique( - recv.key, inputs[0], *graph, group_client, config, recv.shape, recv.dtype, + recv.key, inputs[0], *graph, group_client, config, shape, recv.dtype, recv.backend)); } diff --git a/imperative/src/test/io_remote.cpp b/imperative/src/test/io_remote.cpp index 97a7b62dd..fa3a0a21a 100644 --- a/imperative/src/test/io_remote.cpp +++ b/imperative/src/test/io_remote.cpp @@ -42,7 +42,7 @@ TEST(TestImperative, IORemote) { auto run_recv = [&](std::shared_ptr hnd) { auto def = imperative::RemoteRecv::make( "io_remote_test", server_addr, port, 0, CompNode::load("gpu1"), - TensorShape{vector_size}, dtype::Float32(), "nccl"); + std::vector{(int32_t)vector_size}, dtype::Float32(), "nccl"); auto inp = Tensor::make(*hnd); auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); HostTensorND host_v; diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 30c795ad9..75eb7e3e3 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -284,7 +284,7 @@ def RemoteRecv : MgbHashableOp<"RemoteRecv"> { MgbUI32Attr:$port, MgbUI32Attr:$rank_from, MgbCompNodeAttr:$cn, - MgbTensorShapeAttr:$shape, + MgbArrayAttr:$shape, MgbDTypeAttr:$dtype, MgbStringAttr:$backend ); -- GitLab