提交 c7e6c658 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

refactor(mge/distribute): use is_root (and rank) in stead of rank and root at collective comm

GitOrigin-RevId: dccdb715533576db97bbc21fe61e640141c1e8b6
上级 ff308e3b
......@@ -26,25 +26,17 @@ def reduce_sum(
tensor: Tensor,
key: str,
nr_ranks: Optional[int] = None,
rank: Optional[int] = None,
root: Optional[int] = 0,
is_root: Optional[bool] = None,
) -> Tensor:
"""Create reduce_sum operator for collective communication
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
:param root: rank of root node, use 0 as default
:param is_root: whether this is a root node
"""
return _collective_comm(
tensor,
key,
CollParam.Mode.REDUCE_SUM,
nr_ranks,
rank,
root,
device=tensor.device,
tensor, key, CollParam.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device,
)
......@@ -52,24 +44,21 @@ def broadcast(
tensor: Tensor,
key: str,
nr_ranks: Optional[int] = None,
rank: Optional[int] = None,
root: Optional[int] = 0,
is_root: Optional[bool] = None,
) -> Tensor:
"""Create broadcast operator for collective communication
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
:param root: rank of root node, use 0 as default
:param is_root: whether this is a root node
"""
if key is None:
key = tensor._symvar.name
if is_root is None:
is_root = get_rank() == 0
if rank is None:
rank = get_rank()
if rank == root:
if is_root:
inp = tensor
else:
inp = tensor._symvar.owner_graph
......@@ -79,8 +68,7 @@ def broadcast(
key,
CollParam.Mode.BROADCAST,
nr_ranks,
rank,
root,
is_root,
dtype=tensor.dtype,
device=tensor.device,
)
......@@ -94,9 +82,9 @@ def all_gather(
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
:param rank: rank of this node
"""
return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank, 0)
return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank=rank)
def reduce_scatter_sum(
......@@ -107,69 +95,58 @@ def reduce_scatter_sum(
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
:param rank: rank of this node
"""
return _collective_comm(
tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank
tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank=rank,
)
def all_reduce_sum(
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
) -> Tensor:
def all_reduce_sum(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor:
"""Create all_reduce_sum operator for collective communication
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
"""
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks, rank)
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks)
def all_reduce_max(
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
) -> Tensor:
def all_reduce_max(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor:
"""Create all_reduce_max operator for collective communication
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
"""
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks, rank)
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks)
def all_reduce_min(
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
) -> Tensor:
def all_reduce_min(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor:
"""Create all_reduce_min operator for collective communication
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
"""
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks, rank)
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks)
def bcast_param(
inp: Union[Buffer, Parameter],
key: str,
nr_ranks: Optional[int] = None,
rank: Optional[int] = None,
root: Optional[int] = 0,
is_root: Optional[bool] = None,
) -> None:
"""Broadcast parameters among devices
:param inp: input Buffer or Parameter to be synchronized
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
:param root: rank of root node, use 0 as default
:param is_root: whether this is a root node
"""
if not is_distributed():
return
assert isinstance(inp, (Buffer, Parameter))
bcast_res = broadcast(inp, key, nr_ranks, rank, root)
bcast_res = broadcast(inp, key, nr_ranks, is_root)
add_update(inp, bcast_res, alpha=0)
......@@ -19,8 +19,8 @@ def collective_comm_symvar(
key: str,
op: CollParam.Mode,
nr_ranks: Optional[int] = None,
is_root: Optional[bool] = None,
rank: Optional[int] = None,
root: Optional[int] = 0,
dtype: Optional[type] = None,
device: Optional[mgb.CompNode] = None,
comp_graph: Optional[mgb.CompGraph] = None,
......@@ -31,8 +31,7 @@ def collective_comm_symvar(
:param key: unique identifier for collective communication
:param op: mode of collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
:param root: rank of root node, use 0 as default
:param is_root: whether this node is root node
:param dtype: output data type, use dtype of inp as default
:param device: output comp node, use comp node of inp as default
:param comp_graph: output comp graph, use comp graph of inp as default
......@@ -41,8 +40,8 @@ def collective_comm_symvar(
inp,
key=str(key),
nr_devices=nr_ranks if nr_ranks is not None else get_world_size(),
rank=rank if rank is not None else get_rank(),
root=root,
is_root=is_root if is_root is not None else (get_rank() == 0),
rank=rank if rank is not None else -1,
server_addr=get_master_ip(),
port=get_master_port(),
param=CollParam(mode=op),
......
......@@ -17,7 +17,13 @@ import numpy as np
from .._internal.config import opr_priority_scope
from ..core import Buffer, Parameter, Tensor, TensorDict
from ..core.graph import get_default_graph
from ..distributed import all_reduce_sum, bcast_param, get_world_size, is_distributed
from ..distributed import (
all_reduce_sum,
bcast_param,
get_rank,
get_world_size,
is_distributed,
)
from ..distributed.util import get_group_id
from ..functional import add_update
from ..functional import grad as grad_func
......@@ -222,7 +228,11 @@ class Optimizer(metaclass=ABCMeta):
def bcast_param(self):
for group in self.param_groups:
for param in group["params"]:
bcast_param(param, "bcast_param_" + str(get_group_id()))
bcast_param(
param,
"bcast_param_" + str(get_group_id()),
is_root=(get_rank() == 0),
)
def state_dict(self) -> Dict:
r"""Export the optimizer state.
......
......@@ -93,10 +93,9 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port,
}
SymbolVar _Opr::collective_comm_with_input(
SymbolVar inpvar, const std::string& key,
const size_t nr_devices, const uint32_t rank, const uint32_t root,
const std::string& server_addr, const int port,
PyObject* params, PyObject* dtype,
SymbolVar inpvar, const std::string& key, const size_t nr_devices,
const bool is_root, const int rank, const std::string& server_addr,
const int port, PyObject* params, PyObject* dtype,
const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable) {
SymbolVarArray inputs(1, inpvar);
......@@ -111,15 +110,15 @@ SymbolVar _Opr::collective_comm_with_input(
if (dtype != Py_None) {
_dtype = npy::dtype_np2mgb(dtype);
}
return CollectiveComm::make(inputs, graph, key, nr_devices, rank, root, group_mgr,
dev_buffer_arr, param, _dtype, backend, config, disable.get_val())[0];
return CollectiveComm::make(inputs, graph, key, nr_devices, is_root, rank,
group_mgr, dev_buffer_arr, param, _dtype,
backend, config, disable.get_val())[0];
}
SymbolVar _Opr::collective_comm_without_input(
CompGraph& cg, const std::string& key,
const size_t nr_devices, const uint32_t rank, const uint32_t root,
const std::string& server_addr, const int port,
PyObject* params, PyObject* dtype,
CompGraph& cg, const std::string& key, const size_t nr_devices,
const bool is_root, const int rank, const std::string& server_addr,
const int port, PyObject* params, PyObject* dtype,
const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable) {
SymbolVarArray inputs;
......@@ -134,8 +133,9 @@ SymbolVar _Opr::collective_comm_without_input(
if (dtype != Py_None) {
_dtype = npy::dtype_np2mgb(dtype);
}
return CollectiveComm::make(inputs, &graph, key, nr_devices, rank, root, group_mgr,
dev_buffer_arr, param, _dtype, backend, config, disable.get_val())[0];
return CollectiveComm::make(inputs, &graph, key, nr_devices, is_root, rank,
group_mgr, dev_buffer_arr, param, _dtype,
backend, config, disable.get_val())[0];
}
#else
......@@ -172,7 +172,7 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port,
SymbolVar _Opr::collective_comm_with_input(
SymbolVar inpvar, const std::string& key,
const size_t nr_devices, const uint32_t rank, const uint32_t root,
const size_t nr_devices, const bool is_root, const int rank,
const std::string& server_addr, const int port, PyObject* params,
PyObject* dtype, const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable) {
......@@ -181,7 +181,7 @@ SymbolVar _Opr::collective_comm_with_input(
SymbolVar _Opr::collective_comm_without_input(
CompGraph& cg, const std::string& key,
const size_t nr_devices, const uint32_t rank, const uint32_t root,
const size_t nr_devices, const bool is_root, const int rank,
const std::string& server_addr, const int port, PyObject* params,
PyObject* dtype, const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable) {
......
......@@ -94,17 +94,17 @@ static SymbolVar remote_recv(const std::string& server_addr, const int port,
static SymbolVar collective_comm_with_input(
SymbolVar inpvar, const std::string& key, const size_t nr_devices,
const uint32_t rank, const uint32_t root, const std::string& server_addr,
const int port, PyObject* params, PyObject* dtype,
const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable);
const bool is_root, const int rank, const std::string& server_addr, const int port,
PyObject* params, PyObject* dtype, const std::string& backend,
SharedND* output_buf, const OperatorNodeConfig& config,
const SharedScalar& disable);
static SymbolVar collective_comm_without_input(
CompGraph& graph, const std::string& key, const size_t nr_devices,
const uint32_t rank, const uint32_t root, const std::string& server_addr,
const int port, PyObject* params, PyObject* dtype,
const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable);
const bool is_root, const int rank, const std::string& server_addr, const int port,
PyObject* params, PyObject* dtype, const std::string& backend,
SharedND* output_buf, const OperatorNodeConfig& config,
const SharedScalar& disable);
// misc
static SymbolVarArray extern_c_opr_placeholder(
......
......@@ -102,7 +102,7 @@ def test_all_gather():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data)
output = dist.functional.all_gather(inp, "x")
output = dist.functional.all_gather(inp, "x", rank=rank)
assert np.allclose(output.numpy(), expect)
def check(shape, backend):
......@@ -135,7 +135,7 @@ def test_reduce_scatter_sum():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data)
output = dist.functional.reduce_scatter_sum(inp, "x")
output = dist.functional.reduce_scatter_sum(inp, "x", rank=rank)
assert np.allclose(output.numpy(), expect)
def check(shape, backend):
......
......@@ -368,8 +368,8 @@ CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) {
CollectiveComm::CollectiveComm(
VarNodeArray inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const uint32_t rank,
const uint32_t root, std::shared_ptr<GroupClient> group_client,
const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype, const std::string& backend,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const OperatorNodeConfig& config,
......@@ -380,9 +380,9 @@ CollectiveComm::CollectiveComm(
m_backend(backend),
m_group_client{std::move(group_client)},
m_nr_devices(nr_devices),
m_is_root(is_root),
m_rank(rank),
m_key(key),
m_root(root),
m_dev_buffers(dev_buffer_arr),
m_disable{disable} {
for (auto i : inputs) {
......@@ -422,28 +422,28 @@ CollectiveComm::CollectiveComm(
SymbolVarArray CollectiveComm::make(
const SymbolVarArray& inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const uint32_t rank,
const uint32_t root, std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype, const std::string& backend,
const OperatorNodeConfig& config,
const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype, const std::string& backend,
const OperatorNodeConfig& config,
const std::shared_ptr<DTypeScalar>& disable) {
SmallVector<std::shared_ptr<DeviceTensorND>> dev_buffer_arr(nr_devices,
nullptr);
return make(inputs, graph, key, nr_devices, rank, root, group_client,
return make(inputs, graph, key, nr_devices, is_root, rank, group_client,
dev_buffer_arr, param, dtype, backend, config);
}
SymbolVarArray CollectiveComm::make(
const SymbolVarArray& inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const uint32_t rank,
const uint32_t root, std::shared_ptr<GroupClient> group_client,
const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const Param& param, const DType& dtype, const std::string& backend,
const OperatorNodeConfig& config,
const std::shared_ptr<DTypeScalar>& disable) {
auto inpvars = cg::to_var_node_array(inputs);
auto opr = graph->insert_opr(std::make_unique<CollectiveComm>(
inpvars, graph, key, nr_devices, rank, root, std::move(group_client),
inpvars, graph, key, nr_devices, is_root, rank, std::move(group_client),
param, dtype, backend, dev_buffer_arr, config, disable));
mgb_assert(!opr->output().empty());
return cg::to_symbol_var_array(opr->output());
......@@ -452,11 +452,14 @@ SymbolVarArray CollectiveComm::make(
void CollectiveComm::opr_register() {
if (m_init)
return;
auto&& cuda_env = CompNodeEnv::from_comp_node(output(0)->comp_node())
.cuda_env();
auto&& comp_node = output(0)->comp_node();
auto reg_info = m_group_client->opr_register(
m_key, m_nr_devices, m_is_root, m_rank,
comp_node.get_uid());
auto hash = m_group_client->opr_register(m_key, m_nr_devices, m_rank,
reinterpret_cast<uintptr_t>(cuda_env.stream));
m_rank = reg_info.rank;
m_root = reg_info.root_rank;
MegRayCommunicatorBuilder* builder;
......@@ -468,7 +471,7 @@ void CollectiveComm::opr_register() {
}
m_megray_comm = builder->get_megray_comm(
hash, m_key, m_nr_devices, m_rank,
reg_info.hash, m_key, m_nr_devices, m_rank,
get_megray_backend(m_backend), m_group_client);
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
......@@ -606,8 +609,8 @@ VarNodeArray CollectiveComm::grad(const VarNodeArray& out_grads) const {
}
auto gvar = CollectiveComm::make(
og_syms, owner_graph(), m_key + ":grad", m_nr_devices, m_rank, m_root,
m_group_client, mode, m_dtype, m_backend,
og_syms, owner_graph(), m_key + ":grad", m_nr_devices, m_is_root,
m_rank, m_group_client, mode, m_dtype, m_backend,
OperatorNodeConfig{}.comp_node_arr(cn_arr));
if (m_param.mode == Param::Mode::ALL_REDUCE_MAX) {
......@@ -733,11 +736,11 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm(
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>();
return opr::CollectiveComm::make(to_symbol_var_array(inputs),
ctx.owner_graph(opr_, inputs), opr.key(),
opr.nr_devices(), opr.rank(), opr.root(),
opr.group_client(), opr.dev_buffers(),
opr.param(), opr.dtype(), opr.backend(), config)[0]
return opr::CollectiveComm::make(
to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs),
opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(),
opr.group_client(), opr.dev_buffers(), opr.param(),
opr.dtype(), opr.backend(), config)[0]
.node()
->owner_opr();
}
......
......@@ -6,8 +6,8 @@ decl_raw_opr(
'to the same NCCL operation.', 'str'),
Doc('nr_devices', 'Total number of devices involved in the NCCL '
'operation to which this operator belongs.', 'int'),
Doc('rank', 'Rank of this operator', 'int'),
Doc('root', 'root rank of broadcast or reduce operation'),
Doc('is_root', 'whether this node is root node', 'bool'),
Doc('rank', 'rank of this node, if is -1, generate one', 'int'),
Doc('server_addr', 'rpc server ip address'),
Doc('port', 'server rpc listening port'),
Doc('param', 'The only component of *param* is *mode*, which refers to '
......@@ -28,12 +28,12 @@ decl_raw_opr(
body = [
'if isinstance(input, _mgb.SymbolVar):',
(' output = _mgb._Opr.collective_comm_with_input(input, key, '
'nr_devices, rank, root, server_addr, port, '
'nr_devices, is_root, rank, server_addr, port, '
'[param.serialize()], dtype, backend, output_buffer, config, disable)'),
'else:',
' assert isinstance(input, _mgb.CompGraph)',
(' output = _mgb._Opr.collective_comm_without_input(input, key, '
'nr_devices, rank, root, server_addr, port, '
'nr_devices, is_root, rank, server_addr, port, '
'[param.serialize()], dtype, backend, output_buffer, config, disable)')
],
desc = ('collective communication between multiple CompNodes on multiple '
......
......@@ -16,16 +16,60 @@ using namespace opr;
/* ================= GroupInfo ================= */
void GroupInfo::sort_opr_infos() {
auto cmp = [](const GroupInfo::OprInfo& a, const GroupInfo::OprInfo& b) {
return a.comp_node_hash < b.comp_node_hash;
};
std::sort(m_opr_infos.begin(), m_opr_infos.end(), cmp);
}
void GroupInfo::gen_infos_from_opr_infos() {
// generate rank
bool rank_assgined = true;
for (auto& opr_info:m_opr_infos) {
if(opr_info.rank < 0) {
rank_assgined = false;
break;
}
}
if (!rank_assgined) {
for (size_t i = 0; i < m_opr_infos.size(); i++) {
m_opr_infos[i].rank = i;
m_rank_map.insert({m_opr_infos[i].comp_node_hash, i});
}
} else {
for (size_t i = 0; i < m_opr_infos.size(); i++) {
m_rank_map.insert(
{m_opr_infos[i].comp_node_hash, m_opr_infos[i].rank});
}
}
// generate root rank
for (auto& opr_info:m_opr_infos) {
if (opr_info.is_root) {
m_root_rank = opr_info.rank;
break;
}
}
// generate group hash
auto xxhash = XXHash{};
for (auto&& opr_info : m_opr_infos) {
xxhash.update(&opr_info.comp_node_hash, sizeof(uint64_t))
.update(&opr_info.rank, sizeof(int));
}
m_hash = xxhash.digest();
}
void GroupInfo::add_opr(const std::string& key, size_t nr_expected_devices,
uint32_t rank, uintptr_t stream) {
bool is_root, int rank, uint64_t comp_node_hash) {
std::unique_lock<std::mutex> lk{m_group_mtx};
if (m_nr_expected_devs == 0) {
m_nr_expected_devs = nr_expected_devices;
} else {
mgb_assert(m_nr_expected_devs == nr_expected_devices);
}
OprInfo opr_info = {rank, stream};
m_opr_infos.push_back(std::move(opr_info));
m_opr_infos.push_back({comp_node_hash, is_root, rank});
m_nr_registered_devs++;
m_count++;
if (m_nr_registered_devs > nr_expected_devices) {
......@@ -38,6 +82,8 @@ void GroupInfo::add_opr(const std::string& key, size_t nr_expected_devices,
key.c_str(), nr_expected_devices, m_nr_registered_devs);
}
if (m_nr_expected_devs == m_nr_registered_devs) {
sort_opr_infos();
gen_infos_from_opr_infos();
m_register_cv.notify_all();
} else {
m_register_cv.wait(lk,
......@@ -66,6 +112,8 @@ void GroupInfo::clear() {
m_count--;
if (m_count == 0) {
m_opr_infos.clear();
m_rank_map.clear();
m_root_rank = -1;
m_nr_expected_devs = 0;
m_nr_registered_devs = 0;
m_output_shape.invalidate();
......@@ -77,14 +125,18 @@ void GroupInfo::clear() {
/* ================= GroupManager ================= */
uint64_t GroupManager::opr_register(const std::string& key, size_t nr_devices,
uint32_t rank, uintptr_t stream) {
GroupManager::RegisterInfo GroupManager::opr_register(const std::string& key,
size_t nr_devices,
bool is_root, int rank,
uint64_t comp_node_hash) {
GroupManager::RegisterInfo ret{0, 0, 0};
auto&& group = get_group(key);
group.add_opr(key, nr_devices, rank, stream);
auto&& opr_infos = group.opr_infos();
uint64_t hash = get_hash_key(opr_infos, rank);
group.add_opr(key, nr_devices, is_root, rank, comp_node_hash);
ret.rank = group.get_rank(comp_node_hash);
ret.root_rank = group.get_root_rank();
ret.hash = group.get_group_hash() + ret.rank;
group.clear();
return hash;
return ret;
}
std::vector<std::string> GroupManager::gather_uid(const std::string& uid,
......@@ -126,22 +178,6 @@ GroupInfo& GroupManager::get_group(const std::string& key) {
return m_key2group_info[key];
}
uint64_t GroupManager::get_hash_key(const std::vector<GroupInfo::OprInfo>& _infos,
uint32_t rank) {
auto cmp = [](const GroupInfo::OprInfo& lhs, const GroupInfo::OprInfo& rhs) {
return lhs.rank < rhs.rank;
};
auto infos = _infos;
std::sort(infos.begin(), infos.end(), cmp);
auto xxhash = XXHash{};
for (auto&& opr_info : infos) {
xxhash.update(&opr_info.rank, sizeof(uint32_t))
.update(&opr_info.stream, sizeof(uintptr_t));
}
xxhash.update(&rank, sizeof(uint32_t));
return xxhash.digest();
};
uint32_t GroupManager::group_barrier(uint32_t size, uint32_t rank) {
std::unique_lock<std::mutex> lk{m_barrier_mtx};
if (m_barrier_set.empty()) {
......
......@@ -48,12 +48,11 @@ SymbolVar RemoteSend::make(const PeerDesc& peer, SymbolVar var,
void RemoteSend::scn_do_execute() {
if (!m_init) {
auto&& cuda_env = CompNodeEnv::from_comp_node(output(0)->comp_node())
.cuda_env();
auto&& comp_node = output(0)->comp_node();
// rank 0 for RemoteSend
auto hash = m_group_client->opr_register(m_peer.key, 2, 0,
reinterpret_cast<uintptr_t>(cuda_env.stream));
auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false,
comp_node.get_uid());
auto megray_comm_builder =
owner_graph()
......@@ -62,7 +61,7 @@ void RemoteSend::scn_do_execute() {
.get_user_data_or_create<MegRayCommunicatorBuilder>();
m_megray_comm = megray_comm_builder->get_megray_comm(
hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client);
reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client);
m_init = true;
}
......@@ -152,12 +151,12 @@ SymbolVar RemoteRecv::make(const PeerDesc& peer, cg::ComputingGraph& graph,
void RemoteRecv::scn_do_execute() {
if (!m_init) {
auto&& cuda_env = CompNodeEnv::from_comp_node(output(0)->comp_node())
.cuda_env();
auto&& comp_node = output(0)->comp_node();
// rank 1 for RemoteRecv
auto hash = m_group_client->opr_register(m_peer.key, 2, 1,
reinterpret_cast<uintptr_t>(cuda_env.stream));
auto reg_info = m_group_client->opr_register(
m_peer.key, 2, false, 1,
comp_node.get_uid());
auto megray_comm_builder =
owner_graph()
......@@ -166,7 +165,7 @@ void RemoteRecv::scn_do_execute() {
.get_user_data_or_create<MegRayCommunicatorBuilder>();
m_megray_comm = megray_comm_builder->get_megray_comm(
hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client);
reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client);
m_init = true;
}
......
......@@ -68,9 +68,11 @@ private:
void GroupServerProxy::opr_register(void* input_ptr, size_t input_len,
std::string *output) {
INFO_INIT(mm_handler, OprRegister);
uint64_t hash = m_mgr.opr_register(req.key(), req.nr_expected_devices(),
req.rank(), req.stream());
rsp.set_hash(hash);
auto ret = m_mgr.opr_register(req.key(), req.nr_expected_devices(),
req.is_root(), req.rank(), req.comp_node_hash());
rsp.set_hash(ret.hash);
rsp.set_rank(ret.rank);
rsp.set_root_rank(ret.root_rank);
rsp.SerializeToString(output);
}
......@@ -122,11 +124,11 @@ void GroupServerProxy::group_barrier(void* input_ptr, size_t input_len,
/* ======================== GroupClientProxy ========================== */
#define INFO_INIT(space, f_name, name) \
#define INFO_INIT(space, f_name, name) \
using Request = space::name##Request; \
using Response = space::name##Response; \
std::string func_name = #f_name; \
Request req; \
std::string func_name = #f_name; \
Request req; \
Response rsp;
#define SOLVE_REQUEST(name, req, rsp) \
......@@ -145,15 +147,18 @@ GroupClientProxy::GroupClientProxy(const std::string& server_addr)
m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} {
}
uint64_t GroupClientProxy::opr_register(const std::string& key, size_t nr_devices,
uint32_t rank, uintptr_t stream) {
GroupManager::RegisterInfo GroupClientProxy::opr_register(
const std::string& key, size_t nr_devices, bool is_root, int rank,
uint64_t comp_node_hash) {
INFO_INIT(mm_handler, opr_register, OprRegister)
req.set_key(key);
req.set_is_root(is_root);
req.set_rank(rank);
req.set_stream(stream);
req.set_comp_node_hash(comp_node_hash);
req.set_nr_expected_devices(nr_devices);
SOLVE_REQUEST(func_name, req, rsp);
return rsp.hash();
GroupManager::RegisterInfo ret{rsp.hash(), rsp.rank(), rsp.root_rank()};
return ret;
}
void GroupClientProxy::set_output_shape(const std::string& key,
......
......@@ -26,18 +26,19 @@ public:
using Param = megdnn::param::CollectiveComm;
CollectiveComm(VarNodeArray inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const uint32_t rank,
const uint32_t root, std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype, const std::string& backend,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const OperatorNodeConfig& config,
const std::shared_ptr<DTypeScalar>& disable);
CollectiveComm(
VarNodeArray inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype, const std::string& backend,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const OperatorNodeConfig& config,
const std::shared_ptr<DTypeScalar>& disable);
static SymbolVarArray make(
const SymbolVarArray& inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const uint32_t rank,
const uint32_t root, std::shared_ptr<GroupClient> group_client,
const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const Param& param, const DType& dtype = {},
const std::string& backend = "nccl",
......@@ -45,15 +46,16 @@ public:
const std::shared_ptr<DTypeScalar>& disable =
std::make_shared<DTypeScalar>(0));
static SymbolVarArray make(
const SymbolVarArray& inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const uint32_t rank,
const uint32_t root, std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype = {},
const std::string& backend = "nccl",
const OperatorNodeConfig& config = {},
const std::shared_ptr<DTypeScalar>& disable =
std::make_shared<DTypeScalar>(0));
static SymbolVarArray make(const SymbolVarArray& inputs,
ComputingGraph* const graph,
const std::string& key, const size_t nr_devices,
const bool is_root, const int rank,
std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype = {},
const std::string& backend = "nccl",
const OperatorNodeConfig& config = {},
const std::shared_ptr<DTypeScalar>& disable =
std::make_shared<DTypeScalar>(0));
const Param& param() const { return m_param; }
const DType& dtype() const { return m_dtype; }
......@@ -67,9 +69,9 @@ public:
return m_dev_buffers;
}
uint32_t rank() const { return m_rank; }
uint32_t root() const { return m_root; }
bool is_root() const { return m_rank == m_root; }
int rank() const { return m_rank; }
int root() const { return m_root; }
bool is_root() const { return m_is_root; }
//! The key that identifies an NCCL clique.
//! Operators with same keys belong to the same clique.
......@@ -108,12 +110,13 @@ private:
std::shared_ptr<GroupClient> m_group_client;
size_t m_nr_devices = 0;
uint32_t m_rank;
bool m_is_root;
int m_rank;
std::string m_key;
//! XXHash generated from m_key
size_t m_hash;
//! root of BROADCAST and REDUCE operation
uint32_t m_root;
int m_root;
//! rank of root of BROADCAST and REDUCE operation
Maybe<TensorShape> m_broadcast_output_shape = None;
// Whether shape infer is enabled. This is only used by BROADCAST operation,
......
......@@ -24,12 +24,13 @@ namespace opr {
class GroupInfo {
public:
struct OprInfo {
uint32_t rank;
uintptr_t stream;
uint64_t comp_node_hash;
bool is_root;
int rank;
};
void add_opr(const std::string& key, size_t nr_expected_devices,
uint32_t graph_id, uintptr_t stream);
bool is_root, int rank, uint64_t comp_node_hash);
void set_output_shape(const std::string& key, const TensorShape& shape);
......@@ -37,15 +38,25 @@ class GroupInfo {
void clear();
const std::vector<OprInfo>& opr_infos() const {return m_opr_infos; }
const std::vector<OprInfo>& opr_infos() const { return m_opr_infos; }
int get_root_rank() const { return m_root_rank; }
int get_rank(uint64_t hash) const { return m_rank_map.at(hash); }
uint64_t get_group_hash() const { return m_hash; }
private:
void sort_opr_infos();
void gen_infos_from_opr_infos();
std::vector<OprInfo> m_opr_infos;
std::unordered_map<uint64_t, int> m_rank_map;
uint64_t m_hash;
uint32_t m_nr_registered_devs;
uint32_t m_nr_expected_devs;
Maybe<TensorShape> m_output_shape;
uint32_t m_count = 0;
int m_root_rank = -1;
std::mutex m_group_mtx;
std::condition_variable m_register_cv;
std::condition_variable m_clear_cv;
......@@ -61,10 +72,16 @@ class GroupManager {
public:
~GroupManager() = default;
struct RegisterInfo
{
uint64_t hash;
int rank, root_rank;
};
//! register oprs' info to server, return deduplicated hash
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank,
uintptr_t stream);
RegisterInfo opr_register(const std::string& key, size_t nr_devices,
bool is_root, int rank, uint64_t comp_node_hash);
//! gather uids from all ranks
std::vector<std::string> gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank);
......@@ -80,9 +97,6 @@ class GroupManager {
private:
GroupInfo& get_group(const std::string& key);
uint64_t get_hash_key(const std::vector<GroupInfo::OprInfo>& _infos,
uint32_t rank);
//! key -> group info.
std::unordered_map<std::string, GroupInfo> m_key2group_info;
......@@ -112,9 +126,11 @@ class GroupClient {
virtual ~GroupClient() = default;
public:
virtual uint64_t opr_register(const std::string& key, size_t nr_devices,
uint32_t rank, uintptr_t stream) = 0;
virtual GroupManager::RegisterInfo opr_register(const std::string& key,
size_t nr_devices,
bool is_root, int rank,
uint64_t comp_node_hash) = 0;
virtual std::vector<std::string> gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank) = 0;
......
......@@ -14,6 +14,7 @@
#if MGB_ENABLE_OPR_MM
#include "megbrain/opr/collective_comm.h"
#include "megbrain/opr/group_manager.h"
using namespace mgb;
using namespace opr;
......@@ -31,8 +32,10 @@ public:
GroupClientProxy(const std::string& server_addr);
//! graph registration, assign graph_id to worker.
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank,
uintptr_t stream) override;
GroupManager::RegisterInfo opr_register(const std::string& key,
size_t nr_devices, bool is_root,
int rank,
uint64_t comp_node_hash) override;
std::vector<std::string> gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank) override;
......
......@@ -4,13 +4,16 @@ package mm_handler;
message OprRegisterRequest {
string key = 1;
uint32 rank = 2;
uint64 stream = 3;
uint32 nr_expected_devices = 4;
bool is_root = 2;
int32 rank = 3;
uint64 comp_node_hash = 4;
uint32 nr_expected_devices = 5;
}
message OprRegisterResponse {
uint64 hash = 1;
uint64 hash = 1;
int32 rank = 2;
int32 root_rank = 3;
}
message GatherUidRequest {
......
......@@ -45,9 +45,11 @@ class MockGroupClient final : public opr::GroupClient {
public:
~MockGroupClient() override = default;
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank,
uintptr_t stream) {
return m_mgr.opr_register(key, nr_devices, rank, stream);
opr::GroupManager::RegisterInfo opr_register(const std::string& key,
size_t nr_devices,
bool is_root, int rank,
uintptr_t stream) {
return m_mgr.opr_register(key, nr_devices, is_root, rank, stream);
}
std::vector<std::string> gather_uid(const std::string& uid,
......@@ -94,9 +96,9 @@ TEST(TestOprCollectiveComm, AllReduce) {
auto x1c = opr::Copy::make(x1, cn1);
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "all_reduce",
2, 0, 0, client, {mode}, dtype::Float32(), "nccl")[0];
2, false, 0, client, {mode}, dtype::Float32(), "nccl")[0];
auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "all_reduce",
2, 1, 0, client, {mode}, dtype::Float32(), "nccl")[0];
2, false, 1, client, {mode}, dtype::Float32(), "nccl")[0];
auto y_expect = make_all_reduce_output(mode, {x0, x1});
auto func = graph->compile({make_callback_copy(y0, host_y0),
......@@ -130,7 +132,7 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) {
auto graph0 = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce",
2, 0, 0, client, {mode}, dtype::Float32(), "nccl")[0];
2, false, 0, client, {mode}, dtype::Float32(), "nccl")[0];
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)});
func0->execute();
};
......@@ -139,7 +141,7 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) {
auto graph1 = ComputingGraph::make();
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce",
2, 1, 0, client, {mode}, dtype::Float32(), "nccl")[0];
2, false, 1, client, {mode}, dtype::Float32(), "nccl")[0];
auto func1 = graph1->compile({make_callback_copy(y1, host_y1)});
func1->execute();
};
......@@ -192,7 +194,7 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) {
graph0->options().graph_opt_level = 0;
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce", 2, 0, 0, client,
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce", 2, false, 0, client,
{Mode::ALL_REDUCE_SUM}, dtype::Float32(), "nccl")[0];
y0.node()->owner_opr()->node_prop().attribute().priority = -1;
......@@ -211,7 +213,7 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) {
graph1->options().graph_opt_level = 0;
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce", 2, 1, 0, client,
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce", 2, false, 1, client,
{Mode::ALL_REDUCE_SUM}, dtype::Float32(), "nccl")[0];
y1.node()->owner_opr()->node_prop().attribute().priority = -1;
......@@ -274,9 +276,9 @@ TEST(TestOprCollectiveComm, AllGather) {
auto x1c = opr::Copy::make(x1, cn1);
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "all_gather",
2, 0, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0];
2, false, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0];
auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "all_gather",
2, 1, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0];
2, false, 1, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0];
auto y_expect = opr::Concat::make({x0, x1}, 0);
auto func = graph->compile({make_callback_copy(y0, host_y0),
......@@ -303,7 +305,7 @@ TEST(TestOprCollectiveComm, AllGatherMultiThread) {
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, 0, 0, client,
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, false, 0, client,
{Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0];
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)});
func0->execute();
......@@ -312,7 +314,7 @@ TEST(TestOprCollectiveComm, AllGatherMultiThread) {
auto run_1 = [&]() { // rank 1
auto graph1 = ComputingGraph::make();
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, 1, 0, client,
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, false, 1, client,
{Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0];
auto func1 = graph1->compile({make_callback_copy(y1, host_y1)});
func1->execute();
......@@ -361,7 +363,7 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) {
graph0->options().graph_opt_level = 0;
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, 0, 0, client,
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, false, 0, client,
{Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0];
y0.node()->owner_opr()->node_prop().attribute().priority = -1;
......@@ -380,7 +382,7 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) {
graph1->options().graph_opt_level = 0;
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, 1, 0, client,
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, false, 1, client,
{Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0];
y1.node()->owner_opr()->node_prop().attribute().priority = -1;
......@@ -444,9 +446,9 @@ TEST(TestOprCollectiveComm, ReduceScatterSum) {
auto x1c = opr::Copy::make(x1, cn1);
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "reduce_scatter_sum",
2, 0, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "reduce_scatter_sum",
2, 1, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
auto y_expect = make_reduce_scatter_sum_output({x0, x1});
auto func = graph->compile({make_callback_copy(y0, host_y0),
......@@ -475,7 +477,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumMultiThread) {
auto graph0 = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce_scatter_sum",
2, 0, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)});
func0->execute();
};
......@@ -484,7 +486,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumMultiThread) {
auto graph1 = ComputingGraph::make();
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce_scatter_sum",
2, 1, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
auto func1 = graph1->compile({make_callback_copy(y1, host_y1)});
func1->execute();
};
......@@ -534,7 +536,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) {
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce_scatter_sum",
2, 0, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
y0.node()->owner_opr()->node_prop().attribute().priority = -1;
auto grad0 = opr::Host2DeviceCopy::make(*graph0, host_grad0, cn0);
......@@ -553,7 +555,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) {
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce_scatter_sum",
2, 1, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0];
y1.node()->owner_opr()->node_prop().attribute().priority = -1;
auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1);
......@@ -616,9 +618,9 @@ TEST(TestOprCollectiveComm, ReduceSum) {
auto x1c = opr::Copy::make(x1, cn1);
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "reduce_sum",
2, 0, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0];
2, true, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0];
auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "reduce_sum",
2, 1, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0];
2, false, 1, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0];
auto y_expect = x0 + x1;
auto func = graph->compile({make_callback_copy(y0, host_y0),
......@@ -644,7 +646,7 @@ TEST(TestOprCollectiveComm, ReduceSumMultiThread) {
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, 0, 0, client,
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, true, 0, client,
{Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0];
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)});
func0->execute();
......@@ -653,7 +655,7 @@ TEST(TestOprCollectiveComm, ReduceSumMultiThread) {
auto run_1 = [&]() { // rank 1
auto graph1 = ComputingGraph::make();
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, 1, 0, client,
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, false, 1, client,
{Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0];
auto func1 = graph1->compile({{y1, nullptr}});
func1->execute();
......@@ -699,7 +701,7 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) {
graph0->options().graph_opt_level = 0;
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, 0, 0, client,
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, true, 0, client,
{Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0];
y0.node()->owner_opr()->node_prop().attribute().priority = -1;
......@@ -718,7 +720,7 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) {
graph1->options().graph_opt_level = 0;
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, 1, 0, client,
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, false, 1, client,
{Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0];
y1.node()->owner_opr()->node_prop().attribute().priority = -1;
......@@ -767,12 +769,12 @@ TEST(TestOprCollectiveComm, Broadcast) {
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "broadcast",
2, 0, 0, client, {Mode::BROADCAST}, dtype::Float32(), "nccl")[0];
2, true, 0, client, {Mode::BROADCAST}, dtype::Float32(), "nccl")[0];
auto y_dev = std::make_shared<DeviceTensorND>(DeviceTensorND()
.comp_node(cn1)
.dtype(dtype::Float32())
.resize(host_x0->shape()));
auto y1 = opr::CollectiveComm::make({}, graph.get(), "broadcast", 2, 1, 0,
auto y1 = opr::CollectiveComm::make({}, graph.get(), "broadcast", 2, false, 1,
client, {y_dev}, {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0];
auto func = graph->compile({make_callback_copy(y0, host_y0),
......@@ -797,7 +799,7 @@ TEST(TestOprCollectiveComm, BroadcastMultiThread) {
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, 0, 0, client,
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, true, 0, client,
{Mode::BROADCAST}, dtype::Float32(), "nccl")[0];
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)});
func0->execute();
......@@ -809,7 +811,7 @@ TEST(TestOprCollectiveComm, BroadcastMultiThread) {
.comp_node(cn1)
.dtype(dtype::Float32())
.resize(host_x0->shape()));
auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, 1, 0, client,
auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, false, 1, client,
{y_dev}, {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0];
auto func1 = graph1->compile({make_callback_copy(y1, host_y1)});
func1->execute();
......@@ -845,7 +847,7 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) {
graph0->options().graph_opt_level = 0;
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, 0, 0, client,
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, true, 0, client,
{Mode::BROADCAST}, dtype::Float32(), "nccl")[0];
y0.node()->owner_opr()->node_prop().attribute().priority = -1;
......@@ -863,11 +865,11 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) {
auto graph1 = ComputingGraph::make();
graph1->options().graph_opt_level = 0;
auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, 1, 0, client,
auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, false, 1, client,
{Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0];
auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1);
auto g = opr::CollectiveComm::make({grad1}, graph1.get(), "broadcast:grad", 2, 1, 0, client,
auto g = opr::CollectiveComm::make({grad1}, graph1.get(), "broadcast:grad", 2, false, 1, client,
Mode::REDUCE_SUM, dtype::Float32(), "nccl")[0];
g.node()->owner_opr()->node_prop().attribute().priority = 1;
......
......@@ -26,11 +26,13 @@ class MockGroupClient final : public opr::GroupClient {
public:
~MockGroupClient() override = default;
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank,
uintptr_t stream) {
return m_mgr.opr_register(key, nr_devices, rank, stream);
opr::GroupManager::RegisterInfo opr_register(const std::string& key,
size_t nr_devices,
bool is_root, int rank,
uint64_t comp_node_hash) {
return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash);
}
std::vector<std::string> gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank) {
return m_mgr.gather_uid(uid, key, size, rank);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册