Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
60702667
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
60702667
编写于
9月 04, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge/grad_manager): refactor gradmanager, add allreduce callback
GitOrigin-RevId: 086e2871e8141bc2d6c4067b7c42eff85330ebca
上级
3f2eac2f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
80 addition
and
3 deletion
+80
-3
imperative/python/megengine/autodiff/grad_manager.py
imperative/python/megengine/autodiff/grad_manager.py
+52
-0
imperative/python/megengine/core/autodiff/grad.py
imperative/python/megengine/core/autodiff/grad.py
+5
-1
imperative/python/megengine/distributed/__init__.py
imperative/python/megengine/distributed/__init__.py
+1
-1
imperative/python/megengine/distributed/helper.py
imperative/python/megengine/distributed/helper.py
+22
-1
未找到文件。
imperative/python/megengine/autodiff/grad_manager.py
0 → 100644
浏览文件 @
60702667
from
contextlib
import
contextmanager
from
..core.autodiff.grad
import
Grad
from
..tensor
import
tensor
class
GradManager
:
def
__init__
(
self
):
self
.
_call_back_pair
=
[]
self
.
_recording
=
False
self
.
_grad
=
None
def
register
(
self
,
params
,
callback
=
None
):
self
.
_call_back_pair
.
append
([
params
,
callback
])
def
backward
(
self
,
ys
,
dys
=
None
):
if
not
self
.
_recording
:
raise
RuntimeError
(
"no computation history. "
"did you forget record() or "
"call a method that clears the history?"
)
assert
self
.
_grad
is
not
None
if
not
isinstance
(
ys
,
(
tuple
,
list
)):
ys
=
[
ys
]
if
dys
is
None
:
dys
=
[
tensor
(
1
).
broadcast
(
y
.
shape
)
for
y
in
ys
]
if
not
isinstance
(
dys
,
(
tuple
,
list
)):
dys
=
[
dys
]
try
:
self
.
_grad
(
ys
,
dys
)
finally
:
self
.
_grad
=
None
def
record
(
self
):
@
contextmanager
def
recorder
():
grad
=
Grad
()
if
self
.
_recording
:
raise
RuntimeError
(
"already recording!"
)
try
:
self
.
_recording
=
True
self
.
_grad
=
grad
for
params
,
callbacks
in
self
.
_call_back_pair
:
grad
.
wrt
(
*
params
,
callback
=
callbacks
)
with
grad
:
yield
finally
:
self
.
_recording
=
False
self
.
_grad
=
None
return
recorder
()
imperative/python/megengine/core/autodiff/grad.py
浏览文件 @
60702667
...
...
@@ -260,9 +260,13 @@ class Grad:
cache
[
v
]
=
g
if
last_written_to
[
v
]
==
(
seqno
,
i
):
if
v
.
callback
:
v
.
callback
(
grad
=
v
.
callback
(
v
.
owner
(),
Wrapper
(
cache
[
v
])
if
Wrapper
else
cache
[
v
]
)
if
getattr
(
v
.
owner
(),
"grad"
,
None
)
is
None
:
v
.
owner
().
grad
=
grad
else
:
v
.
owner
().
grad
+=
grad
if
v
.
opnode
is
None
:
# won't read by backward, mark consumed
cache
[
v
]
=
None
...
...
imperative/python/megengine/distributed/__init__.py
浏览文件 @
60702667
...
...
@@ -19,7 +19,7 @@ from .group import (
is_distributed
,
new_group
,
)
from
.helper
import
synchronized
from
.helper
import
bcast_params_
,
make_allreduce_cb
,
synchronized
from
.launcher
import
launcher
from
.server
import
Client
,
Server
from
.util
import
get_free_ports
imperative/python/megengine/distributed/helper.py
浏览文件 @
60702667
...
...
@@ -12,7 +12,8 @@ from typing import Callable
from
megengine.device
import
get_device_count
from
.group
import
group_barrier
,
is_distributed
from
.functional
import
all_reduce_sum
,
broadcast
from
.group
import
WORLD
,
group_barrier
,
is_distributed
def
synchronized
(
func
:
Callable
):
...
...
@@ -42,3 +43,23 @@ def get_device_count_by_fork(device_type: str):
p
.
start
()
p
.
join
()
return
q
.
get
()
def
bcast_params_
(
params
,
group
):
for
p
in
params
:
p
.
_reset
(
broadcast
(
p
,
group
))
class
AllreduceCallback
:
def
__init__
(
self
,
reduce_method
,
group
=
WORLD
):
self
.
_reduce_method
=
reduce_method
self
.
_group
=
group
def
__call__
(
self
,
param
,
grad
):
ret
=
all_reduce_sum
(
grad
,
self
.
_group
)
if
self
.
_reduce_method
==
"MEAN"
:
ret
=
ret
/
self
.
_group
.
size
return
ret
make_allreduce_cb
=
AllreduceCallback
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录