提交 328fb36f 编写于 作者: M Megvii Engine Team

feat(mgb/opr-mm): add Scatter, Gather, AllToAll oprs

GitOrigin-RevId: f75169ecd6ca43c19589408a5bc2ef68e6655fd9
上级 3f51a6a0
......@@ -40,6 +40,32 @@ def reduce_sum(
)
def gather(
tensor: Tensor,
key: str,
nr_ranks: Optional[int] = None,
is_root: Optional[bool] = None,
rank: Optional[int] = None,
) -> Tensor:
"""Create gather operator for collective communication
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param is_root: whether this is a root node
:param rank: rank of this node
"""
return _collective_comm(
tensor,
key,
CollParam.Mode.GATHER,
nr_ranks,
is_root,
rank,
device=tensor.device,
)
def broadcast(
tensor: Tensor,
key: str,
......@@ -74,6 +100,56 @@ def broadcast(
)
def scatter(
tensor: Tensor,
key: str,
nr_ranks: Optional[int] = None,
is_root: Optional[bool] = None,
rank: Optional[int] = None,
) -> Tensor:
"""Create scatter operator for collective communication
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param is_root: whether this is a root node
:param rank: rank of this node
"""
if key is None:
key = tensor._symvar.name
if is_root is None:
is_root = get_rank() == 0
if is_root:
inp = tensor
else:
inp = tensor._symvar.owner_graph
return _collective_comm(
inp,
key,
CollParam.Mode.SCATTER,
nr_ranks,
is_root,
rank,
dtype=tensor.dtype,
device=tensor.device,
)
def all_to_all(
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
) -> Tensor:
"""Create all_to_all operator for collective communication
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of this node
"""
return _collective_comm(tensor, key, CollParam.Mode.ALL_TO_ALL, nr_ranks, rank=rank)
def all_gather(
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
) -> Tensor:
......
......@@ -61,6 +61,42 @@ def test_reduce_sum():
check(shape, backend)
@pytest.mark.isolated_distributed
def test_gather():
world_size = 2
def worker(rank, data, backend, expect, port_queue):
if not mge.is_cuda_available():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data)
output = dist.functional.gather(inp, "x", is_root=(rank == 0), rank=rank)
if rank == 0:
assert np.allclose(output.numpy(), expect)
else:
assert np.allclose(output.numpy(), 0)
def check(shape, backend):
port_queue = mp.Queue()
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
z = np.concatenate((x, y))
p0 = mp.Process(target=worker, args=(0, x, backend, z, port_queue))
p1 = mp.Process(target=worker, args=(1, y, backend, None, port_queue))
p0.start()
p1.start()
p0.join(10)
p1.join(10)
assert p0.exitcode == 0 and p1.exitcode == 0
for shape in [(2, 3), (8, 10), (99, 77)]:
for backend in ["nccl", "ucx"]:
check(shape, backend)
@pytest.mark.isolated_distributed
def test_broadcast():
world_size = 2
......@@ -93,6 +129,76 @@ def test_broadcast():
check(shape, backend)
@pytest.mark.isolated_distributed
def test_scatter():
world_size = 2
def worker(rank, data, backend, expect, port_queue):
if not mge.is_cuda_available():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data)
output = dist.functional.scatter(inp, "x", is_root=(rank == 0), rank=rank)
assert np.allclose(output.numpy(), expect)
def check(shape, backend):
port_queue = mp.Queue()
x = np.random.rand(*shape).astype("float32")
y = x + 1
p0 = mp.Process(
target=worker, args=(0, x, backend, x[: shape[0] // 2], port_queue)
)
p1 = mp.Process(
target=worker, args=(1, y, backend, x[shape[0] // 2 :], port_queue)
)
p0.start()
p1.start()
p0.join(10)
p1.join(10)
assert p0.exitcode == 0 and p1.exitcode == 0
for shape in [(2, 3), (8, 10), (100, 77)]:
for backend in ["nccl", "ucx"]:
check(shape, backend)
@pytest.mark.isolated_distributed
def test_all_to_all():
world_size = 2
def worker(rank, data, backend, expect, port_queue):
if not mge.is_cuda_available():
return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data)
output = dist.functional.all_to_all(inp, "x", rank=rank)
assert np.allclose(output.numpy(), expect)
def check(shape, backend):
port_queue = mp.Queue()
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
a = np.concatenate((x[: shape[0] // 2], y[: shape[0] // 2]))
b = np.concatenate((x[shape[0] // 2 :], y[shape[0] // 2 :]))
p0 = mp.Process(target=worker, args=(0, x, backend, a, port_queue))
p1 = mp.Process(target=worker, args=(1, y, backend, b, port_queue))
p0.start()
p1.start()
p0.join(10)
p1.join(10)
assert p0.exitcode == 0 and p1.exitcode == 0
for shape in [(2, 3), (8, 10), (100, 77)]:
for backend in ["nccl", "ucx"]:
check(shape, backend)
@pytest.mark.isolated_distributed
def test_all_gather():
world_size = 2
......
......@@ -25,9 +25,10 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm);
#define FOREACH_MODE(cb) \
cb(ALL_REDUCE_SUM) cb(ALL_REDUCE_MAX) cb(ALL_REDUCE_MIN) cb(BROADCAST) \
cb(REDUCE_SUM) cb(ALL_GATHER) cb(REDUCE_SCATTER_SUM)
#define FOREACH_MODE(cb) \
cb(ALL_REDUCE_SUM) cb(ALL_REDUCE_MAX) cb(ALL_REDUCE_MIN) cb(BROADCAST) \
cb(REDUCE_SUM) cb(ALL_GATHER) cb(REDUCE_SCATTER_SUM) cb(GATHER) \
cb(SCATTER) cb(ALL_TO_ALL)
namespace {
......@@ -84,6 +85,9 @@ class CollectiveComm::ModeTrait {
class ALL_REDUCE_SUM;
class ALL_REDUCE_MAX;
class ALL_REDUCE_MIN;
class GATHER;
class SCATTER;
class ALL_TO_ALL;
class ReducedBasedTrait;
class AllReduceBase;
......@@ -350,6 +354,102 @@ class CollectiveComm::ModeTrait::BROADCAST : public ModeTrait {
Mode grad_mode() override { return Mode::REDUCE_SUM; }
};
class CollectiveComm::ModeTrait::GATHER : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
add_output_var_all2all(opr);
}
void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
MGB_MARK_USED_VAR(opr);
chk_shape_equal(ishp);
if (opr->is_root()) {
oshp[0] = ishp[0];
oshp[0][0] *= opr->nr_devices();
} else {
oshp[0] = TensorShape{1};
}
}
void exec(CollectiveComm* opr) override {
auto&& iv = opr->input(0)->dev_tensor();
void* recvbuf = nullptr;
if (opr->is_root()) {
recvbuf = opr->output(0)->dev_tensor().raw_ptr();
}
auto status = opr->m_megray_comm->gather(
(void*)iv.raw_ptr(), recvbuf, iv.shape().total_nr_elems(),
get_megray_dtype(iv.dtype()), opr->m_root, opr->megray_ctx());
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay gather failed");
}
Mode grad_mode() override { return Mode::SCATTER; }
};
class CollectiveComm::ModeTrait::SCATTER : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
if (opr->input().size() > 0) {
add_output_var_all2all(opr);
return;
}
const auto& cns = opr->config().comp_node();
mgb_assert(cns.size() == 1, "exactly one comp_node expected, got %zu", cns.size());
auto pname = get_param_name(opr->param());
opr->add_output(ssprintf("%s:%s", pname, opr->key().c_str()))->comp_node(cns[0]);
}
void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
mgb_throw(MegBrainError, "SCATTER should not use get_output_var_shape");
}
void exec(CollectiveComm* opr) override {
auto&& ov = opr->output(0)->dev_tensor();
void* sendbuf = nullptr;
void* recvbuf = ov.raw_ptr();
if (opr->is_root()) {
sendbuf = opr->input(0)->dev_tensor().raw_ptr();
}
auto status = opr->m_megray_comm->scatter(
sendbuf, recvbuf, ov.shape().total_nr_elems(),
get_megray_dtype(ov.dtype()), opr->m_root, opr->megray_ctx());
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay scatter failed");
}
Mode grad_mode() override { return Mode::GATHER; }
};
class CollectiveComm::ModeTrait::ALL_TO_ALL : public ModeTrait {
void add_output_var(CollectiveComm* opr,
const CompNode::UnorderedSet&) override {
add_output_var_all2all(opr);
}
void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp,
TensorShapeArray& oshp) override {
chk_shape_equal(ishp);
oshp = ishp;
}
void exec(CollectiveComm* opr) override {
auto&& iv = opr->input(0)->dev_tensor();
auto&& ov = opr->output(0)->dev_tensor();
auto status = opr->m_megray_comm->all_to_all(
(void*)iv.raw_ptr(), (void*)ov.raw_ptr(),
iv.shape().total_nr_elems() / opr->nr_devices(),
get_megray_dtype(iv.dtype()), opr->megray_ctx());
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay all_to_all failed");
}
Mode grad_mode() override { return Mode::ALL_TO_ALL; }
};
CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) {
switch (mode) {
#define c(_m) \
......@@ -651,41 +751,20 @@ void CollectiveComm::init_output_dtype() {
}
void CollectiveComm::init_output_static_infer_desc() {
if (m_param.mode == Param::Mode::REDUCE_SUM) {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
auto infer_shape_from_input = [](TensorShape& dest, const InpVal& inp_val) {
dest = inp_val.val[0].shape();
return true;
};
auto infer_shape_constant = [](TensorShape& dest, const InpVal&) {
dest = TensorShape{1};
return true;
};
mgb_assert(input().size() == 1);
mgb_assert(output().size() == 1);
if (is_root()) {
mgr.register_shape_infer(output(0),
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape_from_input});
} else {
mgr.register_shape_infer(output(0),
{SourceType::CONSTANT, {}, infer_shape_constant});
}
} else if (m_param.mode == Param::Mode::BROADCAST) {
if (m_param.mode == Param::Mode::BROADCAST ||
m_param.mode == Param::Mode::SCATTER) {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
auto infer_shape_from_input = [this](TensorShape& dest, const InpVal& inp_val) {
if (!m_broadcast_output_shape.valid()) {
m_broadcast_output_shape = inp_val.val[0].shape();
m_group_client->set_output_shape(m_key, m_broadcast_output_shape.val());
}
dest = inp_val.val[0].shape();
if (m_param.mode == Param::Mode::SCATTER) {
dest[0] /= nr_devices();
}
if (!m_output_shape.valid()) {
m_output_shape = dest;
m_group_client->set_output_shape(m_key, dest);
}
return true;
};
......@@ -694,10 +773,11 @@ void CollectiveComm::init_output_static_infer_desc() {
return false;
}
if (!m_broadcast_output_shape.valid()) {
m_broadcast_output_shape = m_group_client->get_output_shape(m_key);
if (!m_output_shape.valid()) {
m_output_shape = m_group_client->get_output_shape(m_key);
}
dest = m_broadcast_output_shape.val();
dest = m_output_shape.val();
return true;
};
......
......@@ -18,6 +18,10 @@
using namespace mgb;
using namespace opr;
cudaStream_t get_stream(VarNode* var) {
return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream;
}
/* ===================== RemoteSend ===================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend);
......@@ -35,7 +39,6 @@ RemoteSend::RemoteSend(const PeerDesc& peer, VarNode* var,
ovar->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::VOLATILE_CONTENT);
}
m_megray_ctx = MegRay::Context::make();
add_equivalence_component<ScalarHash<void*>>(this);
}
......@@ -56,6 +59,9 @@ void RemoteSend::scn_do_execute() {
m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client);
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
m_init = true;
}
......@@ -130,7 +136,6 @@ RemoteRecv::RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph,
->dtype(dtype)
.add_flag(VarNode::Flag::NO_MEM_RECLAIM)
.add_flag(VarNode::Flag::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC);
m_megray_ctx = MegRay::Context::make();
add_equivalence_component<ScalarHash<void*>>(this);
}
......@@ -154,6 +159,9 @@ void RemoteRecv::scn_do_execute() {
m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client);
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
m_init = true;
}
......
......@@ -122,8 +122,9 @@ private:
//! root of BROADCAST and REDUCE operation
int m_root;
//! rank of root of BROADCAST and REDUCE operation
Maybe<TensorShape> m_broadcast_output_shape = None;
// Whether shape infer is enabled. This is only used by BROADCAST operation,
Maybe<TensorShape> m_output_shape = None;
// Whether shape infer is enabled.
// This is only used by BROADCAST and SCATTER operation,
// whose shape infer should be disabled *during* static infer phase.
bool m_enable_shape_infer = false;
......
......@@ -719,6 +719,164 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) {
MGB_ASSERT_TENSOR_EQ(*host_grad, host_out_grad1);
}
TEST(TestOprCollectiveComm, Gather) {
REQUIRE_GPU(2);
auto cn0 = CompNode::load("gpu0");
auto cn1 = CompNode::load("gpu1");
HostTensorGenerator<> gen;
auto host_x0 = gen({28, 28});
auto host_x1 = gen({28, 28});
HostTensorND host_y0, host_y1, host_y_expect;
auto client = std::make_shared<test::MockGroupClient>();
auto graph = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0);
auto x1 = opr::Host2DeviceCopy::make(*graph, host_x1, cn0);
auto x1c = opr::Copy::make(x1, cn1);
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "gather",
2, true, 0, client, {Mode::GATHER}, dtype::Float32(), "nccl")[0];
auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "gather",
2, false, 1, client, {Mode::GATHER}, dtype::Float32(), "nccl")[0];
auto y_expect = opr::Concat::make({x0, x1}, 0);
auto func = graph->compile({make_callback_copy(y0, host_y0),
make_callback_copy(y1, host_y1),
make_callback_copy(y_expect, host_y_expect)});
func->execute();
MGB_ASSERT_TENSOR_EQ(host_y_expect, host_y0);
}
TEST(TestOprCollectiveComm, GatherMultiThread) {
REQUIRE_GPU(2);
auto cn0 = CompNode::load("gpu0");
auto cn1 = CompNode::load("gpu1");
HostTensorGenerator<> gen;
auto host_x0 = gen({28, 28});
auto host_x1 = gen({28, 28});
HostTensorND host_y0, host_y_expect;
auto client = std::make_shared<test::MockGroupClient>();
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "gather", 2, true, 0, client,
{Mode::GATHER}, dtype::Float32(), "nccl")[0];
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)});
func0->execute();
};
auto run_1 = [&]() { // rank 1
auto graph1 = ComputingGraph::make();
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "gather", 2, false, 1, client,
{Mode::GATHER}, dtype::Float32(), "nccl")[0];
auto func1 = graph1->compile({{y1, nullptr}});
func1->execute();
};
auto run_2 = [&]() { // check
auto graph2 = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph2, host_x0, cn0);
auto x1 = opr::Host2DeviceCopy::make(*graph2, host_x1, cn0);
auto y_expect = opr::Concat::make({x0, x1}, 0);
auto func2 = graph2->compile({make_callback_copy(y_expect, host_y_expect)});
func2->execute();
};
std::thread t0(run_0);
std::thread t1(run_1);
std::thread t2(run_2);
t0.join();
t1.join();
t2.join();
MGB_ASSERT_TENSOR_EQ(host_y_expect, host_y0);
}
TEST(TestOprCollectiveComm, GatherWithGrad) {
REQUIRE_GPU(2);
auto cn0 = CompNode::load("gpu0");
auto cn1 = CompNode::load("gpu1");
HostTensorGenerator<> gen;
TensorShape shape({28, 28});
auto host_x0 = gen(shape);
auto host_x1 = gen(shape);
auto host_grad0 = gen(shape);
auto host_grad1 = gen(shape);
HostTensorND host_y0, host_y0_expect, host_out_grad0, host_out_grad1;
auto client = std::make_shared<test::MockGroupClient>();
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
graph0->options().graph_opt_level = 0;
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "gather", 2, true, 0, client,
{Mode::GATHER}, dtype::Float32(), "nccl")[0];
y0.node()->owner_opr()->node_prop().attribute().priority = -1;
auto grad0 = opr::Host2DeviceCopy::make(*graph0, host_grad0, cn0);
auto grad1 = opr::Host2DeviceCopy::make(*graph0, host_grad1, cn0);
auto grad = opr::Concat::make({grad0, grad1}, 0);
auto loss = opr::Dot::make(y0, grad);
auto g = opr::VirtualGrad::make(loss, x0);
auto func0 = graph0->compile(
{make_callback_copy(y0, host_y0),
make_callback_copy(g, host_out_grad0)});
func0->execute();
};
auto run_1 = [&]() { // rank 1
auto graph1 = ComputingGraph::make();
graph1->options().graph_opt_level = 0;
auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "gather", 2, false, 1, client,
{Mode::GATHER}, dtype::Float32(), "nccl")[0];
y1.node()->owner_opr()->node_prop().attribute().priority = -1;
auto grad = opr::Host2DeviceCopy::make(*graph1, gen({1}), cn1);
auto loss = opr::Dot::make(y1, grad);
auto g = opr::VirtualGrad::make(loss, x1);
auto func1 = graph1->compile({{y1, nullptr}, make_callback_copy(g, host_out_grad1)});
func1->execute();
};
auto run_2 = [&]() { // check
auto graph2 = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph2, host_x0, cn0);
auto x1 = opr::Host2DeviceCopy::make(*graph2, host_x1, cn0);
auto y0_expect = opr::Concat::make({x0, x1}, 0);
auto func2 = graph2->compile({
make_callback_copy(y0_expect, host_y0_expect)});
func2->execute();
};
std::thread t0(run_0);
std::thread t1(run_1);
std::thread t2(run_2);
t0.join();
t1.join();
t2.join();
MGB_ASSERT_TENSOR_EQ(host_y0_expect, host_y0);
MGB_ASSERT_TENSOR_EQ(*host_grad0, host_out_grad0);
MGB_ASSERT_TENSOR_EQ(*host_grad1, host_out_grad1);
}
TEST(TestOprCollectiveComm, Broadcast) {
REQUIRE_GPU(2);
auto cn0 = CompNode::load("gpu0");
......@@ -863,3 +1021,349 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) {
MGB_ASSERT_TENSOR_EQ(*host_x0, host_y1);
MGB_ASSERT_TENSOR_EQ(host_out_grad_expect, host_out_grad);
}
TEST(TestOprCollectiveComm, Scatter) {
REQUIRE_GPU(2);
auto cn0 = CompNode::load("gpu0");
auto cn1 = CompNode::load("gpu1");
HostTensorGenerator<> gen;
auto host_x0 = gen({28, 28});
auto host_x1 = gen({28, 28});
HostTensorND host_y0, host_y1;
auto client = std::make_shared<test::MockGroupClient>();
auto graph = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0);
auto x1 = opr::Host2DeviceCopy::make(*graph, host_x1, cn0);
auto x = opr::Concat::make({x0, x1}, 0);
auto y0 = opr::CollectiveComm::make({x}, graph.get(), "scatter",
2, true, 0, client, {Mode::SCATTER}, dtype::Float32(), "nccl")[0];
auto y1 = opr::CollectiveComm::make({}, graph.get(), "scatter", 2, false, 1,
client, {Mode::SCATTER}, dtype::Float32(), "nccl", {cn1})[0];
auto func = graph->compile({make_callback_copy(y0, host_y0),
make_callback_copy(y1, host_y1)});
func->execute();
MGB_ASSERT_TENSOR_EQ(*host_x0, host_y0);
MGB_ASSERT_TENSOR_EQ(*host_x1, host_y1);
}
TEST(TestOprCollectiveComm, ScatterMultiThread) {
REQUIRE_GPU(2);
auto cn0 = CompNode::load("gpu0");
auto cn1 = CompNode::load("gpu1");
HostTensorGenerator<> gen;
auto host_x0 = gen({28, 28});
auto host_x1 = gen({28, 28});
HostTensorND host_y0, host_y1;
auto client = std::make_shared<test::MockGroupClient>();
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto x1 = opr::Host2DeviceCopy::make(*graph0, host_x1, cn0);
auto x = opr::Concat::make({x0, x1}, 0);
auto y0 = opr::CollectiveComm::make({x}, graph0.get(), "scatter", 2, true, 0, client,
{Mode::SCATTER}, dtype::Float32(), "nccl")[0];
auto func0 = graph0->compile({make_callback_copy(y0, host_y0)});
func0->execute();
};
auto run_1 = [&]() { // rank 1
auto graph1 = ComputingGraph::make();
auto y1 = opr::CollectiveComm::make({}, graph1.get(), "scatter", 2, false, 1, client,
{Mode::SCATTER}, dtype::Float32(), "nccl", {cn1})[0];
auto func1 = graph1->compile({make_callback_copy(y1, host_y1)});
func1->execute();
};
std::thread t0(run_0);
std::thread t1(run_1);
t0.join();
t1.join();
MGB_ASSERT_TENSOR_EQ(*host_x0, host_y0);
MGB_ASSERT_TENSOR_EQ(*host_x1, host_y1);
}
TEST(TestOprCollectiveComm, ScatterWithGrad) {
REQUIRE_GPU(2);
auto cn0 = CompNode::load("gpu0");
auto cn1 = CompNode::load("gpu1");
HostTensorGenerator<> gen;
TensorShape shape({28, 28});
auto host_x0 = gen(shape);
auto host_x1 = gen(shape);
auto host_grad0 = gen(shape);
auto host_grad1 = gen(shape);
HostTensorND host_y0, host_y1, host_out_grad, host_out_grad_expect;
auto client = std::make_shared<test::MockGroupClient>();
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
graph0->options().graph_opt_level = 0;
auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0);
auto x1 = opr::Host2DeviceCopy::make(*graph0, host_x1, cn0);
auto x = opr::Concat::make({x0, x1}, 0);
auto y0 = opr::CollectiveComm::make({x}, graph0.get(), "scatter", 2, true, 0, client,
{Mode::SCATTER}, dtype::Float32(), "nccl")[0];
y0.node()->owner_opr()->node_prop().attribute().priority = -1;
auto grad0 = opr::Host2DeviceCopy::make(*graph0, host_grad0, cn0);
auto loss = opr::Dot::make(y0, grad0);
auto g = opr::VirtualGrad::make(loss, x);
auto func0 = graph0->compile(
{make_callback_copy(y0, host_y0),
make_callback_copy(g, host_out_grad)});
func0->execute();
};
auto run_1 = [&]() { // rank 1
auto graph1 = ComputingGraph::make();
graph1->options().graph_opt_level = 0;
auto y1 = opr::CollectiveComm::make({}, graph1.get(), "scatter", 2, false, 1, client,
{Mode::SCATTER}, dtype::Float32(), "nccl", {cn1})[0];
auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1);
auto g = opr::CollectiveComm::make({grad1}, graph1.get(), "scatter:grad", 2, false, 1, client,
Mode::GATHER, dtype::Float32(), "nccl")[0];
g.node()->owner_opr()->node_prop().attribute().priority = 1;
auto func1 = graph1->compile({make_callback_copy(y1, host_y1), {g, nullptr}});
func1->execute();
};
auto run_2 = [&]() { // check
auto graph2 = ComputingGraph::make();
auto grad0 = opr::Host2DeviceCopy::make(*graph2, host_grad0, cn0);
auto grad1 = opr::Host2DeviceCopy::make(*graph2, host_grad1, cn0);
auto out_grad_expect = opr::Concat::make({grad0, grad1}, 0);
auto func2 = graph2->compile({
make_callback_copy(out_grad_expect, host_out_grad_expect)});
func2->execute();
};
std::thread t0(run_0);
std::thread t1(run_1);
std::thread t2(run_2);
t0.join();
t1.join();
t2.join();
MGB_ASSERT_TENSOR_EQ(*host_x0, host_y0);
MGB_ASSERT_TENSOR_EQ(*host_x1, host_y1);
MGB_ASSERT_TENSOR_EQ(host_out_grad_expect, host_out_grad);
}
TEST(TestOprCollectiveComm, AllToAll) {
REQUIRE_GPU(2);
auto cn0 = CompNode::load("gpu0");
auto cn1 = CompNode::load("gpu1");
HostTensorGenerator<> gen;
TensorShape shape({10});
auto host_x00 = gen(shape);
auto host_x01 = gen(shape);
auto host_x10 = gen(shape);
auto host_x11 = gen(shape);
HostTensorND host_y0, host_y1, host_expect_y0, host_expect_y1;
auto client = std::make_shared<test::MockGroupClient>();
auto graph = ComputingGraph::make();
auto x00 = opr::Host2DeviceCopy::make(*graph, host_x00, cn0);
auto x01 = opr::Host2DeviceCopy::make(*graph, host_x01, cn0);
auto x0 = opr::Concat::make({x00, x01}, 0);
auto x10 = opr::Host2DeviceCopy::make(*graph, host_x10, cn1);
auto x11 = opr::Host2DeviceCopy::make(*graph, host_x11, cn1);
auto x1 = opr::Concat::make({x10, x11}, 0);
auto x01c = opr::Copy::make(x01, {cn1});
auto x10c = opr::Copy::make(x10, {cn0});
auto expect_y0 = opr::Concat::make({x00, x10c}, 0);
auto expect_y1 = opr::Concat::make({x01c, x11}, 0);
auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "alltoall",
2, false, 0, client, {Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0];
auto y1 = opr::CollectiveComm::make({x1}, graph.get(), "alltoall", 2, false, 1,
client, {Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0];
auto func = graph->compile({make_callback_copy(y0, host_y0),
make_callback_copy(y1, host_y1),
make_callback_copy(expect_y0, host_expect_y0),
make_callback_copy(expect_y1, host_expect_y1)});
func->execute();
MGB_ASSERT_TENSOR_EQ(host_expect_y0, host_y0);
MGB_ASSERT_TENSOR_EQ(host_expect_y1, host_y1);
}
TEST(TestOprCollectiveComm, AllToAllMultiThread) {
REQUIRE_GPU(2);
auto cn0 = CompNode::load("gpu0");
auto cn1 = CompNode::load("gpu1");
HostTensorGenerator<> gen;
TensorShape shape({10});
auto host_x00 = gen(shape);
auto host_x01 = gen(shape);
auto host_x10 = gen(shape);
auto host_x11 = gen(shape);
HostTensorND host_y0, host_y1, host_expect_y0, host_expect_y1;
auto client = std::make_shared<test::MockGroupClient>();
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
auto x00 = opr::Host2DeviceCopy::make(*graph0, host_x00, cn0);
auto x01 = opr::Host2DeviceCopy::make(*graph0, host_x01, cn0);
auto x10 = opr::Host2DeviceCopy::make(*graph0, host_x10, cn0);
auto x0 = opr::Concat::make({x00, x01}, 0);
auto expect_y0 = opr::Concat::make({x00, x10}, 0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "alltoall", 2, false, 0, client,
{Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0];
auto func0 = graph0->compile(
{make_callback_copy(y0, host_y0),
make_callback_copy(expect_y0, host_expect_y0)});
func0->execute();
};
auto run_1 = [&]() { // rank 1
auto graph1 = ComputingGraph::make();
auto x10 = opr::Host2DeviceCopy::make(*graph1, host_x10, cn1);
auto x11 = opr::Host2DeviceCopy::make(*graph1, host_x11, cn1);
auto x01 = opr::Host2DeviceCopy::make(*graph1, host_x01, cn1);
auto x1 = opr::Concat::make({x10, x11}, 0);
auto expect_y1 = opr::Concat::make({x01, x11}, 0);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "alltoall", 2, false, 1, client,
{Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0];
auto func1 = graph1->compile(
{make_callback_copy(y1, host_y1),
make_callback_copy(expect_y1, host_expect_y1)});
func1->execute();
};
std::thread t0(run_0);
std::thread t1(run_1);
t0.join();
t1.join();
MGB_ASSERT_TENSOR_EQ(host_expect_y0, host_y0);
MGB_ASSERT_TENSOR_EQ(host_expect_y1, host_y1);
}
TEST(TestOprCollectiveComm, AllToAllWithGrad) {
REQUIRE_GPU(2);
auto cn0 = CompNode::load("gpu0");
auto cn1 = CompNode::load("gpu1");
HostTensorGenerator<> gen;
TensorShape shape({10});
auto host_x00 = gen(shape);
auto host_x01 = gen(shape);
auto host_x10 = gen(shape);
auto host_x11 = gen(shape);
auto host_grad00 = gen(shape);
auto host_grad01 = gen(shape);
auto host_grad10 = gen(shape);
auto host_grad11 = gen(shape);
HostTensorND host_y0, host_y1, host_expect_y0, host_expect_y1, host_grad0,
host_grad1, host_expect_grad0, host_expect_grad1;
auto client = std::make_shared<test::MockGroupClient>();
auto run_0 = [&]() { // rank 0
auto graph0 = ComputingGraph::make();
graph0->options().graph_opt_level = 0;
auto x00 = opr::Host2DeviceCopy::make(*graph0, host_x00, cn0);
auto x01 = opr::Host2DeviceCopy::make(*graph0, host_x01, cn0);
auto x10 = opr::Host2DeviceCopy::make(*graph0, host_x10, cn0);
auto x0 = opr::Concat::make({x00, x01}, 0);
auto expect_y0 = opr::Concat::make({x00, x10}, 0);
auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "alltoall", 2, false, 0, client,
{Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0];
y0.node()->owner_opr()->node_prop().attribute().priority = -1;
auto grad00 = opr::Host2DeviceCopy::make(*graph0, host_grad00, cn0);
auto grad10 = opr::Host2DeviceCopy::make(*graph0, host_grad10, cn0);
auto grad_y0 = opr::Concat::make({grad00, grad10}, 0);
auto loss = opr::Dot::make(y0, grad_y0);
auto g = opr::VirtualGrad::make(loss, x0);
auto func0 = graph0->compile(
{make_callback_copy(y0, host_y0),
make_callback_copy(g, host_grad0),
make_callback_copy(expect_y0, host_expect_y0)});
func0->execute();
};
auto run_1 = [&]() { // rank 1
auto graph1 = ComputingGraph::make();
graph1->options().graph_opt_level = 0;
auto x10 = opr::Host2DeviceCopy::make(*graph1, host_x10, cn1);
auto x11 = opr::Host2DeviceCopy::make(*graph1, host_x11, cn1);
auto x01 = opr::Host2DeviceCopy::make(*graph1, host_x01, cn1);
auto x1 = opr::Concat::make({x10, x11}, 0);
auto expect_y1 = opr::Concat::make({x01, x11}, 0);
auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "alltoall", 2, false, 1, client,
{Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0];
y1.node()->owner_opr()->node_prop().attribute().priority = -1;
auto grad01 = opr::Host2DeviceCopy::make(*graph1, host_grad01, cn1);
auto grad11 = opr::Host2DeviceCopy::make(*graph1, host_grad11, cn1);
auto grad_y1 = opr::Concat::make({grad01, grad11}, 0);
auto loss = opr::Dot::make(y1, grad_y1);
auto g = opr::VirtualGrad::make(loss, x1);
auto func0 = graph1->compile(
{make_callback_copy(y1, host_y1),
make_callback_copy(g, host_grad1),
make_callback_copy(expect_y1, host_expect_y1)});
func0->execute();
};
auto run_2 = [&]() { // check
auto graph2 = ComputingGraph::make();
auto grad00 = opr::Host2DeviceCopy::make(*graph2, host_grad00, cn0);
auto grad01 = opr::Host2DeviceCopy::make(*graph2, host_grad01, cn0);
auto grad10 = opr::Host2DeviceCopy::make(*graph2, host_grad10, cn0);
auto grad11 = opr::Host2DeviceCopy::make(*graph2, host_grad11, cn0);
auto out_grad0_expect = opr::Concat::make({grad00, grad01}, 0);
auto out_grad1_expect = opr::Concat::make({grad10, grad11}, 0);
auto func2 = graph2->compile({
make_callback_copy(out_grad0_expect, host_expect_grad0),
make_callback_copy(out_grad1_expect, host_expect_grad1)});
func2->execute();
};
std::thread t0(run_0);
std::thread t1(run_1);
std::thread t2(run_2);
t0.join();
t1.join();
t2.join();
MGB_ASSERT_TENSOR_EQ(host_expect_y0, host_y0);
MGB_ASSERT_TENSOR_EQ(host_expect_y1, host_y1);
MGB_ASSERT_TENSOR_EQ(host_expect_grad0, host_grad0);
MGB_ASSERT_TENSOR_EQ(host_expect_grad1, host_grad1);
}
......@@ -56,7 +56,10 @@ pdef('PersistentOutputStorage').add_fields(
Doc('ALL_REDUCE_SUM', 'every output gets the sum of all inputs'),
Doc('ALL_REDUCE_MAX', 'every output gets the max of all inputs'),
Doc('ALL_REDUCE_MIN', 'every output gets the min of all inputs'),
Doc('ALL_REDUCE_PROD', 'every output gets the prod of all inputs')))
Doc('ALL_REDUCE_PROD', 'every output gets the prod of all inputs'),
Doc('GATHER', 'concat inputs to one node'),
Doc('SCATTER', 'scatter input to each output computing node'),
Doc('ALL_TO_ALL', 'scatter inputs and gather them on each computing node')))
(pdef('FakeSerializedDType',
'HACK: The tag of this param def is actually used for another '
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册