Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ac51f780
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
ac51f780
编写于
6月 01, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/distributed): add support for batch send recv op
GitOrigin-RevId: eb3d712704f7a1d0abc6c611cec7c93ad3f5e8bf
上级
013bb14f
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
531 addition
and
18 deletion
+531
-18
imperative/python/megengine/distributed/__init__.py
imperative/python/megengine/distributed/__init__.py
+1
-0
imperative/python/megengine/distributed/functional.py
imperative/python/megengine/distributed/functional.py
+31
-7
imperative/python/megengine/distributed/group.py
imperative/python/megengine/distributed/group.py
+7
-0
imperative/python/src/common.cpp
imperative/python/src/common.cpp
+17
-0
imperative/python/src/common.h
imperative/python/src/common.h
+1
-0
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+8
-0
imperative/python/src/transformation.h
imperative/python/src/transformation.h
+2
-1
imperative/python/test/unit/distributed/test_distributed.py
imperative/python/test/unit/distributed/test_distributed.py
+29
-0
imperative/src/impl/ops/io_remote.cpp
imperative/src/impl/ops/io_remote.cpp
+157
-3
imperative/src/impl/transformations/group_comm.cpp
imperative/src/impl/transformations/group_comm.cpp
+67
-0
imperative/src/include/megbrain/imperative/ops/io_remote.h
imperative/src/include/megbrain/imperative/ops/io_remote.h
+11
-0
imperative/src/include/megbrain/imperative/transformations/group_comm.h
.../include/megbrain/imperative/transformations/group_comm.h
+44
-0
imperative/src/test/io_remote.cpp
imperative/src/test/io_remote.cpp
+1
-3
src/opr-mm/impl/group_manager.cpp
src/opr-mm/impl/group_manager.cpp
+22
-0
src/opr-mm/impl/megray_helper.cpp
src/opr-mm/impl/megray_helper.cpp
+9
-2
src/opr-mm/impl/mm_handler.cpp
src/opr-mm/impl/mm_handler.cpp
+58
-0
src/opr-mm/include/megbrain/opr/group_manager.h
src/opr-mm/include/megbrain/opr/group_manager.h
+15
-0
src/opr-mm/include/megbrain/opr/megray_helper.h
src/opr-mm/include/megbrain/opr/megray_helper.h
+1
-0
src/opr-mm/include/megbrain/opr/mm_handler.h
src/opr-mm/include/megbrain/opr/mm_handler.h
+32
-2
src/opr-mm/proto/mm_handler.proto
src/opr-mm/proto/mm_handler.proto
+12
-0
src/opr-mm/test/mock_client.h
src/opr-mm/test/mock_client.h
+6
-0
未找到文件。
imperative/python/megengine/distributed/__init__.py
浏览文件 @
ac51f780
# -*- coding: utf-8 -*-
from
mprop
import
mproperty
from
..core._imperative_rt.core2
import
group_end
,
group_start
from
.
import
group
from
.group
import
(
WORLD
,
...
...
imperative/python/megengine/distributed/functional.py
浏览文件 @
ac51f780
# -*- coding: utf-8 -*-
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
import
numpy
as
np
from
..core._imperative_rt.core2
import
apply
from
..core.autodiff.grad
import
Function
,
_grad_manager_dict
from
..core.ops.builtin
import
CollectiveComm
,
Copy
,
RemoteRecv
,
RemoteSend
from
..core.tensor.utils
import
isscalar
from
..core.ops.builtin
import
CollectiveComm
,
RemoteRecv
,
RemoteSend
from
..device
import
get_default_device
,
what_is_xpu
from
..tensor
import
Tensor
from
.
import
group
...
...
@@ -843,16 +842,13 @@ def remote_send(inp: Tensor, dest_rank: int):
"""
group
=
_SendRecvGroup
(
get_rank
(),
dest_rank
)
_bcast_shape_dtype
(
group
,
inp
)
_bcast_tracer_state
(
group
,
inp
)
op
=
RemoteSend
()
op
.
key
=
group
.
key
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
rank_to
=
dest_rank
op
.
backend
=
_backend
()
out
=
_RemoteSend
(
op
)(
inp
)
_save_output_for_autodiff
(
inp
,
out
)
...
...
@@ -900,6 +896,34 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
rank_from
=
src_rank
op
.
backend
=
_backend
()
ret
=
_RemoteRecv
(
op
)(
inp
)
return
ret
def
_remote_send_nobackward
(
inp
:
Tensor
,
dest_rank
:
int
):
op
=
RemoteSend
()
op
.
key
=
"b{}->{}"
.
format
(
get_rank
(),
dest_rank
)
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
rank_to
=
dest_rank
op
.
backend
=
_backend
()
apply
(
op
,
inp
)
def
_remote_recv_nobackward
(
src_rank
:
int
,
device
:
Optional
[
str
]
=
None
,
inp
=
None
,
shape
=
None
,
dtype
=
None
,
):
op
=
RemoteRecv
()
op
.
key
=
"b{}->{}"
.
format
(
src_rank
,
get_rank
())
if
device
is
None
:
device
=
get_default_device
()
op
.
cn
=
device
if
inp
is
None
:
inp
=
Tensor
(
0
,
device
=
device
)
assert
shape
is
not
None
and
dtype
is
not
None
op
.
shape
=
shape
op
.
dtype
=
dtype
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
rank_from
=
src_rank
op
.
backend
=
_backend
()
ret
=
apply
(
op
,
inp
)[
0
]
return
ret
imperative/python/megengine/distributed/group.py
浏览文件 @
ac51f780
...
...
@@ -160,6 +160,13 @@ def init_process_group(
set_default_device
(
"{}{}"
.
format
(
device_type
,
device
))
seed
(
int
(
time
.
time
())
+
rank
)
if
backend
==
"nccl"
:
# init nccl env
from
..core._imperative_rt.common
import
init_nccl_env
group_barrier
()
init_nccl_env
(
master_ip
,
_sd
.
mm_server_port
,
world_size
,
rank
,
0
)
def
_set_machine_ranks
(
ranks
)
->
None
:
global
_sd
...
...
imperative/python/src/common.cpp
浏览文件 @
ac51f780
...
...
@@ -8,6 +8,9 @@
#include "megbrain/comp_node.h"
#include "megbrain/graph.h"
#include "megbrain/imperative/physical_tensor.h"
#if MGB_ENABLE_OPR_MM
#include "megbrain/opr/mm_handler.h"
#endif
#if MEGDNN_WITH_CUDA
#include "cuda_sm_gen.h"
...
...
@@ -46,6 +49,18 @@ void set_default_device(const std::string& device) {
default_device
=
device
;
}
void
init_nccl_env
(
const
std
::
string
&
ip
,
int
port
,
int
nranks
,
int
rank
,
int
root
)
{
#if MGB_ENABLE_OPR_MM
auto
&&
help
=
mgb
::
opr
::
BatchSendRecvHelper
::
getInstance
();
bool
res
=
help
->
init
(
nranks
,
rank
,
ip
,
port
,
root
);
auto
p
=
help
->
get
(
std
::
string
(
"init_all_cards"
));
#else
mgb_throw
(
MegBrainError
,
"MegEngine compiled without MM opr, doesn't support init_nccl_env"
);
#endif
}
std
::
string
get_default_device
()
{
return
default_device
;
}
...
...
@@ -252,6 +267,8 @@ void init_common(py::module m) {
m
.
def
(
"what_is_xpu"
,
[]
{
return
CompNode
::
Locator
::
parse
(
"xpux"
).
to_physical
().
type
;
});
m
.
def
(
"init_nccl_env"
,
&
init_nccl_env
);
init_npy_num_bfloat16
(
m
);
init_npy_num_intbx
(
m
);
init_dtypes
(
m
);
...
...
imperative/python/src/common.h
浏览文件 @
ac51f780
...
...
@@ -8,3 +8,4 @@ void set_default_device(const std::string& device);
std
::
string
get_default_device
();
extern
pybind11
::
handle
py_comp_node_type
;
void
init_nccl_env
(
const
std
::
string
&
ip
,
int
port
,
int
nranks
,
int
rank
,
int
root
);
imperative/python/src/tensor.cpp
浏览文件 @
ac51f780
...
...
@@ -9,6 +9,7 @@
#include "megbrain/imperative/transformations/dtype_promote.h"
#include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/format.h"
#include "megbrain/imperative/transformations/group_comm.h"
#include "megbrain/imperative/transformations/lazy.h"
#include "megbrain/imperative/transformations/scalar.h"
#include "megbrain/imperative/transformations/symbol.h"
...
...
@@ -947,6 +948,13 @@ void init_tensor(py::module m) {
m
.
def
(
"enable_cupti"
,
&
cupti
::
enable
);
m
.
def
(
"disable_cupti"
,
&
cupti
::
disable
);
m
.
def
(
"cupti_available"
,
&
cupti
::
available
);
static
std
::
unique_ptr
<
CleanupGuard
<>>
group_comm_guard
;
m
.
def
(
"group_start"
,
[]()
{
auto
commtrans
=
std
::
make_shared
<
GroupCommTransformation
>
();
group_comm_guard
=
transformations
.
register_at
<
Segment
::
GroupComm
>
(
commtrans
);
});
m
.
def
(
"group_end"
,
[]()
{
group_comm_guard
.
reset
();
});
m
.
def
(
"sync"
,
[
channel
]()
{
if
(
channel
->
check_available
())
{
channel
->
sync
();
...
...
imperative/python/src/transformation.h
浏览文件 @
ac51f780
...
...
@@ -16,6 +16,7 @@ struct TransformationManager {
public:
enum
Segment
{
ModuleTrace
,
GroupComm
,
DTypePromote
,
DimExpansion
,
Format
,
...
...
@@ -26,7 +27,7 @@ public:
Eval
,
};
std
::
array
<
std
::
vector
<
std
::
shared_ptr
<
Transformation
>>
,
9
>
segments
;
std
::
array
<
std
::
vector
<
std
::
shared_ptr
<
Transformation
>>
,
10
>
segments
;
private:
template
<
Segment
segment
>
...
...
imperative/python/test/unit/distributed/test_distributed.py
浏览文件 @
ac51f780
...
...
@@ -237,3 +237,32 @@ def test_get_cuda_compute_capability():
assert
mge
.
device
.
get_cuda_compute_capability
(
dist
.
get_rank
())
>
0
worker
()
@
pytest
.
mark
.
require_ngpu
(
3
)
@
pytest
.
mark
.
isolated_distributed
def
test_batch_send_recv
():
import
megengine.distributed.functional
as
DF
@
dist
.
launcher
(
n_gpus
=
3
)
def
worker
():
rank
=
dist
.
get_rank
()
dist
.
group_start
()
for
i
in
range
(
3
):
tensor
=
mge
.
tensor
(
np
.
ones
(
10000
))
*
rank
if
i
==
2
:
tensor
*=
i
DF
.
_remote_send_nobackward
(
tensor
,
(
rank
+
1
)
%
3
)
DF
.
_remote_recv_nobackward
(
src_rank
=
(
rank
+
1
)
%
3
,
dtype
=
"float32"
,
shape
=
(
10000
,)
)
DF
.
_remote_send_nobackward
(
tensor
,
(
rank
-
1
)
%
3
)
recv
=
DF
.
_remote_recv_nobackward
(
src_rank
=
(
rank
-
1
)
%
3
,
dtype
=
"float32"
,
shape
=
(
10000
,)
)
if
i
==
2
:
recv2
=
recv
dist
.
group_end
()
np
.
testing
.
assert_equal
(
recv2
.
numpy
(),
(
rank
-
1
)
%
3
*
2
*
np
.
ones
(
10000
))
worker
()
imperative/src/impl/ops/io_remote.cpp
浏览文件 @
ac51f780
#include "megbrain/imperative/ops/io_remote.h"
#include "megbrain_build_config.h"
#if MGB_ENABLE_OPR_MM
#include <algorithm>
#include <functional>
#include <numeric>
#include "../blob_manager_impl.h"
#include "../op_trait.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/opr/io_remote.h"
#include "megbrain/opr/megray_helper.h"
#include "megbrain/opr/mm_handler.h"
#endif // MGB_ENABLE_OPR_MM
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/proxy_graph_detail.h"
namespace
mgb
{
namespace
imperative
{
...
...
@@ -46,15 +51,164 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv(
recv
.
backend
));
}
TensorPtr
megray_recv_tensor
(
std
::
shared_ptr
<
MegRay
::
Communicator
>
megray_comm
,
TensorLayout
&
layout
,
CompNode
cn
,
uint32_t
rank_from
)
{
DeviceTensorND
out
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
cn
,
layout
);
auto
megray_ctx
=
mgb
::
opr
::
get_megray_context
(
cn
);
size_t
data_size
=
layout
.
total_nr_elems
();
auto
status
=
megray_comm
->
recv
(
out
.
raw_ptr
(),
data_size
,
mgb
::
opr
::
get_megray_dtype
(
layout
.
dtype
),
rank_from
,
megray_ctx
);
mgb_assert
(
status
==
MegRay
::
MEGRAY_OK
,
"MegRay recv failed"
);
return
Tensor
::
make
(
out
);
}
void
megray_send_tensor
(
std
::
shared_ptr
<
MegRay
::
Communicator
>
megray_comm
,
const
TensorPtr
&
src
,
uint32_t
rank_to
)
{
auto
&&
tensor
=
src
->
dev_tensor
();
auto
&&
ishp
=
src
->
shape
();
size_t
data_size
=
ishp
.
total_nr_elems
();
auto
megray_ctx
=
mgb
::
opr
::
get_megray_context
(
src
->
comp_node
());
auto
status
=
megray_comm
->
send
(
src
->
dev_tensor
().
raw_ptr
(),
data_size
,
mgb
::
opr
::
get_megray_dtype
(
src
->
layout
().
dtype
),
rank_to
,
megray_ctx
);
mgb_assert
(
status
==
MegRay
::
MEGRAY_OK
,
"MegRay send failed"
);
}
TensorLayout
create_layout
(
const
std
::
vector
<
int32_t
>&
shape
,
DType
dtype
)
{
TensorShape
tshape
;
tshape
.
ndim
=
shape
.
size
();
mgb_assert
(
tshape
.
ndim
<=
TensorLayout
::
MAX_NDIM
);
std
::
copy
(
shape
.
begin
(),
shape
.
end
(),
tshape
.
shape
);
return
TensorLayout
(
tshape
,
dtype
);
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible_remote_send
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
input_descs
)
{
auto
&&
dtype
=
input_descs
[
0
].
layout
.
dtype
;
auto
&&
cn
=
input_descs
[
0
].
comp_node
;
return
{{{
TensorLayout
({
0
},
dtype
),
cn
}},
true
};
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor_remote_send
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
auto
&&
op
=
def
.
cast_final_safe
<
RemoteSend
>
();
auto
megray_comm
=
mgb
::
opr
::
BatchSendRecvHelper
::
getInstance
()
->
get
(
std
::
string
(
"init_all_cards"
));
if
(
!
megray_comm
)
{
return
proxy_graph_detail
::
apply_on_physical_tensor
(
def
,
inputs
,
output_descs
,
validated
);
}
mgb_assert
(
megray_comm
!=
nullptr
);
megray_send_tensor
(
megray_comm
,
inputs
[
0
],
op
.
rank_to
);
TensorLayout
layout
({
0
},
inputs
[
0
]
->
dtype
());
DeviceTensorND
out
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
inputs
[
0
]
->
comp_node
(),
layout
);
return
{
Tensor
::
make
(
out
)};
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible_remote_recv
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
input_descs
)
{
auto
&
op
=
def
.
cast_final_safe
<
RemoteRecv
>
();
return
{{{
create_layout
(
op
.
shape
,
op
.
dtype
),
op
.
cn
}},
true
};
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor_remote_recv
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
auto
&&
op
=
def
.
cast_final_safe
<
RemoteRecv
>
();
auto
layout
=
create_layout
(
op
.
shape
,
op
.
dtype
);
auto
megray_comm
=
mgb
::
opr
::
BatchSendRecvHelper
::
getInstance
()
->
get
(
std
::
string
(
"init_all_cards"
));
if
(
!
megray_comm
)
{
return
proxy_graph_detail
::
apply_on_physical_tensor
(
def
,
inputs
,
output_descs
,
validated
);
}
auto
&&
out
=
megray_recv_tensor
(
megray_comm
,
layout
,
op
.
cn
,
op
.
rank_from
);
return
{
out
};
}
SmallVector
<
VarNode
::
LayoutConstraintCallback
>
get_input_layout_constraint
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
SmallVector
<
VarNode
::
LayoutConstraintCallback
>
layout_checker
(
inputs
.
size
());
for
(
size_t
i
;
i
<
inputs
.
size
();
i
++
)
{
layout_checker
[
i
]
=
[](
const
TensorLayout
&
layout
)
{
return
layout
.
is_contiguous
();
};
}
return
layout_checker
;
}
OP_TRAIT_REG
(
RemoteSend
,
RemoteSend
,
mgb
::
opr
::
RemoteSend
)
.
apply_on_var_node
(
apply_on_var_node_remote_send
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor_remote_send
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible_remote_send
)
.
get_input_layout_constraint
(
get_input_layout_constraint
)
.
fallback
();
OP_TRAIT_REG
(
RemoteRecv
,
RemoteRecv
,
mgb
::
opr
::
RemoteRecv
)
.
apply_on_var_node
(
apply_on_var_node_remote_recv
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor_remote_recv
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible_remote_recv
)
.
get_input_layout_constraint
(
get_input_layout_constraint
)
.
fallback
();
}
// anonymous namespace
SmallVector
<
TensorPtr
>
apply_on_physical_tensor_batch_send_recv
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
auto
&&
op
=
def
.
cast_final_safe
<
BatchSendRecvOp
>
();
auto
megray_comm
=
mgb
::
opr
::
BatchSendRecvHelper
::
getInstance
()
->
get
(
std
::
string
(
"init_all_cards"
));
mgb_assert
(
megray_comm
!=
nullptr
);
megray_comm
->
group_start
();
SmallVector
<
TensorPtr
>
outputs
;
size_t
ind
=
0
;
for
(
auto
&&
op_
:
op
.
op_list
)
{
if
(
op_
->
same_type
<
RemoteSend
>
())
{
auto
&&
send_op
=
op_
->
cast_final_safe
<
RemoteSend
>
();
auto
&&
tensor
=
inputs
[
ind
];
megray_send_tensor
(
megray_comm
,
tensor
,
send_op
.
rank_to
);
ind
++
;
}
else
{
mgb_assert
(
op_
->
same_type
<
RemoteRecv
>
());
auto
&&
recv_op
=
op_
->
cast_final_safe
<
RemoteRecv
>
();
auto
layout
=
create_layout
(
recv_op
.
shape
,
recv_op
.
dtype
);
auto
&&
out
=
megray_recv_tensor
(
megray_comm
,
layout
,
recv_op
.
cn
,
recv_op
.
rank_from
);
outputs
.
push_back
(
out
);
}
}
megray_comm
->
group_end
();
return
outputs
;
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible_batch_send_recv
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
input_descs
)
{
auto
&
op
=
def
.
cast_final_safe
<
BatchSendRecvOp
>
();
SmallVector
<
LogicalTensorDesc
>
output_descs
;
for
(
auto
&&
op_
:
op
.
op_list
)
{
if
(
op_
->
same_type
<
RemoteRecv
>
())
{
auto
&&
recv_op
=
op_
->
cast_final_safe
<
RemoteRecv
>
();
output_descs
.
push_back
(
{
create_layout
(
recv_op
.
shape
,
recv_op
.
dtype
),
recv_op
.
cn
});
}
}
return
{
output_descs
,
true
};
}
OP_TRAIT_REG
(
BatchSendRecvOp
,
BatchSendRecvOp
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor_batch_send_recv
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible_batch_send_recv
)
.
get_input_layout_constraint
(
get_input_layout_constraint
)
.
fallback
();
}
// namespace
#endif // MGB_ENABLE_OPR_MM
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
BatchSendRecvOp
);
}
// namespace imperative
}
// namespace mgb
imperative/src/impl/transformations/group_comm.cpp
0 → 100644
浏览文件 @
ac51f780
#include "megbrain/imperative/transformations/group_comm.h"
#include "megbrain/imperative/blob_manager.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/io_remote.h"
namespace
mgb
{
namespace
imperative
{
ValueRefList
GroupCommTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
for
(
auto
inp
:
inputs
)
{
mgb_assert
(
!
inp
.
is
(
m_value_type
),
"Can not use PlaceholderValue as apply input"
);
}
if
(
auto
*
apply_op
=
op
.
as
<
ApplyOp
>
())
{
if
(
apply_op
->
op
().
same_type
<
RemoteSend
>
())
{
auto
&&
send_op
=
apply_op
->
op
().
cast_final_safe
<
RemoteSend
>
();
if
(
send_op
.
key
[
0
]
==
'b'
)
{
send_inputs
.
push_back
(
inputs
[
0
]);
record_ops
.
push_back
(
send_op
.
shared_from_this
());
return
{};
}
}
if
(
apply_op
->
op
().
same_type
<
RemoteRecv
>
())
{
auto
&&
recv_op
=
apply_op
->
op
().
cast_final_safe
<
RemoteRecv
>
();
if
(
recv_op
.
key
[
0
]
==
'b'
)
{
record_ops
.
push_back
(
recv_op
.
shared_from_this
());
auto
rst
=
m_value_type
.
make
();
recv_tensors
.
push_back
(
rst
);
auto
outputs
=
ValueRefList
(
1
);
outputs
[
0
]
=
rst
;
return
outputs
;
}
}
return
imperative
::
apply
(
op
,
inputs
);
}
else
{
return
imperative
::
apply
(
op
,
inputs
);
}
}
ValueRefList
GroupCommTransformation
::
execute_batch_op
()
{
auto
batch_op
=
BatchSendRecvOp
::
make
(
record_ops
);
auto
outputs
=
imperative
::
apply
(
*
batch_op
,
send_inputs
);
return
outputs
;
}
void
GroupCommTransformation
::
on_unregister
()
noexcept
{
auto
rst
=
execute_batch_op
();
mgb_assert
(
rst
.
size
()
==
recv_tensors
.
size
());
for
(
size_t
i
=
0
;
i
<
rst
.
size
();
i
++
)
{
auto
v
=
recv_tensors
[
i
].
lock
();
if
(
v
!=
ValueRef
::
nil
)
{
v
.
reset
(
rst
[
i
]);
}
}
}
GroupCommTransformation
::~
GroupCommTransformation
()
{
for
(
auto
&&
recv
:
recv_tensors
)
{
mgb_assert
(
recv
.
lock
()
==
ValueRef
::
nil
,
"Some PlaceholderValues are not reset after GroupCommTransformation "
"destroyed!"
);
};
}
}
// namespace imperative
}
// namespace mgb
\ No newline at end of file
imperative/src/include/megbrain/imperative/ops/io_remote.h
0 → 100644
浏览文件 @
ac51f780
#pragma once
#include "megbrain/imperative/op_def.h"
namespace
mgb
::
imperative
{
struct
BatchSendRecvOp
final
:
OpDefImplBase
<
BatchSendRecvOp
>
{
SmallVector
<
std
::
shared_ptr
<
OpDef
>>
op_list
;
BatchSendRecvOp
()
=
default
;
BatchSendRecvOp
(
SmallVector
<
std
::
shared_ptr
<
OpDef
>>
op_list
)
:
op_list
{
op_list
}
{}
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
};
}
// namespace mgb::imperative
\ No newline at end of file
imperative/src/include/megbrain/imperative/transformations/group_comm.h
0 → 100644
浏览文件 @
ac51f780
/**
* \file imperative/src/include/megbrain/imperative/scalar.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/imperative/basic_operators.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/ops/autogen.h"
namespace
mgb
::
imperative
{
class
PlaceholderValue
final
:
public
ObjectValue
<
PlaceholderValue
>
{
public:
std
::
string
to_string
()
const
override
{
return
ssprintf
(
"PlaceholderValue"
);
}
void
clear
()
override
{}
};
class
GroupCommTransformation
final
:
public
Transformation
{
private:
SmallVector
<
ValueRef
>
send_inputs
;
std
::
vector
<
PlaceholderValue
::
weak_ref_t
>
recv_tensors
;
SmallVector
<
std
::
shared_ptr
<
OpDef
>>
record_ops
;
ObjectType
<
PlaceholderValue
>
m_value_type
{
"PlaceholderValue"
};
public:
GroupCommTransformation
()
=
default
;
ValueRefList
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRefList
execute_batch_op
();
ValueRef
unwrap
(
ValueRef
value
)
override
{
return
value
;
}
std
::
string
name
()
const
override
{
return
"GroupCommTransformation"
;
}
void
on_unregister
()
noexcept
override
;
~
GroupCommTransformation
();
};
}
// namespace mgb::imperative
imperative/src/test/io_remote.cpp
浏览文件 @
ac51f780
#include "./helper.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/mm_handler.h"
...
...
@@ -47,7 +48,4 @@ TEST(TestImperative, IORemote) {
t0
.
join
();
t1
.
join
();
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// ./imperative_test --gtest_filter TestIORemote
src/opr-mm/impl/group_manager.cpp
浏览文件 @
ac51f780
...
...
@@ -151,6 +151,28 @@ void GroupManager::bcast_addr(
}
}
void
GroupManager
::
bcast_nccluniqueid
(
const
std
::
string
&
key
,
std
::
string
&
id
,
uint32_t
size
,
uint32_t
rank
,
uint32_t
root
)
{
std
::
unique_lock
<
std
::
mutex
>
lk
{
m_key2nccl_id_mtx
};
if
(
rank
==
root
)
{
m_key2nccl_id
[
key
]
=
id
;
}
m_key2nccl_id_size
[
key
]
++
;
if
(
m_key2nccl_id_size
[
key
]
==
size
)
{
m_key2nccl_id_flag
[
key
]
=
true
;
m_bcast_cv
.
notify_all
();
}
else
{
m_bcast_cv
.
wait
(
lk
,
[
&
]
{
return
m_key2nccl_id_flag
.
count
(
key
)
>
0
;
});
}
id
=
m_key2nccl_id
[
key
];
m_key2nccl_id_size
[
key
]
--
;
if
(
m_key2nccl_id_size
[
key
]
==
0
)
{
m_key2nccl_id
.
erase
(
key
);
m_key2nccl_id_flag
.
erase
(
key
);
}
}
void
GroupManager
::
set_output_shape
(
const
std
::
string
&
key
,
const
TensorShape
&
shape
)
{
auto
&&
group
=
get_group
(
key
);
group
.
set_output_shape
(
key
,
shape
);
...
...
src/opr-mm/impl/megray_helper.cpp
浏览文件 @
ac51f780
...
...
@@ -67,6 +67,15 @@ void MegRayCommBuilder::emplace(
m_megray_comms
.
emplace
(
hash
,
comm
);
}
void
MegRayCommBuilder
::
remove
(
uint64_t
hash
,
std
::
shared_ptr
<
MegRay
::
Communicator
>
comm
)
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
m_map_mtx
);
auto
it
=
m_megray_comms
.
find
(
hash
);
if
(
it
!=
m_megray_comms
.
end
())
{
m_megray_comms
.
erase
(
hash
);
}
}
std
::
shared_ptr
<
MegRay
::
Communicator
>
MegRayCommBuilder
::
get_megray_comm
(
uint64_t
hash
,
std
::
string
key
,
uint32_t
size
,
uint32_t
rank
,
MegRay
::
Backend
backend
,
std
::
shared_ptr
<
mgb
::
opr
::
GroupClient
>
group_client
)
{
...
...
@@ -104,5 +113,3 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm(
MegRayCommBuilder
*
MegRayCommBuilder
::
sm_instance
=
nullptr
;
std
::
mutex
MegRayCommBuilder
::
sm_instance_mtx
;
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/opr-mm/impl/mm_handler.cpp
浏览文件 @
ac51f780
...
...
@@ -45,6 +45,7 @@ public:
RUNSERVER
(
get_output_shape
);
RUNSERVER
(
bcast_addr
);
RUNSERVER
(
group_barrier
);
RUNSERVER
(
bcast_nccluniqueid
);
mgb_assert
(
false
,
"invalid rpc request"
);
}
...
...
@@ -53,6 +54,7 @@ private:
void
set_output_shape
(
void
*
input_ptr
,
size_t
input_len
,
std
::
string
*
output
);
void
get_output_shape
(
void
*
input_ptr
,
size_t
input_len
,
std
::
string
*
output
);
void
bcast_addr
(
void
*
input_ptr
,
size_t
input_len
,
std
::
string
*
output
);
void
bcast_nccluniqueid
(
void
*
input_ptr
,
size_t
input_len
,
std
::
string
*
output
);
void
group_barrier
(
void
*
input_ptr
,
size_t
input_len
,
std
::
string
*
output
);
private:
...
...
@@ -116,6 +118,15 @@ void GroupServerProxy::bcast_addr(
rsp
.
SerializeToString
(
output
);
}
void
GroupServerProxy
::
bcast_nccluniqueid
(
void
*
input_ptr
,
size_t
input_len
,
std
::
string
*
output
)
{
INFO_INIT
(
mm_handler
,
BcastNcclUniqueId
);
std
::
string
id
=
req
.
id
();
m_mgr
.
bcast_nccluniqueid
(
req
.
key
(),
id
,
req
.
size
(),
req
.
rank
(),
req
.
root
());
rsp
.
set_id
(
id
);
rsp
.
SerializeToString
(
output
);
}
void
GroupServerProxy
::
group_barrier
(
void
*
input_ptr
,
size_t
input_len
,
std
::
string
*
output
)
{
INFO_INIT
(
mm_handler
,
GroupBarrier
);
...
...
@@ -201,6 +212,19 @@ void GroupClientProxy::bcast_addr(
port
=
rsp
.
port
();
}
void
GroupClientProxy
::
bcast_nccluniqueid
(
const
std
::
string
&
key
,
std
::
string
&
id
,
uint32_t
size
,
uint32_t
rank
,
uint32_t
root
)
{
INFO_INIT
(
mm_handler
,
bcast_nccluniqueid
,
BcastNcclUniqueId
);
req
.
set_id
(
id
.
data
(),
id
.
size
());
req
.
set_key
(
key
.
data
(),
key
.
size
());
req
.
set_size
(
size
);
req
.
set_rank
(
rank
);
req
.
set_root
(
root
);
SOLVE_REQUEST
(
func_name
,
req
,
rsp
);
id
=
rsp
.
id
();
}
uint32_t
GroupClientProxy
::
group_barrier
(
uint32_t
size
,
uint32_t
rank
)
{
INFO_INIT
(
mm_handler
,
group_barrier
,
GroupBarrier
);
req
.
set_size
(
size
);
...
...
@@ -209,6 +233,40 @@ uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) {
return
rsp
.
size
();
}
std
::
shared_ptr
<
MegRay
::
Communicator
>
BatchSendRecvHelper
::
get
(
std
::
string
&&
key
)
{
auto
ptr
=
megray_comm_cache
.
find
(
key
);
if
(
ptr
!=
megray_comm_cache
.
end
())
{
return
megray_comm_cache
[
key
];
}
else
{
return
nullptr
;
}
}
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
MegRay
::
Communicator
>>
BatchSendRecvHelper
::
megray_comm_cache
{};
bool
BatchSendRecvHelper
::
init
(
int
nranks
,
int
rank
,
std
::
string
ip
,
int
port
,
int
root
)
{
auto
megray_comm
=
MegRay
::
get_communicator
(
nranks
,
rank
,
MegRay
::
Backend
::
MEGRAY_NCCL
);
auto
group_client
=
std
::
make_shared
<
opr
::
GroupClientProxy
>
(
ssprintf
(
"%s:%d"
,
ip
.
data
(),
port
));
auto
cb
=
[
=
](
char
*
nccl_buffer
,
size_t
len
)
{
std
::
string
id
;
id
.
resize
(
128
);
if
(
rank
==
root
)
{
memcpy
(
id
.
data
(),
nccl_buffer
,
len
);
}
group_client
->
bcast_nccluniqueid
(
"init_all_cards"
,
id
,
nranks
,
rank
,
root
);
if
(
rank
!=
root
)
{
memcpy
(
nccl_buffer
,
id
.
data
(),
len
);
}
};
megray_comm
->
init
(
cb
);
return
megray_comm_cache
.
insert
({
std
::
string
(
"init_all_cards"
),
megray_comm
})
.
second
;
}
#undef INFO_INIT
#undef SOLVE_REQUEST
...
...
src/opr-mm/include/megbrain/opr/group_manager.h
浏览文件 @
ac51f780
...
...
@@ -77,6 +77,11 @@ public:
std
::
string
&
master_ip
,
int
&
port
,
const
std
::
string
&
key
,
uint32_t
size
,
uint32_t
rank
,
uint32_t
root
);
//! bcast uid
void
bcast_nccluniqueid
(
const
std
::
string
&
key
,
std
::
string
&
id
,
uint32_t
size
,
uint32_t
rank
,
uint32_t
root
);
//! Set output shape of this key
void
set_output_shape
(
const
std
::
string
&
key
,
const
TensorShape
&
shape
);
...
...
@@ -101,6 +106,12 @@ private:
std
::
mutex
m_key2addr_mtx
;
std
::
condition_variable
m_bcast_cv
;
//! key -> ncclid
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m_key2nccl_id
;
std
::
unordered_map
<
std
::
string
,
uint32_t
>
m_key2nccl_id_size
;
std
::
unordered_map
<
std
::
string
,
bool
>
m_key2nccl_id_flag
;
std
::
mutex
m_key2nccl_id_mtx
;
//! barrier
uint32_t
m_barrier_size
;
std
::
set
<
uint32_t
>
m_barrier_set
;
...
...
@@ -128,6 +139,10 @@ public:
std
::
string
&
master_ip
,
int
&
port
,
const
std
::
string
&
key
,
uint32_t
size
,
uint32_t
rank
,
uint32_t
root
)
=
0
;
virtual
void
bcast_nccluniqueid
(
const
std
::
string
&
key
,
std
::
string
&
id
,
uint32_t
size
,
uint32_t
rank
,
uint32_t
root
)
=
0
;
virtual
void
set_output_shape
(
const
std
::
string
&
key
,
const
TensorShape
&
shape
)
=
0
;
virtual
TensorShape
get_output_shape
(
const
std
::
string
&
key
)
=
0
;
...
...
src/opr-mm/include/megbrain/opr/megray_helper.h
浏览文件 @
ac51f780
...
...
@@ -23,6 +23,7 @@ class MegRayCommBuilder {
private:
bool
find
(
uint64_t
hash
,
std
::
shared_ptr
<
MegRay
::
Communicator
>&
comm
);
void
emplace
(
uint64_t
hash
,
std
::
shared_ptr
<
MegRay
::
Communicator
>
comm
);
void
remove
(
uint64_t
hash
,
std
::
shared_ptr
<
MegRay
::
Communicator
>
comm
);
std
::
unordered_map
<
uint64_t
,
std
::
shared_ptr
<
MegRay
::
Communicator
>>
m_megray_comms
;
std
::
mutex
m_map_mtx
;
...
...
src/opr-mm/include/megbrain/opr/mm_handler.h
浏览文件 @
ac51f780
...
...
@@ -39,6 +39,10 @@ public:
std
::
string
&
master_ip
,
int
&
port
,
const
std
::
string
&
key
,
uint32_t
size
,
uint32_t
rank
,
uint32_t
root
)
override
;
void
bcast_nccluniqueid
(
const
std
::
string
&
key
,
std
::
string
&
id
,
uint32_t
size
,
uint32_t
rank
,
uint32_t
root
)
override
;
void
set_output_shape
(
const
std
::
string
&
key
,
const
TensorShape
&
shape
)
override
;
TensorShape
get_output_shape
(
const
std
::
string
&
key
)
override
;
...
...
@@ -52,6 +56,34 @@ private:
void
*
m_stub
;
};
template
<
typename
T
>
class
ProcessGlobal
{
// thread safe
public:
template
<
class
...
Args
>
static
std
::
shared_ptr
<
T
>&
getInstance
(
Args
&&
...
args
)
{
static
auto
instance
=
std
::
make_shared
<
T
>
(
std
::
forward
<
Args
>
(
args
)...);
return
instance
;
}
protected:
template
<
class
...
Args
>
ProcessGlobal
(
Args
&&
...
args
);
ProcessGlobal
()
=
default
;
public:
ProcessGlobal
(
ProcessGlobal
const
&
)
=
delete
;
void
operator
=
(
ProcessGlobal
const
&
)
=
delete
;
};
class
BatchSendRecvHelper
:
public
ProcessGlobal
<
BatchSendRecvHelper
>
{
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
MegRay
::
Communicator
>>
megray_comm_cache
;
public:
std
::
shared_ptr
<
MegRay
::
Communicator
>
get
(
std
::
string
&&
);
bool
init
(
int
nranks
,
int
rank
,
std
::
string
ip
,
int
port
,
int
root
);
};
/* ======================== ZmqRpcServerMgr ========================== */
int
create_zmqrpc_server
(
const
std
::
string
&
server_addr
,
int
port
);
...
...
@@ -60,5 +92,3 @@ int create_zmqrpc_server(const std::string& server_addr, int port);
}
// namespace mgb
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/opr-mm/proto/mm_handler.proto
浏览文件 @
ac51f780
...
...
@@ -30,6 +30,18 @@ message BcastAddrResponse {
int32
port
=
2
;
}
message
BcastNcclUniqueIdRequest
{
string
key
=
1
;
bytes
id
=
2
;
uint32
size
=
3
;
uint32
rank
=
4
;
uint32
root
=
5
;
}
message
BcastNcclUniqueIdResponse
{
bytes
id
=
1
;
}
message
SetOutputShapeRequest
{
string
key
=
1
;
TensorShape
shape
=
2
;
...
...
src/opr-mm/test/mock_client.h
浏览文件 @
ac51f780
...
...
@@ -26,6 +26,12 @@ public:
return
m_mgr
.
bcast_addr
(
master_ip
,
port
,
key
,
size
,
rank
,
root
);
}
void
bcast_nccluniqueid
(
const
std
::
string
&
key
,
std
::
string
&
id
,
uint32_t
size
,
uint32_t
rank
,
uint32_t
root
)
override
{
return
m_mgr
.
bcast_nccluniqueid
(
key
,
id
,
size
,
rank
,
root
);
}
void
set_output_shape
(
const
std
::
string
&
key
,
const
TensorShape
&
shape
)
override
{
m_mgr
.
set_output_shape
(
key
,
shape
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录