提交 1a711299 编写于 作者: M Megvii Engine Team

feat(opr-mm): add backend argument for remote send/recv

GitOrigin-RevId: 841a0e45ab2188a4a7414ff4a23b76e7b9852db7
上级 69a146c8
......@@ -265,6 +265,7 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
op.key = key
op.addr, op.port = get_mm_server_addr()
op.rank_to = dest_rank
op.backend = get_backend()
(dummy,) = apply(_RemoteSend(op), inp)
for g in grad_keys.values():
......@@ -313,6 +314,7 @@ def remote_recv(
op.dtype = dtype
op.addr, op.port = get_mm_server_addr()
op.rank_from = src_rank
op.backend = get_backend()
(ret,) = apply(_RemoteRecv(op), inp)
if _isscalar:
......
......@@ -35,7 +35,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send(
OperatorNodeConfig config{send.make_name()};
cg::OperatorNodeBase* opr =
graph->insert_opr(std::make_unique<mgb::opr::RemoteSend>(
send.key, inputs[0], group_client, true, config));
send.key, inputs[0], group_client, true, send.backend, config));
return opr;
}
......@@ -49,7 +49,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv(
auto&& graph = inputs[0]->owner_graph();
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>(
recv.key, inputs[0], *graph, group_client, config,
recv.shape, recv.dtype));
recv.shape, recv.dtype, recv.backend));
}
OP_TRAIT_REG(RemoteSend, RemoteSend, mgb::opr::RemoteSend)
......
......@@ -34,7 +34,7 @@ TEST(TestImperative, IORemote) {
auto run_send = [&](std::shared_ptr<HostTensorND> hnd) {
auto def = imperative::RemoteSend::make(
"io_remote_test", server_addr, port, 1);
"io_remote_test", server_addr, port, 1, "nccl");
auto inp = Tensor::make(*hnd);
auto oup = OpDef::apply_on_physical_tensor(*def, {inp});
};
......@@ -43,7 +43,7 @@ TEST(TestImperative, IORemote) {
auto def = imperative::RemoteRecv::make(
"io_remote_test", server_addr, port, 0,
CompNode::load("gpu1"), TensorShape{vector_size},
dtype::Float32());
dtype::Float32(), "nccl");
auto inp = Tensor::make(*hnd);
auto oup = OpDef::apply_on_physical_tensor(*def, {inp});
HostTensorND host_v;
......
......@@ -169,7 +169,8 @@ def RemoteSend : MgbHashableOp<"RemoteSend"> {
MgbStringAttr:$key,
MgbStringAttr:$addr,
MgbUI32Attr:$port,
MgbUI32Attr:$rank_to
MgbUI32Attr:$rank_to,
MgbStringAttr:$backend
);
}
......@@ -181,7 +182,8 @@ def RemoteRecv : MgbHashableOp<"RemoteRecv"> {
MgbUI32Attr:$rank_from,
MgbCompNodeAttr:$cn,
MgbTensorShapeAttr:$shape,
MgbDTypeAttr:$dtype
MgbDTypeAttr:$dtype,
MgbStringAttr:$backend
);
}
......
......@@ -24,8 +24,9 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend);
RemoteSend::RemoteSend(const std::string& key, VarNode* var,
std::shared_ptr<GroupClient> group_client,
bool is_grad, const OperatorNodeConfig& config) :
bool is_grad, std::string backend, const OperatorNodeConfig& config) :
Super(var->owner_graph(), config, "remote_send", {var}),
m_backend(backend),
m_is_grad(is_grad) {
m_key = key;
m_group_client = group_client;
......@@ -41,9 +42,9 @@ RemoteSend::RemoteSend(const std::string& key, VarNode* var,
SymbolVar RemoteSend::make(const std::string& key, SymbolVar var,
std::shared_ptr<GroupClient> group_client,
bool is_grad, const OperatorNodeConfig& config) {
bool is_grad, std::string backend, const OperatorNodeConfig& config) {
return var.insert_single_output_opr<RemoteSend>(key, var.node(), group_client,
is_grad, config);
is_grad, backend, config);
}
void RemoteSend::scn_do_execute() {
......@@ -64,7 +65,7 @@ void RemoteSend::scn_do_execute() {
}
m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client);
reg_info.hash, m_key, 2, 0, get_megray_backend(m_backend), m_group_client);
m_megray_ctx = get_megray_context(output(0)->comp_node());
......@@ -122,7 +123,7 @@ MGB_IMPL_OPR_GRAD(RemoteSend) {
*opr.owner_graph(), opr.group_client(),
OperatorNodeConfig{opr.comp_node()}.name(
opr.name() + ":grad_recv"),
opr.input(0)->shape(), opr.input(0)->dtype())
opr.input(0)->shape(), opr.input(0)->dtype(), opr.backend())
.node();
}
#endif
......@@ -134,9 +135,9 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv);
RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config,
const TensorShape& shape, DType dtype) :
const TensorShape& shape, DType dtype, std::string backend) :
Super(&graph, config, "remote_recv", {}),
m_shape(shape), m_dtype(dtype) {
m_shape(shape), m_dtype(dtype), m_backend(backend) {
m_key = key;
m_group_client = group_client;
......@@ -150,9 +151,9 @@ RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph,
RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config,
const TensorShape& shape, DType dtype) :
const TensorShape& shape, DType dtype, std::string backend) :
Super(&graph, config, "remote_recv", {}),
m_shape(shape), m_dtype(dtype) {
m_shape(shape), m_dtype(dtype), m_backend(backend) {
m_key = key;
m_group_client = group_client;
......@@ -167,18 +168,18 @@ RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph&
SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config,
const TensorShape& shape, DType dtype) {
const TensorShape& shape, DType dtype, std::string backend) {
auto opr = graph.insert_opr(std::make_unique<RemoteRecv>(
key, graph, group_client, config, shape, dtype));
key, graph, group_client, config, shape, dtype, backend));
return opr->output(0);
}
SymbolVar RemoteRecv::make(const std::string& key, SymbolVar var, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config,
const TensorShape& shape, DType dtype) {
const TensorShape& shape, DType dtype, std::string backend) {
auto opr = graph.insert_opr(std::make_unique<RemoteRecv>(
key, var.node(), graph, group_client, config, shape, dtype));
key, var.node(), graph, group_client, config, shape, dtype, backend));
return opr->output(0);
}
......@@ -201,7 +202,7 @@ void RemoteRecv::scn_do_execute() {
}
m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client);
reg_info.hash, m_key, 2, 1, get_megray_backend(m_backend), m_group_client);
m_megray_ctx = get_megray_context(output(0)->comp_node());
......@@ -251,7 +252,7 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_send(
mgb_assert(inputs.size() == 1);
auto&& opr = opr_.cast_final_safe<RemoteSend>();
return RemoteSend::make(opr.key(), inputs[0], opr.group_client(),
opr.is_grad(), config)
opr.is_grad(), opr.backend(), config)
.node()
->owner_opr();
}
......@@ -265,14 +266,14 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv(
if (inputs.size() == 1) {
return RemoteRecv::make(opr.key(), inputs[0], *opr.owner_graph(),
opr.group_client(), config, opr.shape(),
opr.dtype())
opr.dtype(), opr.backend())
.node()
->owner_opr();
} else {
mgb_assert(inputs.size() == 0, "recv should have 1 or 0 input");
return RemoteRecv::make(opr.key(), *opr.owner_graph(),
opr.group_client(), config, opr.shape(),
opr.dtype())
opr.dtype(), opr.backend())
.node()
->owner_opr();
}
......
......@@ -9,6 +9,8 @@ decl_raw_opr(
Doc('key', 'key to bind send-recv pair', 'str'),
Doc('var', 'variable to be sent', ':class:`.SymbolVar`'),
Doc('is_grad', 'whether the send', 'bool'),
Doc('backend', 'Backend for collective communication, nccl or ucx',
'str', '\'nccl\''),
]
)
......@@ -24,7 +26,9 @@ decl_raw_opr(
':class:`.CompGraph`'),
Doc('shape', 'output var shape'),
Doc('dtype', 'data type of the output var; must match dtype at sender',
':class:`numpy.dtype` compatible')
':class:`numpy.dtype` compatible'),
Doc('backend', 'Backend for collective communication, nccl or ucx',
'str', '\'nccl\''),
]
)
......
......@@ -48,17 +48,19 @@ MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // {
public:
RemoteSend(const std::string& key, VarNode* var,
std::shared_ptr<GroupClient> group_client,
bool is_grad, const OperatorNodeConfig& config);
bool is_grad, std::string backend, const OperatorNodeConfig& config);
static SymbolVar make(
const std::string& key, SymbolVar var,
std::shared_ptr<GroupClient> group_client,
bool is_grad, const OperatorNodeConfig& config = {});
bool is_grad, std::string backend, const OperatorNodeConfig& config = {});
const std::string& backend() const { return m_backend; }
bool is_grad() const { return m_is_grad; }
private:
HostTensorND m_output_val;
std::string m_backend;
bool m_is_grad;
void scn_do_execute() override;
......@@ -75,31 +77,33 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // {
RemoteRecv(const std::string& key, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype);
DType dtype, std::string backend);
RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype);
DType dtype, std::string backend);
static SymbolVar make(
const std::string& key, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype);
DType dtype, std::string backend);
static SymbolVar make(
const std::string& key, SymbolVar var, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype);
DType dtype, std::string backend);
const TensorShape& shape() const { return m_shape; }
const DType& dtype() const { return m_dtype; }
const std::string& backend() const { return m_backend; }
private:
const TensorShape m_shape;
const DType m_dtype;
const std::string m_backend;
const CompNode m_comp_node;
DeviceTensorND m_dev_buffer;
......
......@@ -33,10 +33,10 @@ TEST(TestOprIORemote, Identity) {
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0);
auto xr = opr::RemoteSend::make("x", x, client, false);
auto xr = opr::RemoteSend::make("x", x, client, false, "nccl");
auto y = opr::RemoteRecv::make("x", *graph.get(),
client, {cn1}, host_x->shape(),
host_x->dtype());
host_x->dtype(), "nccl");
auto func = graph->compile({{xr, {}}, make_callback_copy(y, host_y)});
......@@ -57,7 +57,7 @@ TEST(TestOprIORemote, IdentityMultiThread) {
auto graph = ComputingGraph::make();
sys::set_thread_name("sender");
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
xr = opr::RemoteSend::make("x", x, client, false);
xr = opr::RemoteSend::make("x", x, client, false, "nccl");
auto func = graph->compile({{xr, {}}});
func->execute();
};
......@@ -67,7 +67,7 @@ TEST(TestOprIORemote, IdentityMultiThread) {
auto graph = ComputingGraph::make();
auto x = opr::RemoteRecv::make("x", *graph.get(),
client, {cns[0]}, host_x->shape(),
host_x->dtype());
host_x->dtype(), "nccl");
auto func = graph->compile({make_callback_copy(x, host_x_get)});
func->execute();
};
......@@ -91,7 +91,7 @@ TEST(TestOprIORemote, IdentityWithGopt) {
sys::set_thread_name("sender");
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x) * 2 + 1,
xr = opr::RemoteSend::make("x", x, client, false);
xr = opr::RemoteSend::make("x", x, client, false, "nccl");
auto func = graph->compile({{xr, {}}});
func->execute();
};
......@@ -101,7 +101,7 @@ TEST(TestOprIORemote, IdentityWithGopt) {
auto graph = ComputingGraph::make();
auto x = opr::RemoteRecv::make("x", *graph.get(),
client, {cns[0]}, host_x->shape(),
host_x->dtype());
host_x->dtype(), "nccl");
auto func =
graph->compile({make_callback_copy((x - 1) / 2, host_x_get)});
func->execute();
......@@ -126,12 +126,12 @@ TEST(TestOprIORemote, APlusB) {
auto graph = ComputingGraph::make();
auto z = opr::RemoteRecv::make("z", *graph.get(),
client, {cns[0]}, host_x->shape(),
host_x->dtype());
host_x->dtype(), "nccl");
auto x = opr::Host2DeviceCopy::make(*graph, host_x).rename("x"),
y = opr::Host2DeviceCopy::make(*graph, host_y).rename("y"),
xr = opr::RemoteSend::make("x", x, client, false)
xr = opr::RemoteSend::make("x", x, client, false, "nccl")
.rename("xr"),
yr = opr::RemoteSend::make("y", y, client, false)
yr = opr::RemoteSend::make("y", y, client, false, "nccl")
.rename("yr");
auto func = graph->compile(
{{xr, {}}, {yr, {}}, make_callback_copy(z, host_z)});
......@@ -144,12 +144,12 @@ TEST(TestOprIORemote, APlusB) {
auto graph = ComputingGraph::make();
auto x = opr::RemoteRecv::make("x", *graph.get(),
client, {cns[1]}, host_x->shape(),
host_x->dtype()),
host_x->dtype(), "nccl"),
y = opr::RemoteRecv::make("y", *graph.get(),
client, {cns[1]}, host_y->shape(),
host_y->dtype()),
host_y->dtype(), "nccl"),
z = x + y,
zr = opr::RemoteSend::make("z", z, client, false);
zr = opr::RemoteSend::make("z", z, client, false, "nccl");
auto func = graph->compile({{zr, {}}});
func->execute();
};
......@@ -178,10 +178,10 @@ TEST(TestOprIORemote, SendGrad) {
sys::set_thread_name("sender");
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
loss = opr::RemoteSend::make("loss", x, client, false);
loss = opr::RemoteSend::make("loss", x, client, false, "nccl");
ASSERT_TRUE(!loss.shape().ndim &&
loss.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT));
loss = opr::RemoteSend::make("loss", x, client, true);
loss = opr::RemoteSend::make("loss", x, client, true, "nccl");
auto gx = cg::grad(loss, x);
set_priority(loss, 0);
set_priority(gx, 1);
......@@ -200,8 +200,8 @@ TEST(TestOprIORemote, SendGrad) {
auto graph = ComputingGraph::make();
auto x = opr::RemoteRecv::make("loss", *graph.get(),
client, {cns[1]}, host_x->shape(),
host_x->dtype());
auto y = opr::RemoteSend::make("loss:grad", x + 1, client, false);
host_x->dtype(), "nccl");
auto y = opr::RemoteSend::make("loss:grad", x + 1, client, false, "nccl");
auto func = graph->compile({{y, {}}});
func->execute();
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册