Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2ad8c5e1
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
396
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
2ad8c5e1
编写于
11月 25, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/io_remote): fix remote send/recv gradient at trace
GitOrigin-RevId: 7886efd0c124b1a6f60046c9f876e457eb683b1d
上级
f470df4f
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
37 addition
and
16 deletion
+37
-16
imperative/python/megengine/core/autodiff/grad.py
imperative/python/megengine/core/autodiff/grad.py
+7
-1
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+1
-1
imperative/python/test/unit/autodiff/test_grad_manger.py
imperative/python/test/unit/autodiff/test_grad_manger.py
+12
-9
src/opr-mm/impl/io_remote.cpp
src/opr-mm/impl/io_remote.cpp
+14
-5
src/opr-mm/include/megbrain/opr/io_remote.h
src/opr-mm/include/megbrain/opr/io_remote.h
+3
-0
未找到文件。
imperative/python/megengine/core/autodiff/grad.py
浏览文件 @
2ad8c5e1
...
...
@@ -16,7 +16,7 @@ import numpy as np
import
megengine
as
mge
from
..ops.builtin
import
Elemwise
,
OpDef
from
..ops.builtin
import
Elemwise
,
OpDef
,
RemoteSend
from
..ops.special
import
Const
from
..tensor.core
import
TensorBase
,
TensorWrapperBase
,
apply
from
..tensor.function
import
Function
...
...
@@ -84,6 +84,9 @@ class Grad:
# ops forms the computational graph
self
.
ops
=
[]
# save remote_send output for backward
self
.
remote_send_cache
=
[]
self
.
_attached_tensors
=
weakref
.
WeakSet
()
self
.
_enabled
=
True
...
...
@@ -144,6 +147,7 @@ class Grad:
o
.
clear
()
for
i
in
self
.
_attached_tensors
:
i
.
_extra_data
.
pop
(
self
,
None
)
self
.
remote_send_cache
=
[]
def
__exit__
(
self
,
*
_
):
self
.
_exit
()
...
...
@@ -398,6 +402,8 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
return
opnode
,
outputs
=
manager
.
_new_opnode
([
i
and
i
.
node
for
i
in
args
],
ctx
.
outputs
)
if
isinstance
(
op
,
RemoteSend
):
manager
.
remote_send_cache
.
append
(
opnode
)
opnode
.
backward
=
backward
outputs
=
[
x
if
y
else
None
for
(
x
,
y
)
in
zip
(
outputs
,
output_need_grad
)]
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
2ad8c5e1
...
...
@@ -588,7 +588,7 @@ class trace:
graph
.
options
.
graph_opt_level
=
self
.
_graph_opt_level
else
:
graph
.
options
.
graph_opt_level
=
2
graph
.
compile
(
*
readers
)
graph
.
compile
(
*
readers
,
*
links
)
def
_reset_exec_env
(
self
):
for
opnode
in
self
.
_need_reset_nodes
:
...
...
imperative/python/test/unit/autodiff/test_grad_manger.py
浏览文件 @
2ad8c5e1
...
...
@@ -111,7 +111,6 @@ def test_remote_grad():
gm
=
GradManager
().
attach
(
m
.
parameters
())
opt
=
optim
.
SGD
(
m
.
parameters
(),
1e-3
,
momentum
=
0.9
)
@
trace
(
symbolic
=
True
)
def
train_func
(
x
):
with
gm
:
if
rank
!=
0
:
...
...
@@ -120,18 +119,22 @@ def test_remote_grad():
)
y
=
m
(
x
)
if
rank
!=
size
-
1
:
y
=
dist
.
functional
.
remote_send
(
y
,
dest_rank
=
rank
+
1
)
if
rank
==
size
-
1
:
dist
.
functional
.
remote_send
(
y
,
dest_rank
=
rank
+
1
)
gm
.
backward
()
else
:
y
=
y
.
mean
()
gm
.
backward
(
y
)
else
:
gm
.
backward
()
opt
.
step
().
clear_grad
()
for
i
in
range
(
3
):
train_func
(
x
)
train_funcs
=
[
train_func
,
trace
(
symbolic
=
False
)(
train_func
),
trace
(
symbolic
=
True
)(
train_func
),
]
for
param
in
m
.
parameters
():
param
.
numpy
()
for
func
in
train_funcs
:
for
i
in
range
(
3
):
func
(
x
)
sync
()
worker
()
src/opr-mm/impl/io_remote.cpp
浏览文件 @
2ad8c5e1
...
...
@@ -266,11 +266,20 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv(
const
cg
::
OperatorNodeBase
&
opr_
,
const
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
)
{
auto
&&
opr
=
opr_
.
cast_final_safe
<
RemoteRecv
>
();
return
RemoteRecv
::
make
(
opr
.
key
(),
*
opr
.
owner_graph
(),
opr
.
group_client
(),
config
,
inputs
[
0
]
->
shape
(),
inputs
[
0
]
->
dtype
())
.
node
()
->
owner_opr
();
if
(
inputs
.
size
()
==
1
)
{
return
RemoteRecv
::
make
(
opr
.
key
(),
inputs
[
0
],
*
opr
.
owner_graph
(),
opr
.
group_client
(),
config
,
opr
.
shape
(),
opr
.
dtype
())
.
node
()
->
owner_opr
();
}
else
{
mgb_assert
(
inputs
.
size
()
==
0
,
"recv should have 1 or 0 input"
);
return
RemoteRecv
::
make
(
opr
.
key
(),
*
opr
.
owner_graph
(),
opr
.
group_client
(),
config
,
opr
.
shape
(),
opr
.
dtype
())
.
node
()
->
owner_opr
();
}
}
MGB_REG_OPR_SHALLOW_COPY
(
RemoteRecv
,
opr_shallow_copy_remote_recv
);
...
...
src/opr-mm/include/megbrain/opr/io_remote.h
浏览文件 @
2ad8c5e1
...
...
@@ -94,6 +94,9 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // {
const
OperatorNodeConfig
&
config
,
const
TensorShape
&
shape
,
DType
dtype
);
const
TensorShape
&
shape
()
const
{
return
m_shape
;
}
const
DType
&
dtype
()
const
{
return
m_dtype
;
}
private
:
const
TensorShape
m_shape
;
const
DType
m_dtype
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录