From 1a7112997c9ebab0fdeadee32e6a546993b2cf78 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 30 Mar 2021 15:31:14 +0800 Subject: [PATCH] feat(opr-mm): add backend argument for remote send/recv GitOrigin-RevId: 841a0e45ab2188a4a7414ff4a23b76e7b9852db7 --- .../megengine/distributed/functional.py | 2 ++ imperative/src/impl/ops/io_remote.cpp | 4 +-- imperative/src/test/io_remote.cpp | 4 +-- src/core/include/megbrain/ir/ops.td | 6 ++-- src/opr-mm/impl/io_remote.cpp | 35 ++++++++++--------- src/opr-mm/impl/io_remote.oprdecl | 6 +++- src/opr-mm/include/megbrain/opr/io_remote.h | 16 +++++---- src/opr-mm/test/io_remote.cpp | 32 ++++++++--------- 8 files changed, 59 insertions(+), 46 deletions(-) diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index b10ea5b3..f5c58db9 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -265,6 +265,7 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: op.key = key op.addr, op.port = get_mm_server_addr() op.rank_to = dest_rank + op.backend = get_backend() (dummy,) = apply(_RemoteSend(op), inp) for g in grad_keys.values(): @@ -313,6 +314,7 @@ def remote_recv( op.dtype = dtype op.addr, op.port = get_mm_server_addr() op.rank_from = src_rank + op.backend = get_backend() (ret,) = apply(_RemoteRecv(op), inp) if _isscalar: diff --git a/imperative/src/impl/ops/io_remote.cpp b/imperative/src/impl/ops/io_remote.cpp index ed0398d7..ce401958 100644 --- a/imperative/src/impl/ops/io_remote.cpp +++ b/imperative/src/impl/ops/io_remote.cpp @@ -35,7 +35,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send( OperatorNodeConfig config{send.make_name()}; cg::OperatorNodeBase* opr = graph->insert_opr(std::make_unique( - send.key, inputs[0], group_client, true, config)); + send.key, inputs[0], group_client, true, send.backend, config)); return opr; } @@ -49,7 +49,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( auto&& graph = inputs[0]->owner_graph(); return graph->insert_opr(std::make_unique( recv.key, inputs[0], *graph, group_client, config, - recv.shape, recv.dtype)); + recv.shape, recv.dtype, recv.backend)); } OP_TRAIT_REG(RemoteSend, RemoteSend, mgb::opr::RemoteSend) diff --git a/imperative/src/test/io_remote.cpp b/imperative/src/test/io_remote.cpp index 1e6054f0..05f3e899 100644 --- a/imperative/src/test/io_remote.cpp +++ b/imperative/src/test/io_remote.cpp @@ -34,7 +34,7 @@ TEST(TestImperative, IORemote) { auto run_send = [&](std::shared_ptr hnd) { auto def = imperative::RemoteSend::make( - "io_remote_test", server_addr, port, 1); + "io_remote_test", server_addr, port, 1, "nccl"); auto inp = Tensor::make(*hnd); auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); }; @@ -43,7 +43,7 @@ TEST(TestImperative, IORemote) { auto def = imperative::RemoteRecv::make( "io_remote_test", server_addr, port, 0, CompNode::load("gpu1"), TensorShape{vector_size}, - dtype::Float32()); + dtype::Float32(), "nccl"); auto inp = Tensor::make(*hnd); auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); HostTensorND host_v; diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index f8e15f92..01c91da3 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -169,7 +169,8 @@ def RemoteSend : MgbHashableOp<"RemoteSend"> { MgbStringAttr:$key, MgbStringAttr:$addr, MgbUI32Attr:$port, - MgbUI32Attr:$rank_to + MgbUI32Attr:$rank_to, + MgbStringAttr:$backend ); } @@ -181,7 +182,8 @@ def RemoteRecv : MgbHashableOp<"RemoteRecv"> { MgbUI32Attr:$rank_from, MgbCompNodeAttr:$cn, MgbTensorShapeAttr:$shape, - MgbDTypeAttr:$dtype + MgbDTypeAttr:$dtype, + MgbStringAttr:$backend ); } diff --git a/src/opr-mm/impl/io_remote.cpp b/src/opr-mm/impl/io_remote.cpp index ff5c7aa5..73d23b4e 100644 --- a/src/opr-mm/impl/io_remote.cpp +++ b/src/opr-mm/impl/io_remote.cpp @@ -24,8 +24,9 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); RemoteSend::RemoteSend(const std::string& key, VarNode* var, std::shared_ptr group_client, - bool is_grad, const OperatorNodeConfig& config) : + bool is_grad, std::string backend, const OperatorNodeConfig& config) : Super(var->owner_graph(), config, "remote_send", {var}), + m_backend(backend), m_is_grad(is_grad) { m_key = key; m_group_client = group_client; @@ -41,9 +42,9 @@ RemoteSend::RemoteSend(const std::string& key, VarNode* var, SymbolVar RemoteSend::make(const std::string& key, SymbolVar var, std::shared_ptr group_client, - bool is_grad, const OperatorNodeConfig& config) { + bool is_grad, std::string backend, const OperatorNodeConfig& config) { return var.insert_single_output_opr(key, var.node(), group_client, - is_grad, config); + is_grad, backend, config); } void RemoteSend::scn_do_execute() { @@ -64,7 +65,7 @@ void RemoteSend::scn_do_execute() { } m_megray_comm = MegRayCommBuilder::get_megray_comm( - reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client); + reg_info.hash, m_key, 2, 0, get_megray_backend(m_backend), m_group_client); m_megray_ctx = get_megray_context(output(0)->comp_node()); @@ -122,7 +123,7 @@ MGB_IMPL_OPR_GRAD(RemoteSend) { *opr.owner_graph(), opr.group_client(), OperatorNodeConfig{opr.comp_node()}.name( opr.name() + ":grad_recv"), - opr.input(0)->shape(), opr.input(0)->dtype()) + opr.input(0)->shape(), opr.input(0)->dtype(), opr.backend()) .node(); } #endif @@ -134,9 +135,9 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, std::shared_ptr group_client, const OperatorNodeConfig& config, - const TensorShape& shape, DType dtype) : + const TensorShape& shape, DType dtype, std::string backend) : Super(&graph, config, "remote_recv", {}), - m_shape(shape), m_dtype(dtype) { + m_shape(shape), m_dtype(dtype), m_backend(backend) { m_key = key; m_group_client = group_client; @@ -150,9 +151,9 @@ RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph, RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, std::shared_ptr group_client, const OperatorNodeConfig& config, - const TensorShape& shape, DType dtype) : + const TensorShape& shape, DType dtype, std::string backend) : Super(&graph, config, "remote_recv", {}), - m_shape(shape), m_dtype(dtype) { + m_shape(shape), m_dtype(dtype), m_backend(backend) { m_key = key; m_group_client = group_client; @@ -167,18 +168,18 @@ RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, std::shared_ptr group_client, const OperatorNodeConfig& config, - const TensorShape& shape, DType dtype) { + const TensorShape& shape, DType dtype, std::string backend) { auto opr = graph.insert_opr(std::make_unique( - key, graph, group_client, config, shape, dtype)); + key, graph, group_client, config, shape, dtype, backend)); return opr->output(0); } SymbolVar RemoteRecv::make(const std::string& key, SymbolVar var, cg::ComputingGraph& graph, std::shared_ptr group_client, const OperatorNodeConfig& config, - const TensorShape& shape, DType dtype) { + const TensorShape& shape, DType dtype, std::string backend) { auto opr = graph.insert_opr(std::make_unique( - key, var.node(), graph, group_client, config, shape, dtype)); + key, var.node(), graph, group_client, config, shape, dtype, backend)); return opr->output(0); } @@ -201,7 +202,7 @@ void RemoteRecv::scn_do_execute() { } m_megray_comm = MegRayCommBuilder::get_megray_comm( - reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client); + reg_info.hash, m_key, 2, 1, get_megray_backend(m_backend), m_group_client); m_megray_ctx = get_megray_context(output(0)->comp_node()); @@ -251,7 +252,7 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_send( mgb_assert(inputs.size() == 1); auto&& opr = opr_.cast_final_safe(); return RemoteSend::make(opr.key(), inputs[0], opr.group_client(), - opr.is_grad(), config) + opr.is_grad(), opr.backend(), config) .node() ->owner_opr(); } @@ -265,14 +266,14 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv( if (inputs.size() == 1) { return RemoteRecv::make(opr.key(), inputs[0], *opr.owner_graph(), opr.group_client(), config, opr.shape(), - opr.dtype()) + opr.dtype(), opr.backend()) .node() ->owner_opr(); } else { mgb_assert(inputs.size() == 0, "recv should have 1 or 0 input"); return RemoteRecv::make(opr.key(), *opr.owner_graph(), opr.group_client(), config, opr.shape(), - opr.dtype()) + opr.dtype(), opr.backend()) .node() ->owner_opr(); } diff --git a/src/opr-mm/impl/io_remote.oprdecl b/src/opr-mm/impl/io_remote.oprdecl index f903be60..dff7d508 100644 --- a/src/opr-mm/impl/io_remote.oprdecl +++ b/src/opr-mm/impl/io_remote.oprdecl @@ -9,6 +9,8 @@ decl_raw_opr( Doc('key', 'key to bind send-recv pair', 'str'), Doc('var', 'variable to be sent', ':class:`.SymbolVar`'), Doc('is_grad', 'whether the send', 'bool'), + Doc('backend', 'Backend for collective communication, nccl or ucx', + 'str', '\'nccl\''), ] ) @@ -24,7 +26,9 @@ decl_raw_opr( ':class:`.CompGraph`'), Doc('shape', 'output var shape'), Doc('dtype', 'data type of the output var; must match dtype at sender', - ':class:`numpy.dtype` compatible') + ':class:`numpy.dtype` compatible'), + Doc('backend', 'Backend for collective communication, nccl or ucx', + 'str', '\'nccl\''), ] ) diff --git a/src/opr-mm/include/megbrain/opr/io_remote.h b/src/opr-mm/include/megbrain/opr/io_remote.h index 04929b6a..db0ffb22 100644 --- a/src/opr-mm/include/megbrain/opr/io_remote.h +++ b/src/opr-mm/include/megbrain/opr/io_remote.h @@ -48,17 +48,19 @@ MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // { public: RemoteSend(const std::string& key, VarNode* var, std::shared_ptr group_client, - bool is_grad, const OperatorNodeConfig& config); + bool is_grad, std::string backend, const OperatorNodeConfig& config); static SymbolVar make( const std::string& key, SymbolVar var, std::shared_ptr group_client, - bool is_grad, const OperatorNodeConfig& config = {}); + bool is_grad, std::string backend, const OperatorNodeConfig& config = {}); + const std::string& backend() const { return m_backend; } bool is_grad() const { return m_is_grad; } private: HostTensorND m_output_val; + std::string m_backend; bool m_is_grad; void scn_do_execute() override; @@ -75,31 +77,33 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { RemoteRecv(const std::string& key, cg::ComputingGraph& graph, std::shared_ptr group_client, const OperatorNodeConfig& config, const TensorShape& shape, - DType dtype); + DType dtype, std::string backend); RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph, std::shared_ptr group_client, const OperatorNodeConfig& config, const TensorShape& shape, - DType dtype); + DType dtype, std::string backend); static SymbolVar make( const std::string& key, cg::ComputingGraph& graph, std::shared_ptr group_client, const OperatorNodeConfig& config, const TensorShape& shape, - DType dtype); + DType dtype, std::string backend); static SymbolVar make( const std::string& key, SymbolVar var, cg::ComputingGraph& graph, std::shared_ptr group_client, const OperatorNodeConfig& config, const TensorShape& shape, - DType dtype); + DType dtype, std::string backend); const TensorShape& shape() const { return m_shape; } const DType& dtype() const { return m_dtype; } + const std::string& backend() const { return m_backend; } private: const TensorShape m_shape; const DType m_dtype; + const std::string m_backend; const CompNode m_comp_node; DeviceTensorND m_dev_buffer; diff --git a/src/opr-mm/test/io_remote.cpp b/src/opr-mm/test/io_remote.cpp index 95668f25..9c4769ea 100644 --- a/src/opr-mm/test/io_remote.cpp +++ b/src/opr-mm/test/io_remote.cpp @@ -33,10 +33,10 @@ TEST(TestOprIORemote, Identity) { auto graph = ComputingGraph::make(); auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0); - auto xr = opr::RemoteSend::make("x", x, client, false); + auto xr = opr::RemoteSend::make("x", x, client, false, "nccl"); auto y = opr::RemoteRecv::make("x", *graph.get(), client, {cn1}, host_x->shape(), - host_x->dtype()); + host_x->dtype(), "nccl"); auto func = graph->compile({{xr, {}}, make_callback_copy(y, host_y)}); @@ -57,7 +57,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { auto graph = ComputingGraph::make(); sys::set_thread_name("sender"); auto x = opr::Host2DeviceCopy::make(*graph, host_x), - xr = opr::RemoteSend::make("x", x, client, false); + xr = opr::RemoteSend::make("x", x, client, false, "nccl"); auto func = graph->compile({{xr, {}}}); func->execute(); }; @@ -67,7 +67,7 @@ TEST(TestOprIORemote, IdentityMultiThread) { auto graph = ComputingGraph::make(); auto x = opr::RemoteRecv::make("x", *graph.get(), client, {cns[0]}, host_x->shape(), - host_x->dtype()); + host_x->dtype(), "nccl"); auto func = graph->compile({make_callback_copy(x, host_x_get)}); func->execute(); }; @@ -91,7 +91,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { sys::set_thread_name("sender"); auto graph = ComputingGraph::make(); auto x = opr::Host2DeviceCopy::make(*graph, host_x) * 2 + 1, - xr = opr::RemoteSend::make("x", x, client, false); + xr = opr::RemoteSend::make("x", x, client, false, "nccl"); auto func = graph->compile({{xr, {}}}); func->execute(); }; @@ -101,7 +101,7 @@ TEST(TestOprIORemote, IdentityWithGopt) { auto graph = ComputingGraph::make(); auto x = opr::RemoteRecv::make("x", *graph.get(), client, {cns[0]}, host_x->shape(), - host_x->dtype()); + host_x->dtype(), "nccl"); auto func = graph->compile({make_callback_copy((x - 1) / 2, host_x_get)}); func->execute(); @@ -126,12 +126,12 @@ TEST(TestOprIORemote, APlusB) { auto graph = ComputingGraph::make(); auto z = opr::RemoteRecv::make("z", *graph.get(), client, {cns[0]}, host_x->shape(), - host_x->dtype()); + host_x->dtype(), "nccl"); auto x = opr::Host2DeviceCopy::make(*graph, host_x).rename("x"), y = opr::Host2DeviceCopy::make(*graph, host_y).rename("y"), - xr = opr::RemoteSend::make("x", x, client, false) + xr = opr::RemoteSend::make("x", x, client, false, "nccl") .rename("xr"), - yr = opr::RemoteSend::make("y", y, client, false) + yr = opr::RemoteSend::make("y", y, client, false, "nccl") .rename("yr"); auto func = graph->compile( {{xr, {}}, {yr, {}}, make_callback_copy(z, host_z)}); @@ -144,12 +144,12 @@ TEST(TestOprIORemote, APlusB) { auto graph = ComputingGraph::make(); auto x = opr::RemoteRecv::make("x", *graph.get(), client, {cns[1]}, host_x->shape(), - host_x->dtype()), + host_x->dtype(), "nccl"), y = opr::RemoteRecv::make("y", *graph.get(), client, {cns[1]}, host_y->shape(), - host_y->dtype()), + host_y->dtype(), "nccl"), z = x + y, - zr = opr::RemoteSend::make("z", z, client, false); + zr = opr::RemoteSend::make("z", z, client, false, "nccl"); auto func = graph->compile({{zr, {}}}); func->execute(); }; @@ -178,10 +178,10 @@ TEST(TestOprIORemote, SendGrad) { sys::set_thread_name("sender"); auto graph = ComputingGraph::make(); auto x = opr::Host2DeviceCopy::make(*graph, host_x), - loss = opr::RemoteSend::make("loss", x, client, false); + loss = opr::RemoteSend::make("loss", x, client, false, "nccl"); ASSERT_TRUE(!loss.shape().ndim && loss.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); - loss = opr::RemoteSend::make("loss", x, client, true); + loss = opr::RemoteSend::make("loss", x, client, true, "nccl"); auto gx = cg::grad(loss, x); set_priority(loss, 0); set_priority(gx, 1); @@ -200,8 +200,8 @@ TEST(TestOprIORemote, SendGrad) { auto graph = ComputingGraph::make(); auto x = opr::RemoteRecv::make("loss", *graph.get(), client, {cns[1]}, host_x->shape(), - host_x->dtype()); - auto y = opr::RemoteSend::make("loss:grad", x + 1, client, false); + host_x->dtype(), "nccl"); + auto y = opr::RemoteSend::make("loss:grad", x + 1, client, false, "nccl"); auto func = graph->compile({{y, {}}}); func->execute(); }; -- GitLab