From ac51f7807d895a010805f3cac6cdfbd12ad9ce62 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 1 Jun 2022 21:10:13 +0800 Subject: [PATCH] feat(mge/distributed): add support for batch send recv op GitOrigin-RevId: eb3d712704f7a1d0abc6c611cec7c93ad3f5e8bf --- .../python/megengine/distributed/__init__.py | 1 + .../megengine/distributed/functional.py | 38 ++++- .../python/megengine/distributed/group.py | 7 + imperative/python/src/common.cpp | 17 ++ imperative/python/src/common.h | 1 + imperative/python/src/tensor.cpp | 8 + imperative/python/src/transformation.h | 3 +- .../test/unit/distributed/test_distributed.py | 29 ++++ imperative/src/impl/ops/io_remote.cpp | 160 +++++++++++++++++- .../src/impl/transformations/group_comm.cpp | 67 ++++++++ .../megbrain/imperative/ops/io_remote.h | 11 ++ .../imperative/transformations/group_comm.h | 44 +++++ imperative/src/test/io_remote.cpp | 4 +- src/opr-mm/impl/group_manager.cpp | 22 +++ src/opr-mm/impl/megray_helper.cpp | 11 +- src/opr-mm/impl/mm_handler.cpp | 58 +++++++ .../include/megbrain/opr/group_manager.h | 15 ++ .../include/megbrain/opr/megray_helper.h | 1 + src/opr-mm/include/megbrain/opr/mm_handler.h | 34 +++- src/opr-mm/proto/mm_handler.proto | 12 ++ src/opr-mm/test/mock_client.h | 6 + 21 files changed, 531 insertions(+), 18 deletions(-) create mode 100644 imperative/src/impl/transformations/group_comm.cpp create mode 100644 imperative/src/include/megbrain/imperative/ops/io_remote.h create mode 100644 imperative/src/include/megbrain/imperative/transformations/group_comm.h diff --git a/imperative/python/megengine/distributed/__init__.py b/imperative/python/megengine/distributed/__init__.py index 99bff4f7..55ee93c6 100644 --- a/imperative/python/megengine/distributed/__init__.py +++ b/imperative/python/megengine/distributed/__init__.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from mprop import mproperty +from ..core._imperative_rt.core2 import group_end, group_start from . import group from .group import ( WORLD, diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index 9c9d9a1c..02304bfb 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -1,12 +1,11 @@ # -*- coding: utf-8 -*- -from typing import Optional, Tuple +from typing import Optional import numpy as np from ..core._imperative_rt.core2 import apply from ..core.autodiff.grad import Function, _grad_manager_dict -from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend -from ..core.tensor.utils import isscalar +from ..core.ops.builtin import CollectiveComm, RemoteRecv, RemoteSend from ..device import get_default_device, what_is_xpu from ..tensor import Tensor from . import group @@ -843,16 +842,13 @@ def remote_send(inp: Tensor, dest_rank: int): """ group = _SendRecvGroup(get_rank(), dest_rank) _bcast_shape_dtype(group, inp) - _bcast_tracer_state(group, inp) - op = RemoteSend() op.key = group.key op.addr, op.port = get_mm_server_addr() op.rank_to = dest_rank op.backend = _backend() out = _RemoteSend(op)(inp) - _save_output_for_autodiff(inp, out) @@ -900,6 +896,34 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor op.addr, op.port = get_mm_server_addr() op.rank_from = src_rank op.backend = _backend() - ret = _RemoteRecv(op)(inp) return ret + + +def _remote_send_nobackward(inp: Tensor, dest_rank: int): + op = RemoteSend() + op.key = "b{}->{}".format(get_rank(), dest_rank) + op.addr, op.port = get_mm_server_addr() + op.rank_to = dest_rank + op.backend = _backend() + apply(op, inp) + + +def _remote_recv_nobackward( + src_rank: int, device: Optional[str] = None, inp=None, shape=None, dtype=None, +): + op = RemoteRecv() + op.key = "b{}->{}".format(src_rank, get_rank()) + if device is None: + device = get_default_device() + op.cn = device + if inp is None: + inp = Tensor(0, device=device) + assert shape is not None and dtype is not None + op.shape = shape + op.dtype = dtype + op.addr, op.port = get_mm_server_addr() + op.rank_from = src_rank + op.backend = _backend() + ret = apply(op, inp)[0] + return ret diff --git a/imperative/python/megengine/distributed/group.py b/imperative/python/megengine/distributed/group.py index 58ddfc4e..dd961293 100644 --- a/imperative/python/megengine/distributed/group.py +++ b/imperative/python/megengine/distributed/group.py @@ -160,6 +160,13 @@ def init_process_group( set_default_device("{}{}".format(device_type, device)) seed(int(time.time()) + rank) + if backend == "nccl": + # init nccl env + from ..core._imperative_rt.common import init_nccl_env + + group_barrier() + init_nccl_env(master_ip, _sd.mm_server_port, world_size, rank, 0) + def _set_machine_ranks(ranks) -> None: global _sd diff --git a/imperative/python/src/common.cpp b/imperative/python/src/common.cpp index 5f0d1ffe..61fa8e0a 100644 --- a/imperative/python/src/common.cpp +++ b/imperative/python/src/common.cpp @@ -8,6 +8,9 @@ #include "megbrain/comp_node.h" #include "megbrain/graph.h" #include "megbrain/imperative/physical_tensor.h" +#if MGB_ENABLE_OPR_MM +#include "megbrain/opr/mm_handler.h" +#endif #if MEGDNN_WITH_CUDA #include "cuda_sm_gen.h" @@ -46,6 +49,18 @@ void set_default_device(const std::string& device) { default_device = device; } +void init_nccl_env(const std::string& ip, int port, int nranks, int rank, int root) { +#if MGB_ENABLE_OPR_MM + auto&& help = mgb::opr::BatchSendRecvHelper::getInstance(); + bool res = help->init(nranks, rank, ip, port, root); + auto p = help->get(std::string("init_all_cards")); +#else + mgb_throw( + MegBrainError, + "MegEngine compiled without MM opr, doesn't support init_nccl_env"); +#endif +} + std::string get_default_device() { return default_device; } @@ -252,6 +267,8 @@ void init_common(py::module m) { m.def("what_is_xpu", [] { return CompNode::Locator::parse("xpux").to_physical().type; }); + m.def("init_nccl_env", &init_nccl_env); + init_npy_num_bfloat16(m); init_npy_num_intbx(m); init_dtypes(m); diff --git a/imperative/python/src/common.h b/imperative/python/src/common.h index 8260f306..2761f012 100644 --- a/imperative/python/src/common.h +++ b/imperative/python/src/common.h @@ -8,3 +8,4 @@ void set_default_device(const std::string& device); std::string get_default_device(); extern pybind11::handle py_comp_node_type; +void init_nccl_env(const std::string& ip, int port, int nranks, int rank, int root); diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index e24bd598..ef67ad30 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -9,6 +9,7 @@ #include "megbrain/imperative/transformations/dtype_promote.h" #include "megbrain/imperative/transformations/eval.h" #include "megbrain/imperative/transformations/format.h" +#include "megbrain/imperative/transformations/group_comm.h" #include "megbrain/imperative/transformations/lazy.h" #include "megbrain/imperative/transformations/scalar.h" #include "megbrain/imperative/transformations/symbol.h" @@ -947,6 +948,13 @@ void init_tensor(py::module m) { m.def("enable_cupti", &cupti::enable); m.def("disable_cupti", &cupti::disable); m.def("cupti_available", &cupti::available); + + static std::unique_ptr> group_comm_guard; + m.def("group_start", []() { + auto commtrans = std::make_shared(); + group_comm_guard = transformations.register_at(commtrans); + }); + m.def("group_end", []() { group_comm_guard.reset(); }); m.def("sync", [channel]() { if (channel->check_available()) { channel->sync(); diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h index 10d5df56..eab8bb0d 100644 --- a/imperative/python/src/transformation.h +++ b/imperative/python/src/transformation.h @@ -16,6 +16,7 @@ struct TransformationManager { public: enum Segment { ModuleTrace, + GroupComm, DTypePromote, DimExpansion, Format, @@ -26,7 +27,7 @@ public: Eval, }; - std::array>, 9> segments; + std::array>, 10> segments; private: template diff --git a/imperative/python/test/unit/distributed/test_distributed.py b/imperative/python/test/unit/distributed/test_distributed.py index d3146e9a..4c151743 100644 --- a/imperative/python/test/unit/distributed/test_distributed.py +++ b/imperative/python/test/unit/distributed/test_distributed.py @@ -237,3 +237,32 @@ def test_get_cuda_compute_capability(): assert mge.device.get_cuda_compute_capability(dist.get_rank()) > 0 worker() + + +@pytest.mark.require_ngpu(3) +@pytest.mark.isolated_distributed +def test_batch_send_recv(): + import megengine.distributed.functional as DF + + @dist.launcher(n_gpus=3) + def worker(): + rank = dist.get_rank() + dist.group_start() + for i in range(3): + tensor = mge.tensor(np.ones(10000)) * rank + if i == 2: + tensor *= i + DF._remote_send_nobackward(tensor, (rank + 1) % 3) + DF._remote_recv_nobackward( + src_rank=(rank + 1) % 3, dtype="float32", shape=(10000,) + ) + DF._remote_send_nobackward(tensor, (rank - 1) % 3) + recv = DF._remote_recv_nobackward( + src_rank=(rank - 1) % 3, dtype="float32", shape=(10000,) + ) + if i == 2: + recv2 = recv + dist.group_end() + np.testing.assert_equal(recv2.numpy(), (rank - 1) % 3 * 2 * np.ones(10000)) + + worker() diff --git a/imperative/src/impl/ops/io_remote.cpp b/imperative/src/impl/ops/io_remote.cpp index 6dd097fa..86e0cb35 100644 --- a/imperative/src/impl/ops/io_remote.cpp +++ b/imperative/src/impl/ops/io_remote.cpp @@ -1,14 +1,19 @@ +#include "megbrain/imperative/ops/io_remote.h" #include "megbrain_build_config.h" #if MGB_ENABLE_OPR_MM +#include +#include +#include +#include "../blob_manager_impl.h" #include "../op_trait.h" #include "megbrain/imperative/proxy_graph_detail.h" #include "megbrain/opr/io_remote.h" +#include "megbrain/opr/megray_helper.h" #include "megbrain/opr/mm_handler.h" #endif // MGB_ENABLE_OPR_MM - #include "megbrain/imperative/ops/autogen.h" - +#include "megbrain/imperative/proxy_graph_detail.h" namespace mgb { namespace imperative { @@ -46,15 +51,164 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( recv.backend)); } +TensorPtr megray_recv_tensor( + std::shared_ptr megray_comm, TensorLayout& layout, + CompNode cn, uint32_t rank_from) { + DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag(cn, layout); + auto megray_ctx = mgb::opr::get_megray_context(cn); + size_t data_size = layout.total_nr_elems(); + auto status = megray_comm->recv( + out.raw_ptr(), data_size, mgb::opr::get_megray_dtype(layout.dtype), + rank_from, megray_ctx); + mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed"); + return Tensor::make(out); +} + +void megray_send_tensor( + std::shared_ptr megray_comm, const TensorPtr& src, + uint32_t rank_to) { + auto&& tensor = src->dev_tensor(); + auto&& ishp = src->shape(); + size_t data_size = ishp.total_nr_elems(); + auto megray_ctx = mgb::opr::get_megray_context(src->comp_node()); + auto status = megray_comm->send( + src->dev_tensor().raw_ptr(), data_size, + mgb::opr::get_megray_dtype(src->layout().dtype), rank_to, megray_ctx); + mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed"); +} + +TensorLayout create_layout(const std::vector& shape, DType dtype) { + TensorShape tshape; + tshape.ndim = shape.size(); + mgb_assert(tshape.ndim <= TensorLayout::MAX_NDIM); + std::copy(shape.begin(), shape.end(), tshape.shape); + return TensorLayout(tshape, dtype); +} + +std::tuple, bool> infer_output_attrs_fallible_remote_send( + const OpDef& def, const SmallVector& input_descs) { + auto&& dtype = input_descs[0].layout.dtype; + auto&& cn = input_descs[0].comp_node; + return {{{TensorLayout({0}, dtype), cn}}, true}; +} + +SmallVector apply_on_physical_tensor_remote_send( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto&& op = def.cast_final_safe(); + auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get( + std::string("init_all_cards")); + if (!megray_comm) { + return proxy_graph_detail::apply_on_physical_tensor( + def, inputs, output_descs, validated); + } + mgb_assert(megray_comm != nullptr); + megray_send_tensor(megray_comm, inputs[0], op.rank_to); + TensorLayout layout({0}, inputs[0]->dtype()); + DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag( + inputs[0]->comp_node(), layout); + return {Tensor::make(out)}; +} + +std::tuple, bool> infer_output_attrs_fallible_remote_recv( + const OpDef& def, const SmallVector& input_descs) { + auto& op = def.cast_final_safe(); + return {{{create_layout(op.shape, op.dtype), op.cn}}, true}; +} + +SmallVector apply_on_physical_tensor_remote_recv( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto&& op = def.cast_final_safe(); + auto layout = create_layout(op.shape, op.dtype); + auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get( + std::string("init_all_cards")); + if (!megray_comm) { + return proxy_graph_detail::apply_on_physical_tensor( + def, inputs, output_descs, validated); + } + auto&& out = megray_recv_tensor(megray_comm, layout, op.cn, op.rank_from); + return {out}; +} + +SmallVector get_input_layout_constraint( + const OpDef& def, const SmallVector& inputs) { + SmallVector layout_checker(inputs.size()); + for (size_t i; i < inputs.size(); i++) { + layout_checker[i] = [](const TensorLayout& layout) { + return layout.is_contiguous(); + }; + } + return layout_checker; +} + OP_TRAIT_REG(RemoteSend, RemoteSend, mgb::opr::RemoteSend) .apply_on_var_node(apply_on_var_node_remote_send) + .apply_on_physical_tensor(apply_on_physical_tensor_remote_send) + .infer_output_attrs_fallible(infer_output_attrs_fallible_remote_send) + .get_input_layout_constraint(get_input_layout_constraint) .fallback(); OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv) .apply_on_var_node(apply_on_var_node_remote_recv) + .apply_on_physical_tensor(apply_on_physical_tensor_remote_recv) + .infer_output_attrs_fallible(infer_output_attrs_fallible_remote_recv) + .get_input_layout_constraint(get_input_layout_constraint) .fallback(); -} // anonymous namespace + +SmallVector apply_on_physical_tensor_batch_send_recv( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto&& op = def.cast_final_safe(); + auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get( + std::string("init_all_cards")); + mgb_assert(megray_comm != nullptr); + megray_comm->group_start(); + SmallVector outputs; + size_t ind = 0; + for (auto&& op_ : op.op_list) { + if (op_->same_type()) { + auto&& send_op = op_->cast_final_safe(); + auto&& tensor = inputs[ind]; + megray_send_tensor(megray_comm, tensor, send_op.rank_to); + ind++; + } else { + mgb_assert(op_->same_type()); + auto&& recv_op = op_->cast_final_safe(); + auto layout = create_layout(recv_op.shape, recv_op.dtype); + auto&& out = megray_recv_tensor( + megray_comm, layout, recv_op.cn, recv_op.rank_from); + outputs.push_back(out); + } + } + megray_comm->group_end(); + return outputs; +} + +std::tuple, bool> +infer_output_attrs_fallible_batch_send_recv( + const OpDef& def, const SmallVector& input_descs) { + auto& op = def.cast_final_safe(); + SmallVector output_descs; + for (auto&& op_ : op.op_list) { + if (op_->same_type()) { + auto&& recv_op = op_->cast_final_safe(); + output_descs.push_back( + {create_layout(recv_op.shape, recv_op.dtype), recv_op.cn}); + } + } + return {output_descs, true}; +} + +OP_TRAIT_REG(BatchSendRecvOp, BatchSendRecvOp) + .apply_on_physical_tensor(apply_on_physical_tensor_batch_send_recv) + .infer_output_attrs_fallible(infer_output_attrs_fallible_batch_send_recv) + .get_input_layout_constraint(get_input_layout_constraint) + .fallback(); +} // namespace + #endif // MGB_ENABLE_OPR_MM +MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchSendRecvOp); } // namespace imperative } // namespace mgb diff --git a/imperative/src/impl/transformations/group_comm.cpp b/imperative/src/impl/transformations/group_comm.cpp new file mode 100644 index 00000000..5bf61cec --- /dev/null +++ b/imperative/src/impl/transformations/group_comm.cpp @@ -0,0 +1,67 @@ +#include "megbrain/imperative/transformations/group_comm.h" +#include "megbrain/imperative/blob_manager.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/imperative/ops/io_remote.h" +namespace mgb { +namespace imperative { + +ValueRefList GroupCommTransformation::apply_transformation( + const Operator& op, Span inputs) { + for (auto inp : inputs) { + mgb_assert( + !inp.is(m_value_type), "Can not use PlaceholderValue as apply input"); + } + if (auto* apply_op = op.as()) { + if (apply_op->op().same_type()) { + auto&& send_op = apply_op->op().cast_final_safe(); + if (send_op.key[0] == 'b') { + send_inputs.push_back(inputs[0]); + record_ops.push_back(send_op.shared_from_this()); + return {}; + } + } + if (apply_op->op().same_type()) { + auto&& recv_op = apply_op->op().cast_final_safe(); + if (recv_op.key[0] == 'b') { + record_ops.push_back(recv_op.shared_from_this()); + auto rst = m_value_type.make(); + recv_tensors.push_back(rst); + auto outputs = ValueRefList(1); + outputs[0] = rst; + return outputs; + } + } + return imperative::apply(op, inputs); + } else { + return imperative::apply(op, inputs); + } +} + +ValueRefList GroupCommTransformation::execute_batch_op() { + auto batch_op = BatchSendRecvOp::make(record_ops); + auto outputs = imperative::apply(*batch_op, send_inputs); + return outputs; +} + +void GroupCommTransformation::on_unregister() noexcept { + auto rst = execute_batch_op(); + mgb_assert(rst.size() == recv_tensors.size()); + for (size_t i = 0; i < rst.size(); i++) { + auto v = recv_tensors[i].lock(); + if (v != ValueRef::nil) { + v.reset(rst[i]); + } + } +} + +GroupCommTransformation::~GroupCommTransformation() { + for (auto&& recv : recv_tensors) { + mgb_assert( + recv.lock() == ValueRef::nil, + "Some PlaceholderValues are not reset after GroupCommTransformation " + "destroyed!"); + }; +} + +} // namespace imperative +} // namespace mgb \ No newline at end of file diff --git a/imperative/src/include/megbrain/imperative/ops/io_remote.h b/imperative/src/include/megbrain/imperative/ops/io_remote.h new file mode 100644 index 00000000..3e364aed --- /dev/null +++ b/imperative/src/include/megbrain/imperative/ops/io_remote.h @@ -0,0 +1,11 @@ +#pragma once +#include "megbrain/imperative/op_def.h" +namespace mgb::imperative { +struct BatchSendRecvOp final : OpDefImplBase { + SmallVector> op_list; + BatchSendRecvOp() = default; + BatchSendRecvOp(SmallVector> op_list) : op_list{op_list} {} + MGB_DYN_TYPE_OBJ_FINAL_DECL; +}; + +} // namespace mgb::imperative \ No newline at end of file diff --git a/imperative/src/include/megbrain/imperative/transformations/group_comm.h b/imperative/src/include/megbrain/imperative/transformations/group_comm.h new file mode 100644 index 00000000..17e59794 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/transformations/group_comm.h @@ -0,0 +1,44 @@ +/** + * \file imperative/src/include/megbrain/imperative/scalar.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/imperative/basic_operators.h" +#include "megbrain/imperative/dispatch.h" +#include "megbrain/imperative/ops/autogen.h" + +namespace mgb::imperative { + +class PlaceholderValue final : public ObjectValue { +public: + std::string to_string() const override { return ssprintf("PlaceholderValue"); } + void clear() override {} +}; + +class GroupCommTransformation final : public Transformation { +private: + SmallVector send_inputs; + std::vector recv_tensors; + SmallVector> record_ops; + ObjectType m_value_type{"PlaceholderValue"}; + +public: + GroupCommTransformation() = default; + ValueRefList apply_transformation( + const Operator& op, Span inputs) override; + ValueRefList execute_batch_op(); + ValueRef unwrap(ValueRef value) override { return value; } + std::string name() const override { return "GroupCommTransformation"; } + void on_unregister() noexcept override; + ~GroupCommTransformation(); +}; + +} // namespace mgb::imperative diff --git a/imperative/src/test/io_remote.cpp b/imperative/src/test/io_remote.cpp index 072e4e2e..0d03727e 100644 --- a/imperative/src/test/io_remote.cpp +++ b/imperative/src/test/io_remote.cpp @@ -1,4 +1,5 @@ #include "./helper.h" +#include "megbrain/comp_node_env.h" #include "megbrain/imperative/ops/autogen.h" #include "megbrain/opr/mm_handler.h" @@ -47,7 +48,4 @@ TEST(TestImperative, IORemote) { t0.join(); t1.join(); } - // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} - -// ./imperative_test --gtest_filter TestIORemote diff --git a/src/opr-mm/impl/group_manager.cpp b/src/opr-mm/impl/group_manager.cpp index 55516665..cdffbbcb 100644 --- a/src/opr-mm/impl/group_manager.cpp +++ b/src/opr-mm/impl/group_manager.cpp @@ -151,6 +151,28 @@ void GroupManager::bcast_addr( } } +void GroupManager::bcast_nccluniqueid( + const std::string& key, std::string& id, uint32_t size, uint32_t rank, + uint32_t root) { + std::unique_lock lk{m_key2nccl_id_mtx}; + if (rank == root) { + m_key2nccl_id[key] = id; + } + m_key2nccl_id_size[key]++; + if (m_key2nccl_id_size[key] == size) { + m_key2nccl_id_flag[key] = true; + m_bcast_cv.notify_all(); + } else { + m_bcast_cv.wait(lk, [&] { return m_key2nccl_id_flag.count(key) > 0; }); + } + id = m_key2nccl_id[key]; + m_key2nccl_id_size[key]--; + if (m_key2nccl_id_size[key] == 0) { + m_key2nccl_id.erase(key); + m_key2nccl_id_flag.erase(key); + } +} + void GroupManager::set_output_shape(const std::string& key, const TensorShape& shape) { auto&& group = get_group(key); group.set_output_shape(key, shape); diff --git a/src/opr-mm/impl/megray_helper.cpp b/src/opr-mm/impl/megray_helper.cpp index c562f186..d3591a37 100644 --- a/src/opr-mm/impl/megray_helper.cpp +++ b/src/opr-mm/impl/megray_helper.cpp @@ -67,6 +67,15 @@ void MegRayCommBuilder::emplace( m_megray_comms.emplace(hash, comm); } +void MegRayCommBuilder::remove( + uint64_t hash, std::shared_ptr comm) { + std::unique_lock lk(m_map_mtx); + auto it = m_megray_comms.find(hash); + if (it != m_megray_comms.end()) { + m_megray_comms.erase(hash); + } +} + std::shared_ptr MegRayCommBuilder::get_megray_comm( uint64_t hash, std::string key, uint32_t size, uint32_t rank, MegRay::Backend backend, std::shared_ptr group_client) { @@ -104,5 +113,3 @@ std::shared_ptr MegRayCommBuilder::get_megray_comm( MegRayCommBuilder* MegRayCommBuilder::sm_instance = nullptr; std::mutex MegRayCommBuilder::sm_instance_mtx; - -// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr-mm/impl/mm_handler.cpp b/src/opr-mm/impl/mm_handler.cpp index 7cbb8a3e..1510c148 100644 --- a/src/opr-mm/impl/mm_handler.cpp +++ b/src/opr-mm/impl/mm_handler.cpp @@ -45,6 +45,7 @@ public: RUNSERVER(get_output_shape); RUNSERVER(bcast_addr); RUNSERVER(group_barrier); + RUNSERVER(bcast_nccluniqueid); mgb_assert(false, "invalid rpc request"); } @@ -53,6 +54,7 @@ private: void set_output_shape(void* input_ptr, size_t input_len, std::string* output); void get_output_shape(void* input_ptr, size_t input_len, std::string* output); void bcast_addr(void* input_ptr, size_t input_len, std::string* output); + void bcast_nccluniqueid(void* input_ptr, size_t input_len, std::string* output); void group_barrier(void* input_ptr, size_t input_len, std::string* output); private: @@ -116,6 +118,15 @@ void GroupServerProxy::bcast_addr( rsp.SerializeToString(output); } +void GroupServerProxy::bcast_nccluniqueid( + void* input_ptr, size_t input_len, std::string* output) { + INFO_INIT(mm_handler, BcastNcclUniqueId); + std::string id = req.id(); + m_mgr.bcast_nccluniqueid(req.key(), id, req.size(), req.rank(), req.root()); + rsp.set_id(id); + rsp.SerializeToString(output); +} + void GroupServerProxy::group_barrier( void* input_ptr, size_t input_len, std::string* output) { INFO_INIT(mm_handler, GroupBarrier); @@ -201,6 +212,19 @@ void GroupClientProxy::bcast_addr( port = rsp.port(); } +void GroupClientProxy::bcast_nccluniqueid( + const std::string& key, std::string& id, uint32_t size, uint32_t rank, + uint32_t root) { + INFO_INIT(mm_handler, bcast_nccluniqueid, BcastNcclUniqueId); + req.set_id(id.data(), id.size()); + req.set_key(key.data(), key.size()); + req.set_size(size); + req.set_rank(rank); + req.set_root(root); + SOLVE_REQUEST(func_name, req, rsp); + id = rsp.id(); +} + uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { INFO_INIT(mm_handler, group_barrier, GroupBarrier); req.set_size(size); @@ -209,6 +233,40 @@ uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { return rsp.size(); } +std::shared_ptr BatchSendRecvHelper::get(std::string&& key) { + auto ptr = megray_comm_cache.find(key); + if (ptr != megray_comm_cache.end()) { + return megray_comm_cache[key]; + } else { + return nullptr; + } +} + +std::unordered_map> + BatchSendRecvHelper::megray_comm_cache{}; + +bool BatchSendRecvHelper::init( + int nranks, int rank, std::string ip, int port, int root) { + auto megray_comm = + MegRay::get_communicator(nranks, rank, MegRay::Backend::MEGRAY_NCCL); + auto group_client = + std::make_shared(ssprintf("%s:%d", ip.data(), port)); + auto cb = [=](char* nccl_buffer, size_t len) { + std::string id; + id.resize(128); + if (rank == root) { + memcpy(id.data(), nccl_buffer, len); + } + group_client->bcast_nccluniqueid("init_all_cards", id, nranks, rank, root); + if (rank != root) { + memcpy(nccl_buffer, id.data(), len); + } + }; + megray_comm->init(cb); + return megray_comm_cache.insert({std::string("init_all_cards"), megray_comm}) + .second; +} + #undef INFO_INIT #undef SOLVE_REQUEST diff --git a/src/opr-mm/include/megbrain/opr/group_manager.h b/src/opr-mm/include/megbrain/opr/group_manager.h index 29ba9ef4..18b888ef 100644 --- a/src/opr-mm/include/megbrain/opr/group_manager.h +++ b/src/opr-mm/include/megbrain/opr/group_manager.h @@ -77,6 +77,11 @@ public: std::string& master_ip, int& port, const std::string& key, uint32_t size, uint32_t rank, uint32_t root); + //! bcast uid + void bcast_nccluniqueid( + const std::string& key, std::string& id, uint32_t size, uint32_t rank, + uint32_t root); + //! Set output shape of this key void set_output_shape(const std::string& key, const TensorShape& shape); @@ -101,6 +106,12 @@ private: std::mutex m_key2addr_mtx; std::condition_variable m_bcast_cv; + //! key -> ncclid + std::unordered_map m_key2nccl_id; + std::unordered_map m_key2nccl_id_size; + std::unordered_map m_key2nccl_id_flag; + std::mutex m_key2nccl_id_mtx; + //! barrier uint32_t m_barrier_size; std::set m_barrier_set; @@ -128,6 +139,10 @@ public: std::string& master_ip, int& port, const std::string& key, uint32_t size, uint32_t rank, uint32_t root) = 0; + virtual void bcast_nccluniqueid( + const std::string& key, std::string& id, uint32_t size, uint32_t rank, + uint32_t root) = 0; + virtual void set_output_shape(const std::string& key, const TensorShape& shape) = 0; virtual TensorShape get_output_shape(const std::string& key) = 0; diff --git a/src/opr-mm/include/megbrain/opr/megray_helper.h b/src/opr-mm/include/megbrain/opr/megray_helper.h index b9b2e139..ec8ee7db 100644 --- a/src/opr-mm/include/megbrain/opr/megray_helper.h +++ b/src/opr-mm/include/megbrain/opr/megray_helper.h @@ -23,6 +23,7 @@ class MegRayCommBuilder { private: bool find(uint64_t hash, std::shared_ptr& comm); void emplace(uint64_t hash, std::shared_ptr comm); + void remove(uint64_t hash, std::shared_ptr comm); std::unordered_map> m_megray_comms; std::mutex m_map_mtx; diff --git a/src/opr-mm/include/megbrain/opr/mm_handler.h b/src/opr-mm/include/megbrain/opr/mm_handler.h index 97b829d4..c00eacdf 100644 --- a/src/opr-mm/include/megbrain/opr/mm_handler.h +++ b/src/opr-mm/include/megbrain/opr/mm_handler.h @@ -39,6 +39,10 @@ public: std::string& master_ip, int& port, const std::string& key, uint32_t size, uint32_t rank, uint32_t root) override; + void bcast_nccluniqueid( + const std::string& key, std::string& id, uint32_t size, uint32_t rank, + uint32_t root) override; + void set_output_shape(const std::string& key, const TensorShape& shape) override; TensorShape get_output_shape(const std::string& key) override; @@ -52,6 +56,34 @@ private: void* m_stub; }; +template +class ProcessGlobal { // thread safe +public: + template + static std::shared_ptr& getInstance(Args&&... args) { + static auto instance = std::make_shared(std::forward(args)...); + return instance; + } + +protected: + template + ProcessGlobal(Args&&... args); + ProcessGlobal() = default; + +public: + ProcessGlobal(ProcessGlobal const&) = delete; + void operator=(ProcessGlobal const&) = delete; +}; + +class BatchSendRecvHelper : public ProcessGlobal { + static std::unordered_map> + megray_comm_cache; + +public: + std::shared_ptr get(std::string&&); + bool init(int nranks, int rank, std::string ip, int port, int root); +}; + /* ======================== ZmqRpcServerMgr ========================== */ int create_zmqrpc_server(const std::string& server_addr, int port); @@ -60,5 +92,3 @@ int create_zmqrpc_server(const std::string& server_addr, int port); } // namespace mgb #endif - -// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr-mm/proto/mm_handler.proto b/src/opr-mm/proto/mm_handler.proto index 00f1d662..908e15ef 100644 --- a/src/opr-mm/proto/mm_handler.proto +++ b/src/opr-mm/proto/mm_handler.proto @@ -30,6 +30,18 @@ message BcastAddrResponse { int32 port = 2; } +message BcastNcclUniqueIdRequest{ + string key = 1; + bytes id = 2; + uint32 size =3 ; + uint32 rank = 4; + uint32 root =5; +} + +message BcastNcclUniqueIdResponse{ + bytes id = 1; +} + message SetOutputShapeRequest { string key = 1; TensorShape shape = 2; diff --git a/src/opr-mm/test/mock_client.h b/src/opr-mm/test/mock_client.h index 075fa15d..46b96df6 100644 --- a/src/opr-mm/test/mock_client.h +++ b/src/opr-mm/test/mock_client.h @@ -26,6 +26,12 @@ public: return m_mgr.bcast_addr(master_ip, port, key, size, rank, root); } + void bcast_nccluniqueid( + const std::string& key, std::string& id, uint32_t size, uint32_t rank, + uint32_t root) override { + return m_mgr.bcast_nccluniqueid(key, id, size, rank, root); + } + void set_output_shape(const std::string& key, const TensorShape& shape) override { m_mgr.set_output_shape(key, shape); } -- GitLab