Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
1a711299
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
1a711299
编写于
3月 30, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(opr-mm): add backend argument for remote send/recv
GitOrigin-RevId: 841a0e45ab2188a4a7414ff4a23b76e7b9852db7
上级
69a146c8
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
59 addition
and
46 deletion
+59
-46
imperative/python/megengine/distributed/functional.py
imperative/python/megengine/distributed/functional.py
+2
-0
imperative/src/impl/ops/io_remote.cpp
imperative/src/impl/ops/io_remote.cpp
+2
-2
imperative/src/test/io_remote.cpp
imperative/src/test/io_remote.cpp
+2
-2
src/core/include/megbrain/ir/ops.td
src/core/include/megbrain/ir/ops.td
+4
-2
src/opr-mm/impl/io_remote.cpp
src/opr-mm/impl/io_remote.cpp
+18
-17
src/opr-mm/impl/io_remote.oprdecl
src/opr-mm/impl/io_remote.oprdecl
+5
-1
src/opr-mm/include/megbrain/opr/io_remote.h
src/opr-mm/include/megbrain/opr/io_remote.h
+10
-6
src/opr-mm/test/io_remote.cpp
src/opr-mm/test/io_remote.cpp
+16
-16
未找到文件。
imperative/python/megengine/distributed/functional.py
浏览文件 @
1a711299
...
...
@@ -265,6 +265,7 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
op
.
key
=
key
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
rank_to
=
dest_rank
op
.
backend
=
get_backend
()
(
dummy
,)
=
apply
(
_RemoteSend
(
op
),
inp
)
for
g
in
grad_keys
.
values
():
...
...
@@ -313,6 +314,7 @@ def remote_recv(
op
.
dtype
=
dtype
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
rank_from
=
src_rank
op
.
backend
=
get_backend
()
(
ret
,)
=
apply
(
_RemoteRecv
(
op
),
inp
)
if
_isscalar
:
...
...
imperative/src/impl/ops/io_remote.cpp
浏览文件 @
1a711299
...
...
@@ -35,7 +35,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send(
OperatorNodeConfig
config
{
send
.
make_name
()};
cg
::
OperatorNodeBase
*
opr
=
graph
->
insert_opr
(
std
::
make_unique
<
mgb
::
opr
::
RemoteSend
>
(
send
.
key
,
inputs
[
0
],
group_client
,
true
,
config
));
send
.
key
,
inputs
[
0
],
group_client
,
true
,
send
.
backend
,
config
));
return
opr
;
}
...
...
@@ -49,7 +49,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv(
auto
&&
graph
=
inputs
[
0
]
->
owner_graph
();
return
graph
->
insert_opr
(
std
::
make_unique
<
mgb
::
opr
::
RemoteRecv
>
(
recv
.
key
,
inputs
[
0
],
*
graph
,
group_client
,
config
,
recv
.
shape
,
recv
.
dtype
));
recv
.
shape
,
recv
.
dtype
,
recv
.
backend
));
}
OP_TRAIT_REG
(
RemoteSend
,
RemoteSend
,
mgb
::
opr
::
RemoteSend
)
...
...
imperative/src/test/io_remote.cpp
浏览文件 @
1a711299
...
...
@@ -34,7 +34,7 @@ TEST(TestImperative, IORemote) {
auto
run_send
=
[
&
](
std
::
shared_ptr
<
HostTensorND
>
hnd
)
{
auto
def
=
imperative
::
RemoteSend
::
make
(
"io_remote_test"
,
server_addr
,
port
,
1
);
"io_remote_test"
,
server_addr
,
port
,
1
,
"nccl"
);
auto
inp
=
Tensor
::
make
(
*
hnd
);
auto
oup
=
OpDef
::
apply_on_physical_tensor
(
*
def
,
{
inp
});
};
...
...
@@ -43,7 +43,7 @@ TEST(TestImperative, IORemote) {
auto
def
=
imperative
::
RemoteRecv
::
make
(
"io_remote_test"
,
server_addr
,
port
,
0
,
CompNode
::
load
(
"gpu1"
),
TensorShape
{
vector_size
},
dtype
::
Float32
());
dtype
::
Float32
()
,
"nccl"
);
auto
inp
=
Tensor
::
make
(
*
hnd
);
auto
oup
=
OpDef
::
apply_on_physical_tensor
(
*
def
,
{
inp
});
HostTensorND
host_v
;
...
...
src/core/include/megbrain/ir/ops.td
浏览文件 @
1a711299
...
...
@@ -169,7 +169,8 @@ def RemoteSend : MgbHashableOp<"RemoteSend"> {
MgbStringAttr:$key,
MgbStringAttr:$addr,
MgbUI32Attr:$port,
MgbUI32Attr:$rank_to
MgbUI32Attr:$rank_to,
MgbStringAttr:$backend
);
}
...
...
@@ -181,7 +182,8 @@ def RemoteRecv : MgbHashableOp<"RemoteRecv"> {
MgbUI32Attr:$rank_from,
MgbCompNodeAttr:$cn,
MgbTensorShapeAttr:$shape,
MgbDTypeAttr:$dtype
MgbDTypeAttr:$dtype,
MgbStringAttr:$backend
);
}
...
...
src/opr-mm/impl/io_remote.cpp
浏览文件 @
1a711299
...
...
@@ -24,8 +24,9 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend);
RemoteSend
::
RemoteSend
(
const
std
::
string
&
key
,
VarNode
*
var
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
bool
is_grad
,
const
OperatorNodeConfig
&
config
)
:
bool
is_grad
,
std
::
string
backend
,
const
OperatorNodeConfig
&
config
)
:
Super
(
var
->
owner_graph
(),
config
,
"remote_send"
,
{
var
}),
m_backend
(
backend
),
m_is_grad
(
is_grad
)
{
m_key
=
key
;
m_group_client
=
group_client
;
...
...
@@ -41,9 +42,9 @@ RemoteSend::RemoteSend(const std::string& key, VarNode* var,
SymbolVar
RemoteSend
::
make
(
const
std
::
string
&
key
,
SymbolVar
var
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
bool
is_grad
,
const
OperatorNodeConfig
&
config
)
{
bool
is_grad
,
std
::
string
backend
,
const
OperatorNodeConfig
&
config
)
{
return
var
.
insert_single_output_opr
<
RemoteSend
>
(
key
,
var
.
node
(),
group_client
,
is_grad
,
config
);
is_grad
,
backend
,
config
);
}
void
RemoteSend
::
scn_do_execute
()
{
...
...
@@ -64,7 +65,7 @@ void RemoteSend::scn_do_execute() {
}
m_megray_comm
=
MegRayCommBuilder
::
get_megray_comm
(
reg_info
.
hash
,
m_key
,
2
,
0
,
MegRay
::
MEGRAY_NCCL
,
m_group_client
);
reg_info
.
hash
,
m_key
,
2
,
0
,
get_megray_backend
(
m_backend
)
,
m_group_client
);
m_megray_ctx
=
get_megray_context
(
output
(
0
)
->
comp_node
());
...
...
@@ -122,7 +123,7 @@ MGB_IMPL_OPR_GRAD(RemoteSend) {
*
opr
.
owner_graph
(),
opr
.
group_client
(),
OperatorNodeConfig
{
opr
.
comp_node
()}.
name
(
opr
.
name
()
+
":grad_recv"
),
opr
.
input
(
0
)
->
shape
(),
opr
.
input
(
0
)
->
dtype
())
opr
.
input
(
0
)
->
shape
(),
opr
.
input
(
0
)
->
dtype
()
,
opr
.
backend
()
)
.
node
();
}
#endif
...
...
@@ -134,9 +135,9 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv);
RemoteRecv
::
RemoteRecv
(
const
std
::
string
&
key
,
cg
::
ComputingGraph
&
graph
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
OperatorNodeConfig
&
config
,
const
TensorShape
&
shape
,
DType
dtype
)
:
const
TensorShape
&
shape
,
DType
dtype
,
std
::
string
backend
)
:
Super
(
&
graph
,
config
,
"remote_recv"
,
{}),
m_shape
(
shape
),
m_dtype
(
dtype
)
{
m_shape
(
shape
),
m_dtype
(
dtype
)
,
m_backend
(
backend
)
{
m_key
=
key
;
m_group_client
=
group_client
;
...
...
@@ -150,9 +151,9 @@ RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph,
RemoteRecv
::
RemoteRecv
(
const
std
::
string
&
key
,
VarNode
*
var
,
cg
::
ComputingGraph
&
graph
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
OperatorNodeConfig
&
config
,
const
TensorShape
&
shape
,
DType
dtype
)
:
const
TensorShape
&
shape
,
DType
dtype
,
std
::
string
backend
)
:
Super
(
&
graph
,
config
,
"remote_recv"
,
{}),
m_shape
(
shape
),
m_dtype
(
dtype
)
{
m_shape
(
shape
),
m_dtype
(
dtype
)
,
m_backend
(
backend
)
{
m_key
=
key
;
m_group_client
=
group_client
;
...
...
@@ -167,18 +168,18 @@ RemoteRecv::RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph&
SymbolVar
RemoteRecv
::
make
(
const
std
::
string
&
key
,
cg
::
ComputingGraph
&
graph
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
OperatorNodeConfig
&
config
,
const
TensorShape
&
shape
,
DType
dtype
)
{
const
TensorShape
&
shape
,
DType
dtype
,
std
::
string
backend
)
{
auto
opr
=
graph
.
insert_opr
(
std
::
make_unique
<
RemoteRecv
>
(
key
,
graph
,
group_client
,
config
,
shape
,
dtype
));
key
,
graph
,
group_client
,
config
,
shape
,
dtype
,
backend
));
return
opr
->
output
(
0
);
}
SymbolVar
RemoteRecv
::
make
(
const
std
::
string
&
key
,
SymbolVar
var
,
cg
::
ComputingGraph
&
graph
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
OperatorNodeConfig
&
config
,
const
TensorShape
&
shape
,
DType
dtype
)
{
const
TensorShape
&
shape
,
DType
dtype
,
std
::
string
backend
)
{
auto
opr
=
graph
.
insert_opr
(
std
::
make_unique
<
RemoteRecv
>
(
key
,
var
.
node
(),
graph
,
group_client
,
config
,
shape
,
dtype
));
key
,
var
.
node
(),
graph
,
group_client
,
config
,
shape
,
dtype
,
backend
));
return
opr
->
output
(
0
);
}
...
...
@@ -201,7 +202,7 @@ void RemoteRecv::scn_do_execute() {
}
m_megray_comm
=
MegRayCommBuilder
::
get_megray_comm
(
reg_info
.
hash
,
m_key
,
2
,
1
,
MegRay
::
MEGRAY_NCCL
,
m_group_client
);
reg_info
.
hash
,
m_key
,
2
,
1
,
get_megray_backend
(
m_backend
)
,
m_group_client
);
m_megray_ctx
=
get_megray_context
(
output
(
0
)
->
comp_node
());
...
...
@@ -251,7 +252,7 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_send(
mgb_assert
(
inputs
.
size
()
==
1
);
auto
&&
opr
=
opr_
.
cast_final_safe
<
RemoteSend
>
();
return
RemoteSend
::
make
(
opr
.
key
(),
inputs
[
0
],
opr
.
group_client
(),
opr
.
is_grad
(),
config
)
opr
.
is_grad
(),
opr
.
backend
(),
config
)
.
node
()
->
owner_opr
();
}
...
...
@@ -265,14 +266,14 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv(
if
(
inputs
.
size
()
==
1
)
{
return
RemoteRecv
::
make
(
opr
.
key
(),
inputs
[
0
],
*
opr
.
owner_graph
(),
opr
.
group_client
(),
config
,
opr
.
shape
(),
opr
.
dtype
())
opr
.
dtype
()
,
opr
.
backend
()
)
.
node
()
->
owner_opr
();
}
else
{
mgb_assert
(
inputs
.
size
()
==
0
,
"recv should have 1 or 0 input"
);
return
RemoteRecv
::
make
(
opr
.
key
(),
*
opr
.
owner_graph
(),
opr
.
group_client
(),
config
,
opr
.
shape
(),
opr
.
dtype
())
opr
.
dtype
()
,
opr
.
backend
()
)
.
node
()
->
owner_opr
();
}
...
...
src/opr-mm/impl/io_remote.oprdecl
浏览文件 @
1a711299
...
...
@@ -9,6 +9,8 @@ decl_raw_opr(
Doc
(
'key'
,
'key to bind send-recv pair'
,
'str'
),
Doc
(
'var'
,
'variable to be sent'
,
':class:`.SymbolVar`'
),
Doc
(
'is_grad'
,
'whether the send'
,
'bool'
),
Doc
(
'backend'
,
'Backend for collective communication, nccl or ucx'
,
'str'
,
'
\'
nccl
\'
'
),
]
)
...
...
@@ -24,7 +26,9 @@ decl_raw_opr(
':class:`.CompGraph`'
),
Doc
(
'shape'
,
'output var shape'
),
Doc
(
'dtype'
,
'data type of the output var; must match dtype at sender'
,
':class:`numpy.dtype` compatible'
)
':class:`numpy.dtype` compatible'
),
Doc
(
'backend'
,
'Backend for collective communication, nccl or ucx'
,
'str'
,
'
\'
nccl
\'
'
),
]
)
...
...
src/opr-mm/include/megbrain/opr/io_remote.h
浏览文件 @
1a711299
...
...
@@ -48,17 +48,19 @@ MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // {
public:
RemoteSend
(
const
std
::
string
&
key
,
VarNode
*
var
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
bool
is_grad
,
const
OperatorNodeConfig
&
config
);
bool
is_grad
,
std
::
string
backend
,
const
OperatorNodeConfig
&
config
);
static
SymbolVar
make
(
const
std
::
string
&
key
,
SymbolVar
var
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
bool
is_grad
,
const
OperatorNodeConfig
&
config
=
{});
bool
is_grad
,
std
::
string
backend
,
const
OperatorNodeConfig
&
config
=
{});
const
std
::
string
&
backend
()
const
{
return
m_backend
;
}
bool
is_grad
()
const
{
return
m_is_grad
;
}
private:
HostTensorND
m_output_val
;
std
::
string
m_backend
;
bool
m_is_grad
;
void
scn_do_execute
()
override
;
...
...
@@ -75,31 +77,33 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // {
RemoteRecv
(
const
std
::
string
&
key
,
cg
::
ComputingGraph
&
graph
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
OperatorNodeConfig
&
config
,
const
TensorShape
&
shape
,
DType
dtype
);
DType
dtype
,
std
::
string
backend
);
RemoteRecv
(
const
std
::
string
&
key
,
VarNode
*
var
,
cg
::
ComputingGraph
&
graph
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
OperatorNodeConfig
&
config
,
const
TensorShape
&
shape
,
DType
dtype
);
DType
dtype
,
std
::
string
backend
);
static
SymbolVar
make
(
const
std
::
string
&
key
,
cg
::
ComputingGraph
&
graph
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
OperatorNodeConfig
&
config
,
const
TensorShape
&
shape
,
DType
dtype
);
DType
dtype
,
std
::
string
backend
);
static
SymbolVar
make
(
const
std
::
string
&
key
,
SymbolVar
var
,
cg
::
ComputingGraph
&
graph
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
OperatorNodeConfig
&
config
,
const
TensorShape
&
shape
,
DType
dtype
);
DType
dtype
,
std
::
string
backend
);
const
TensorShape
&
shape
()
const
{
return
m_shape
;
}
const
DType
&
dtype
()
const
{
return
m_dtype
;
}
const
std
::
string
&
backend
()
const
{
return
m_backend
;
}
private
:
const
TensorShape
m_shape
;
const
DType
m_dtype
;
const
std
::
string
m_backend
;
const
CompNode
m_comp_node
;
DeviceTensorND
m_dev_buffer
;
...
...
src/opr-mm/test/io_remote.cpp
浏览文件 @
1a711299
...
...
@@ -33,10 +33,10 @@ TEST(TestOprIORemote, Identity) {
auto
graph
=
ComputingGraph
::
make
();
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
,
cn0
);
auto
xr
=
opr
::
RemoteSend
::
make
(
"x"
,
x
,
client
,
false
);
auto
xr
=
opr
::
RemoteSend
::
make
(
"x"
,
x
,
client
,
false
,
"nccl"
);
auto
y
=
opr
::
RemoteRecv
::
make
(
"x"
,
*
graph
.
get
(),
client
,
{
cn1
},
host_x
->
shape
(),
host_x
->
dtype
());
host_x
->
dtype
()
,
"nccl"
);
auto
func
=
graph
->
compile
({{
xr
,
{}},
make_callback_copy
(
y
,
host_y
)});
...
...
@@ -57,7 +57,7 @@ TEST(TestOprIORemote, IdentityMultiThread) {
auto
graph
=
ComputingGraph
::
make
();
sys
::
set_thread_name
(
"sender"
);
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
),
xr
=
opr
::
RemoteSend
::
make
(
"x"
,
x
,
client
,
false
);
xr
=
opr
::
RemoteSend
::
make
(
"x"
,
x
,
client
,
false
,
"nccl"
);
auto
func
=
graph
->
compile
({{
xr
,
{}}});
func
->
execute
();
};
...
...
@@ -67,7 +67,7 @@ TEST(TestOprIORemote, IdentityMultiThread) {
auto
graph
=
ComputingGraph
::
make
();
auto
x
=
opr
::
RemoteRecv
::
make
(
"x"
,
*
graph
.
get
(),
client
,
{
cns
[
0
]},
host_x
->
shape
(),
host_x
->
dtype
());
host_x
->
dtype
()
,
"nccl"
);
auto
func
=
graph
->
compile
({
make_callback_copy
(
x
,
host_x_get
)});
func
->
execute
();
};
...
...
@@ -91,7 +91,7 @@ TEST(TestOprIORemote, IdentityWithGopt) {
sys
::
set_thread_name
(
"sender"
);
auto
graph
=
ComputingGraph
::
make
();
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
)
*
2
+
1
,
xr
=
opr
::
RemoteSend
::
make
(
"x"
,
x
,
client
,
false
);
xr
=
opr
::
RemoteSend
::
make
(
"x"
,
x
,
client
,
false
,
"nccl"
);
auto
func
=
graph
->
compile
({{
xr
,
{}}});
func
->
execute
();
};
...
...
@@ -101,7 +101,7 @@ TEST(TestOprIORemote, IdentityWithGopt) {
auto
graph
=
ComputingGraph
::
make
();
auto
x
=
opr
::
RemoteRecv
::
make
(
"x"
,
*
graph
.
get
(),
client
,
{
cns
[
0
]},
host_x
->
shape
(),
host_x
->
dtype
());
host_x
->
dtype
()
,
"nccl"
);
auto
func
=
graph
->
compile
({
make_callback_copy
((
x
-
1
)
/
2
,
host_x_get
)});
func
->
execute
();
...
...
@@ -126,12 +126,12 @@ TEST(TestOprIORemote, APlusB) {
auto
graph
=
ComputingGraph
::
make
();
auto
z
=
opr
::
RemoteRecv
::
make
(
"z"
,
*
graph
.
get
(),
client
,
{
cns
[
0
]},
host_x
->
shape
(),
host_x
->
dtype
());
host_x
->
dtype
()
,
"nccl"
);
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
).
rename
(
"x"
),
y
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_y
).
rename
(
"y"
),
xr
=
opr
::
RemoteSend
::
make
(
"x"
,
x
,
client
,
false
)
xr
=
opr
::
RemoteSend
::
make
(
"x"
,
x
,
client
,
false
,
"nccl"
)
.
rename
(
"xr"
),
yr
=
opr
::
RemoteSend
::
make
(
"y"
,
y
,
client
,
false
)
yr
=
opr
::
RemoteSend
::
make
(
"y"
,
y
,
client
,
false
,
"nccl"
)
.
rename
(
"yr"
);
auto
func
=
graph
->
compile
(
{{
xr
,
{}},
{
yr
,
{}},
make_callback_copy
(
z
,
host_z
)});
...
...
@@ -144,12 +144,12 @@ TEST(TestOprIORemote, APlusB) {
auto
graph
=
ComputingGraph
::
make
();
auto
x
=
opr
::
RemoteRecv
::
make
(
"x"
,
*
graph
.
get
(),
client
,
{
cns
[
1
]},
host_x
->
shape
(),
host_x
->
dtype
()),
host_x
->
dtype
()
,
"nccl"
),
y
=
opr
::
RemoteRecv
::
make
(
"y"
,
*
graph
.
get
(),
client
,
{
cns
[
1
]},
host_y
->
shape
(),
host_y
->
dtype
()),
host_y
->
dtype
()
,
"nccl"
),
z
=
x
+
y
,
zr
=
opr
::
RemoteSend
::
make
(
"z"
,
z
,
client
,
false
);
zr
=
opr
::
RemoteSend
::
make
(
"z"
,
z
,
client
,
false
,
"nccl"
);
auto
func
=
graph
->
compile
({{
zr
,
{}}});
func
->
execute
();
};
...
...
@@ -178,10 +178,10 @@ TEST(TestOprIORemote, SendGrad) {
sys
::
set_thread_name
(
"sender"
);
auto
graph
=
ComputingGraph
::
make
();
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
),
loss
=
opr
::
RemoteSend
::
make
(
"loss"
,
x
,
client
,
false
);
loss
=
opr
::
RemoteSend
::
make
(
"loss"
,
x
,
client
,
false
,
"nccl"
);
ASSERT_TRUE
(
!
loss
.
shape
().
ndim
&&
loss
.
node
()
->
contain_flag
(
VarNode
::
Flag
::
VOLATILE_CONTENT
));
loss
=
opr
::
RemoteSend
::
make
(
"loss"
,
x
,
client
,
true
);
loss
=
opr
::
RemoteSend
::
make
(
"loss"
,
x
,
client
,
true
,
"nccl"
);
auto
gx
=
cg
::
grad
(
loss
,
x
);
set_priority
(
loss
,
0
);
set_priority
(
gx
,
1
);
...
...
@@ -200,8 +200,8 @@ TEST(TestOprIORemote, SendGrad) {
auto
graph
=
ComputingGraph
::
make
();
auto
x
=
opr
::
RemoteRecv
::
make
(
"loss"
,
*
graph
.
get
(),
client
,
{
cns
[
1
]},
host_x
->
shape
(),
host_x
->
dtype
());
auto
y
=
opr
::
RemoteSend
::
make
(
"loss:grad"
,
x
+
1
,
client
,
false
);
host_x
->
dtype
()
,
"nccl"
);
auto
y
=
opr
::
RemoteSend
::
make
(
"loss:grad"
,
x
+
1
,
client
,
false
,
"nccl"
);
auto
func
=
graph
->
compile
({{
y
,
{}}});
func
->
execute
();
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录