提交 ac51f780 编写于 作者: M Megvii Engine Team

feat(mge/distributed): add support for batch send recv op

GitOrigin-RevId: eb3d712704f7a1d0abc6c611cec7c93ad3f5e8bf
上级 013bb14f
# -*- coding: utf-8 -*-
from mprop import mproperty
from ..core._imperative_rt.core2 import group_end, group_start
from . import group
from .group import (
WORLD,
......
# -*- coding: utf-8 -*-
from typing import Optional, Tuple
from typing import Optional
import numpy as np
from ..core._imperative_rt.core2 import apply
from ..core.autodiff.grad import Function, _grad_manager_dict
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.utils import isscalar
from ..core.ops.builtin import CollectiveComm, RemoteRecv, RemoteSend
from ..device import get_default_device, what_is_xpu
from ..tensor import Tensor
from . import group
......@@ -843,16 +842,13 @@ def remote_send(inp: Tensor, dest_rank: int):
"""
group = _SendRecvGroup(get_rank(), dest_rank)
_bcast_shape_dtype(group, inp)
_bcast_tracer_state(group, inp)
op = RemoteSend()
op.key = group.key
op.addr, op.port = get_mm_server_addr()
op.rank_to = dest_rank
op.backend = _backend()
out = _RemoteSend(op)(inp)
_save_output_for_autodiff(inp, out)
......@@ -900,6 +896,34 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor
op.addr, op.port = get_mm_server_addr()
op.rank_from = src_rank
op.backend = _backend()
ret = _RemoteRecv(op)(inp)
return ret
def _remote_send_nobackward(inp: Tensor, dest_rank: int):
op = RemoteSend()
op.key = "b{}->{}".format(get_rank(), dest_rank)
op.addr, op.port = get_mm_server_addr()
op.rank_to = dest_rank
op.backend = _backend()
apply(op, inp)
def _remote_recv_nobackward(
src_rank: int, device: Optional[str] = None, inp=None, shape=None, dtype=None,
):
op = RemoteRecv()
op.key = "b{}->{}".format(src_rank, get_rank())
if device is None:
device = get_default_device()
op.cn = device
if inp is None:
inp = Tensor(0, device=device)
assert shape is not None and dtype is not None
op.shape = shape
op.dtype = dtype
op.addr, op.port = get_mm_server_addr()
op.rank_from = src_rank
op.backend = _backend()
ret = apply(op, inp)[0]
return ret
......@@ -160,6 +160,13 @@ def init_process_group(
set_default_device("{}{}".format(device_type, device))
seed(int(time.time()) + rank)
if backend == "nccl":
# init nccl env
from ..core._imperative_rt.common import init_nccl_env
group_barrier()
init_nccl_env(master_ip, _sd.mm_server_port, world_size, rank, 0)
def _set_machine_ranks(ranks) -> None:
global _sd
......
......@@ -8,6 +8,9 @@
#include "megbrain/comp_node.h"
#include "megbrain/graph.h"
#include "megbrain/imperative/physical_tensor.h"
#if MGB_ENABLE_OPR_MM
#include "megbrain/opr/mm_handler.h"
#endif
#if MEGDNN_WITH_CUDA
#include "cuda_sm_gen.h"
......@@ -46,6 +49,18 @@ void set_default_device(const std::string& device) {
default_device = device;
}
void init_nccl_env(const std::string& ip, int port, int nranks, int rank, int root) {
#if MGB_ENABLE_OPR_MM
auto&& help = mgb::opr::BatchSendRecvHelper::getInstance();
bool res = help->init(nranks, rank, ip, port, root);
auto p = help->get(std::string("init_all_cards"));
#else
mgb_throw(
MegBrainError,
"MegEngine compiled without MM opr, doesn't support init_nccl_env");
#endif
}
std::string get_default_device() {
return default_device;
}
......@@ -252,6 +267,8 @@ void init_common(py::module m) {
m.def("what_is_xpu",
[] { return CompNode::Locator::parse("xpux").to_physical().type; });
m.def("init_nccl_env", &init_nccl_env);
init_npy_num_bfloat16(m);
init_npy_num_intbx(m);
init_dtypes(m);
......
......@@ -8,3 +8,4 @@ void set_default_device(const std::string& device);
std::string get_default_device();
extern pybind11::handle py_comp_node_type;
void init_nccl_env(const std::string& ip, int port, int nranks, int rank, int root);
......@@ -9,6 +9,7 @@
#include "megbrain/imperative/transformations/dtype_promote.h"
#include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/format.h"
#include "megbrain/imperative/transformations/group_comm.h"
#include "megbrain/imperative/transformations/lazy.h"
#include "megbrain/imperative/transformations/scalar.h"
#include "megbrain/imperative/transformations/symbol.h"
......@@ -947,6 +948,13 @@ void init_tensor(py::module m) {
m.def("enable_cupti", &cupti::enable);
m.def("disable_cupti", &cupti::disable);
m.def("cupti_available", &cupti::available);
static std::unique_ptr<CleanupGuard<>> group_comm_guard;
m.def("group_start", []() {
auto commtrans = std::make_shared<GroupCommTransformation>();
group_comm_guard = transformations.register_at<Segment::GroupComm>(commtrans);
});
m.def("group_end", []() { group_comm_guard.reset(); });
m.def("sync", [channel]() {
if (channel->check_available()) {
channel->sync();
......
......@@ -16,6 +16,7 @@ struct TransformationManager {
public:
enum Segment {
ModuleTrace,
GroupComm,
DTypePromote,
DimExpansion,
Format,
......@@ -26,7 +27,7 @@ public:
Eval,
};
std::array<std::vector<std::shared_ptr<Transformation>>, 9> segments;
std::array<std::vector<std::shared_ptr<Transformation>>, 10> segments;
private:
template <Segment segment>
......
......@@ -237,3 +237,32 @@ def test_get_cuda_compute_capability():
assert mge.device.get_cuda_compute_capability(dist.get_rank()) > 0
worker()
@pytest.mark.require_ngpu(3)
@pytest.mark.isolated_distributed
def test_batch_send_recv():
import megengine.distributed.functional as DF
@dist.launcher(n_gpus=3)
def worker():
rank = dist.get_rank()
dist.group_start()
for i in range(3):
tensor = mge.tensor(np.ones(10000)) * rank
if i == 2:
tensor *= i
DF._remote_send_nobackward(tensor, (rank + 1) % 3)
DF._remote_recv_nobackward(
src_rank=(rank + 1) % 3, dtype="float32", shape=(10000,)
)
DF._remote_send_nobackward(tensor, (rank - 1) % 3)
recv = DF._remote_recv_nobackward(
src_rank=(rank - 1) % 3, dtype="float32", shape=(10000,)
)
if i == 2:
recv2 = recv
dist.group_end()
np.testing.assert_equal(recv2.numpy(), (rank - 1) % 3 * 2 * np.ones(10000))
worker()
#include "megbrain/imperative/ops/io_remote.h"
#include "megbrain_build_config.h"
#if MGB_ENABLE_OPR_MM
#include <algorithm>
#include <functional>
#include <numeric>
#include "../blob_manager_impl.h"
#include "../op_trait.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/opr/io_remote.h"
#include "megbrain/opr/megray_helper.h"
#include "megbrain/opr/mm_handler.h"
#endif // MGB_ENABLE_OPR_MM
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/proxy_graph_detail.h"
namespace mgb {
namespace imperative {
......@@ -46,15 +51,164 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv(
recv.backend));
}
TensorPtr megray_recv_tensor(
std::shared_ptr<MegRay::Communicator> megray_comm, TensorLayout& layout,
CompNode cn, uint32_t rank_from) {
DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag(cn, layout);
auto megray_ctx = mgb::opr::get_megray_context(cn);
size_t data_size = layout.total_nr_elems();
auto status = megray_comm->recv(
out.raw_ptr(), data_size, mgb::opr::get_megray_dtype(layout.dtype),
rank_from, megray_ctx);
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed");
return Tensor::make(out);
}
void megray_send_tensor(
std::shared_ptr<MegRay::Communicator> megray_comm, const TensorPtr& src,
uint32_t rank_to) {
auto&& tensor = src->dev_tensor();
auto&& ishp = src->shape();
size_t data_size = ishp.total_nr_elems();
auto megray_ctx = mgb::opr::get_megray_context(src->comp_node());
auto status = megray_comm->send(
src->dev_tensor().raw_ptr(), data_size,
mgb::opr::get_megray_dtype(src->layout().dtype), rank_to, megray_ctx);
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed");
}
TensorLayout create_layout(const std::vector<int32_t>& shape, DType dtype) {
TensorShape tshape;
tshape.ndim = shape.size();
mgb_assert(tshape.ndim <= TensorLayout::MAX_NDIM);
std::copy(shape.begin(), shape.end(), tshape.shape);
return TensorLayout(tshape, dtype);
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible_remote_send(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
auto&& dtype = input_descs[0].layout.dtype;
auto&& cn = input_descs[0].comp_node;
return {{{TensorLayout({0}, dtype), cn}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor_remote_send(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op = def.cast_final_safe<RemoteSend>();
auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get(
std::string("init_all_cards"));
if (!megray_comm) {
return proxy_graph_detail::apply_on_physical_tensor(
def, inputs, output_descs, validated);
}
mgb_assert(megray_comm != nullptr);
megray_send_tensor(megray_comm, inputs[0], op.rank_to);
TensorLayout layout({0}, inputs[0]->dtype());
DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag(
inputs[0]->comp_node(), layout);
return {Tensor::make(out)};
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible_remote_recv(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
auto& op = def.cast_final_safe<RemoteRecv>();
return {{{create_layout(op.shape, op.dtype), op.cn}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor_remote_recv(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op = def.cast_final_safe<RemoteRecv>();
auto layout = create_layout(op.shape, op.dtype);
auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get(
std::string("init_all_cards"));
if (!megray_comm) {
return proxy_graph_detail::apply_on_physical_tensor(
def, inputs, output_descs, validated);
}
auto&& out = megray_recv_tensor(megray_comm, layout, op.cn, op.rank_from);
return {out};
}
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
for (size_t i; i < inputs.size(); i++) {
layout_checker[i] = [](const TensorLayout& layout) {
return layout.is_contiguous();
};
}
return layout_checker;
}
OP_TRAIT_REG(RemoteSend, RemoteSend, mgb::opr::RemoteSend)
.apply_on_var_node(apply_on_var_node_remote_send)
.apply_on_physical_tensor(apply_on_physical_tensor_remote_send)
.infer_output_attrs_fallible(infer_output_attrs_fallible_remote_send)
.get_input_layout_constraint(get_input_layout_constraint)
.fallback();
OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv)
.apply_on_var_node(apply_on_var_node_remote_recv)
.apply_on_physical_tensor(apply_on_physical_tensor_remote_recv)
.infer_output_attrs_fallible(infer_output_attrs_fallible_remote_recv)
.get_input_layout_constraint(get_input_layout_constraint)
.fallback();
} // anonymous namespace
SmallVector<TensorPtr> apply_on_physical_tensor_batch_send_recv(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op = def.cast_final_safe<BatchSendRecvOp>();
auto megray_comm = mgb::opr::BatchSendRecvHelper::getInstance()->get(
std::string("init_all_cards"));
mgb_assert(megray_comm != nullptr);
megray_comm->group_start();
SmallVector<TensorPtr> outputs;
size_t ind = 0;
for (auto&& op_ : op.op_list) {
if (op_->same_type<RemoteSend>()) {
auto&& send_op = op_->cast_final_safe<RemoteSend>();
auto&& tensor = inputs[ind];
megray_send_tensor(megray_comm, tensor, send_op.rank_to);
ind++;
} else {
mgb_assert(op_->same_type<RemoteRecv>());
auto&& recv_op = op_->cast_final_safe<RemoteRecv>();
auto layout = create_layout(recv_op.shape, recv_op.dtype);
auto&& out = megray_recv_tensor(
megray_comm, layout, recv_op.cn, recv_op.rank_from);
outputs.push_back(out);
}
}
megray_comm->group_end();
return outputs;
}
std::tuple<SmallVector<LogicalTensorDesc>, bool>
infer_output_attrs_fallible_batch_send_recv(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
auto& op = def.cast_final_safe<BatchSendRecvOp>();
SmallVector<LogicalTensorDesc> output_descs;
for (auto&& op_ : op.op_list) {
if (op_->same_type<RemoteRecv>()) {
auto&& recv_op = op_->cast_final_safe<RemoteRecv>();
output_descs.push_back(
{create_layout(recv_op.shape, recv_op.dtype), recv_op.cn});
}
}
return {output_descs, true};
}
OP_TRAIT_REG(BatchSendRecvOp, BatchSendRecvOp)
.apply_on_physical_tensor(apply_on_physical_tensor_batch_send_recv)
.infer_output_attrs_fallible(infer_output_attrs_fallible_batch_send_recv)
.get_input_layout_constraint(get_input_layout_constraint)
.fallback();
} // namespace
#endif // MGB_ENABLE_OPR_MM
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchSendRecvOp);
} // namespace imperative
} // namespace mgb
#include "megbrain/imperative/transformations/group_comm.h"
#include "megbrain/imperative/blob_manager.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/io_remote.h"
namespace mgb {
namespace imperative {
ValueRefList GroupCommTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
for (auto inp : inputs) {
mgb_assert(
!inp.is(m_value_type), "Can not use PlaceholderValue as apply input");
}
if (auto* apply_op = op.as<ApplyOp>()) {
if (apply_op->op().same_type<RemoteSend>()) {
auto&& send_op = apply_op->op().cast_final_safe<RemoteSend>();
if (send_op.key[0] == 'b') {
send_inputs.push_back(inputs[0]);
record_ops.push_back(send_op.shared_from_this());
return {};
}
}
if (apply_op->op().same_type<RemoteRecv>()) {
auto&& recv_op = apply_op->op().cast_final_safe<RemoteRecv>();
if (recv_op.key[0] == 'b') {
record_ops.push_back(recv_op.shared_from_this());
auto rst = m_value_type.make();
recv_tensors.push_back(rst);
auto outputs = ValueRefList(1);
outputs[0] = rst;
return outputs;
}
}
return imperative::apply(op, inputs);
} else {
return imperative::apply(op, inputs);
}
}
ValueRefList GroupCommTransformation::execute_batch_op() {
auto batch_op = BatchSendRecvOp::make(record_ops);
auto outputs = imperative::apply(*batch_op, send_inputs);
return outputs;
}
void GroupCommTransformation::on_unregister() noexcept {
auto rst = execute_batch_op();
mgb_assert(rst.size() == recv_tensors.size());
for (size_t i = 0; i < rst.size(); i++) {
auto v = recv_tensors[i].lock();
if (v != ValueRef::nil) {
v.reset(rst[i]);
}
}
}
GroupCommTransformation::~GroupCommTransformation() {
for (auto&& recv : recv_tensors) {
mgb_assert(
recv.lock() == ValueRef::nil,
"Some PlaceholderValues are not reset after GroupCommTransformation "
"destroyed!");
};
}
} // namespace imperative
} // namespace mgb
\ No newline at end of file
#pragma once
#include "megbrain/imperative/op_def.h"
namespace mgb::imperative {
struct BatchSendRecvOp final : OpDefImplBase<BatchSendRecvOp> {
SmallVector<std::shared_ptr<OpDef>> op_list;
BatchSendRecvOp() = default;
BatchSendRecvOp(SmallVector<std::shared_ptr<OpDef>> op_list) : op_list{op_list} {}
MGB_DYN_TYPE_OBJ_FINAL_DECL;
};
} // namespace mgb::imperative
\ No newline at end of file
/**
* \file imperative/src/include/megbrain/imperative/scalar.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/imperative/basic_operators.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/ops/autogen.h"
namespace mgb::imperative {
class PlaceholderValue final : public ObjectValue<PlaceholderValue> {
public:
std::string to_string() const override { return ssprintf("PlaceholderValue"); }
void clear() override {}
};
class GroupCommTransformation final : public Transformation {
private:
SmallVector<ValueRef> send_inputs;
std::vector<PlaceholderValue::weak_ref_t> recv_tensors;
SmallVector<std::shared_ptr<OpDef>> record_ops;
ObjectType<PlaceholderValue> m_value_type{"PlaceholderValue"};
public:
GroupCommTransformation() = default;
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override;
ValueRefList execute_batch_op();
ValueRef unwrap(ValueRef value) override { return value; }
std::string name() const override { return "GroupCommTransformation"; }
void on_unregister() noexcept override;
~GroupCommTransformation();
};
} // namespace mgb::imperative
#include "./helper.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/mm_handler.h"
......@@ -47,7 +48,4 @@ TEST(TestImperative, IORemote) {
t0.join();
t1.join();
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// ./imperative_test --gtest_filter TestIORemote
......@@ -151,6 +151,28 @@ void GroupManager::bcast_addr(
}
}
void GroupManager::bcast_nccluniqueid(
const std::string& key, std::string& id, uint32_t size, uint32_t rank,
uint32_t root) {
std::unique_lock<std::mutex> lk{m_key2nccl_id_mtx};
if (rank == root) {
m_key2nccl_id[key] = id;
}
m_key2nccl_id_size[key]++;
if (m_key2nccl_id_size[key] == size) {
m_key2nccl_id_flag[key] = true;
m_bcast_cv.notify_all();
} else {
m_bcast_cv.wait(lk, [&] { return m_key2nccl_id_flag.count(key) > 0; });
}
id = m_key2nccl_id[key];
m_key2nccl_id_size[key]--;
if (m_key2nccl_id_size[key] == 0) {
m_key2nccl_id.erase(key);
m_key2nccl_id_flag.erase(key);
}
}
void GroupManager::set_output_shape(const std::string& key, const TensorShape& shape) {
auto&& group = get_group(key);
group.set_output_shape(key, shape);
......
......@@ -67,6 +67,15 @@ void MegRayCommBuilder::emplace(
m_megray_comms.emplace(hash, comm);
}
void MegRayCommBuilder::remove(
uint64_t hash, std::shared_ptr<MegRay::Communicator> comm) {
std::unique_lock<std::mutex> lk(m_map_mtx);
auto it = m_megray_comms.find(hash);
if (it != m_megray_comms.end()) {
m_megray_comms.erase(hash);
}
}
std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm(
uint64_t hash, std::string key, uint32_t size, uint32_t rank,
MegRay::Backend backend, std::shared_ptr<mgb::opr::GroupClient> group_client) {
......@@ -104,5 +113,3 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm(
MegRayCommBuilder* MegRayCommBuilder::sm_instance = nullptr;
std::mutex MegRayCommBuilder::sm_instance_mtx;
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -45,6 +45,7 @@ public:
RUNSERVER(get_output_shape);
RUNSERVER(bcast_addr);
RUNSERVER(group_barrier);
RUNSERVER(bcast_nccluniqueid);
mgb_assert(false, "invalid rpc request");
}
......@@ -53,6 +54,7 @@ private:
void set_output_shape(void* input_ptr, size_t input_len, std::string* output);
void get_output_shape(void* input_ptr, size_t input_len, std::string* output);
void bcast_addr(void* input_ptr, size_t input_len, std::string* output);
void bcast_nccluniqueid(void* input_ptr, size_t input_len, std::string* output);
void group_barrier(void* input_ptr, size_t input_len, std::string* output);
private:
......@@ -116,6 +118,15 @@ void GroupServerProxy::bcast_addr(
rsp.SerializeToString(output);
}
void GroupServerProxy::bcast_nccluniqueid(
void* input_ptr, size_t input_len, std::string* output) {
INFO_INIT(mm_handler, BcastNcclUniqueId);
std::string id = req.id();
m_mgr.bcast_nccluniqueid(req.key(), id, req.size(), req.rank(), req.root());
rsp.set_id(id);
rsp.SerializeToString(output);
}
void GroupServerProxy::group_barrier(
void* input_ptr, size_t input_len, std::string* output) {
INFO_INIT(mm_handler, GroupBarrier);
......@@ -201,6 +212,19 @@ void GroupClientProxy::bcast_addr(
port = rsp.port();
}
void GroupClientProxy::bcast_nccluniqueid(
const std::string& key, std::string& id, uint32_t size, uint32_t rank,
uint32_t root) {
INFO_INIT(mm_handler, bcast_nccluniqueid, BcastNcclUniqueId);
req.set_id(id.data(), id.size());
req.set_key(key.data(), key.size());
req.set_size(size);
req.set_rank(rank);
req.set_root(root);
SOLVE_REQUEST(func_name, req, rsp);
id = rsp.id();
}
uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) {
INFO_INIT(mm_handler, group_barrier, GroupBarrier);
req.set_size(size);
......@@ -209,6 +233,40 @@ uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) {
return rsp.size();
}
std::shared_ptr<MegRay::Communicator> BatchSendRecvHelper::get(std::string&& key) {
auto ptr = megray_comm_cache.find(key);
if (ptr != megray_comm_cache.end()) {
return megray_comm_cache[key];
} else {
return nullptr;
}
}
std::unordered_map<std::string, std::shared_ptr<MegRay::Communicator>>
BatchSendRecvHelper::megray_comm_cache{};
bool BatchSendRecvHelper::init(
int nranks, int rank, std::string ip, int port, int root) {
auto megray_comm =
MegRay::get_communicator(nranks, rank, MegRay::Backend::MEGRAY_NCCL);
auto group_client =
std::make_shared<opr::GroupClientProxy>(ssprintf("%s:%d", ip.data(), port));
auto cb = [=](char* nccl_buffer, size_t len) {
std::string id;
id.resize(128);
if (rank == root) {
memcpy(id.data(), nccl_buffer, len);
}
group_client->bcast_nccluniqueid("init_all_cards", id, nranks, rank, root);
if (rank != root) {
memcpy(nccl_buffer, id.data(), len);
}
};
megray_comm->init(cb);
return megray_comm_cache.insert({std::string("init_all_cards"), megray_comm})
.second;
}
#undef INFO_INIT
#undef SOLVE_REQUEST
......
......@@ -77,6 +77,11 @@ public:
std::string& master_ip, int& port, const std::string& key, uint32_t size,
uint32_t rank, uint32_t root);
//! bcast uid
void bcast_nccluniqueid(
const std::string& key, std::string& id, uint32_t size, uint32_t rank,
uint32_t root);
//! Set output shape of this key
void set_output_shape(const std::string& key, const TensorShape& shape);
......@@ -101,6 +106,12 @@ private:
std::mutex m_key2addr_mtx;
std::condition_variable m_bcast_cv;
//! key -> ncclid
std::unordered_map<std::string, std::string> m_key2nccl_id;
std::unordered_map<std::string, uint32_t> m_key2nccl_id_size;
std::unordered_map<std::string, bool> m_key2nccl_id_flag;
std::mutex m_key2nccl_id_mtx;
//! barrier
uint32_t m_barrier_size;
std::set<uint32_t> m_barrier_set;
......@@ -128,6 +139,10 @@ public:
std::string& master_ip, int& port, const std::string& key, uint32_t size,
uint32_t rank, uint32_t root) = 0;
virtual void bcast_nccluniqueid(
const std::string& key, std::string& id, uint32_t size, uint32_t rank,
uint32_t root) = 0;
virtual void set_output_shape(const std::string& key, const TensorShape& shape) = 0;
virtual TensorShape get_output_shape(const std::string& key) = 0;
......
......@@ -23,6 +23,7 @@ class MegRayCommBuilder {
private:
bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm);
void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm);
void remove(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm);
std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms;
std::mutex m_map_mtx;
......
......@@ -39,6 +39,10 @@ public:
std::string& master_ip, int& port, const std::string& key, uint32_t size,
uint32_t rank, uint32_t root) override;
void bcast_nccluniqueid(
const std::string& key, std::string& id, uint32_t size, uint32_t rank,
uint32_t root) override;
void set_output_shape(const std::string& key, const TensorShape& shape) override;
TensorShape get_output_shape(const std::string& key) override;
......@@ -52,6 +56,34 @@ private:
void* m_stub;
};
template <typename T>
class ProcessGlobal { // thread safe
public:
template <class... Args>
static std::shared_ptr<T>& getInstance(Args&&... args) {
static auto instance = std::make_shared<T>(std::forward<Args>(args)...);
return instance;
}
protected:
template <class... Args>
ProcessGlobal(Args&&... args);
ProcessGlobal() = default;
public:
ProcessGlobal(ProcessGlobal const&) = delete;
void operator=(ProcessGlobal const&) = delete;
};
class BatchSendRecvHelper : public ProcessGlobal<BatchSendRecvHelper> {
static std::unordered_map<std::string, std::shared_ptr<MegRay::Communicator>>
megray_comm_cache;
public:
std::shared_ptr<MegRay::Communicator> get(std::string&&);
bool init(int nranks, int rank, std::string ip, int port, int root);
};
/* ======================== ZmqRpcServerMgr ========================== */
int create_zmqrpc_server(const std::string& server_addr, int port);
......@@ -60,5 +92,3 @@ int create_zmqrpc_server(const std::string& server_addr, int port);
} // namespace mgb
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -30,6 +30,18 @@ message BcastAddrResponse {
int32 port = 2;
}
message BcastNcclUniqueIdRequest{
string key = 1;
bytes id = 2;
uint32 size =3 ;
uint32 rank = 4;
uint32 root =5;
}
message BcastNcclUniqueIdResponse{
bytes id = 1;
}
message SetOutputShapeRequest {
string key = 1;
TensorShape shape = 2;
......
......@@ -26,6 +26,12 @@ public:
return m_mgr.bcast_addr(master_ip, port, key, size, rank, root);
}
void bcast_nccluniqueid(
const std::string& key, std::string& id, uint32_t size, uint32_t rank,
uint32_t root) override {
return m_mgr.bcast_nccluniqueid(key, id, size, rank, root);
}
void set_output_shape(const std::string& key, const TensorShape& shape) override {
m_mgr.set_output_shape(key, shape);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册