Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
13e8f00a
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
13e8f00a
编写于
8月 13, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
8月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/module): add forward hook support
GitOrigin-RevId: c0db58df13ce12ee293026aad30b2de93e9c6f80
上级
ab9fa48e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
139 addition
and
28 deletion
+139
-28
python_module/megengine/module/module.py
python_module/megengine/module/module.py
+63
-27
python_module/megengine/utils/hook.py
python_module/megengine/utils/hook.py
+23
-0
python_module/test/unit/module/test_module.py
python_module/test/unit/module/test_module.py
+53
-1
未找到文件。
python_module/megengine/module/module.py
浏览文件 @
13e8f00a
...
@@ -14,6 +14,7 @@ import numpy as np
...
@@ -14,6 +14,7 @@ import numpy as np
from
.._internal.dtype
import
is_quantize
from
.._internal.dtype
import
is_quantize
from
..core
import
Buffer
,
Parameter
,
Tensor
from
..core
import
Buffer
,
Parameter
,
Tensor
from
..logger
import
get_logger
from
..logger
import
get_logger
from
..utils.hook
import
HookHandler
logger
=
get_logger
(
__name__
)
logger
=
get_logger
(
__name__
)
...
@@ -57,19 +58,51 @@ class Module(metaclass=ABCMeta):
...
@@ -57,19 +58,51 @@ class Module(metaclass=ABCMeta):
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
# runtime attributes
self
.
training
=
True
self
.
training
=
True
self
.
quantize_diabled
=
False
self
.
quantize_diabled
=
False
# hooks
self
.
_forward_pre_hooks
=
OrderedDict
()
self
.
_forward_hooks
=
OrderedDict
()
@
abstractmethod
@
abstractmethod
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
pass
pass
def
register_forward_pre_hook
(
self
,
hook
:
Callable
)
->
HookHandler
:
"""Register a hook to handle forward inputs. `hook` should be a function
Note that `inputs` keyword inputs
:param hook: a function that receive `module` and `inputs`, then return
a modified `inputs` or `None`.
:return: a handler with :meth:`~.HookHandler.remove` interface to delete the hook.
"""
return
HookHandler
(
self
.
_forward_pre_hooks
,
hook
)
def
register_forward_hook
(
self
,
hook
:
Callable
)
->
HookHandler
:
"""Register a hook to handle forward results. `hook` should be a function that
receive `module`, `inputs` and `outputs`, then return a modified `outputs` or `None`.
This method return a handler with :meth:`~.HookHandler.remove` interface to delete the hook.
"""
return
HookHandler
(
self
.
_forward_hooks
,
hook
)
def
__call__
(
self
,
*
inputs
,
**
kwargs
):
def
__call__
(
self
,
*
inputs
,
**
kwargs
):
# ToDo: Convert numpy or scalar
for
hook
in
self
.
_forward_pre_hooks
.
values
():
# Maybe ToDo: set training phase
modified_inputs
=
hook
(
self
,
inputs
)
# Maybe ToDo: set computing graph
if
modified_inputs
is
not
None
:
if
not
isinstance
(
modified_inputs
,
tuple
):
modified_inputs
=
(
modified_inputs
,)
inputs
=
modified_inputs
outputs
=
self
.
forward
(
*
inputs
,
**
kwargs
)
outputs
=
self
.
forward
(
*
inputs
,
**
kwargs
)
# Maybe ToDo: set connectivity metadata
for
hook
in
self
.
_forward_hooks
.
values
():
modified_outputs
=
hook
(
self
,
inputs
,
outputs
)
if
modified_outputs
is
not
None
:
outputs
=
modified_outputs
return
outputs
return
outputs
def
_flatten
(
def
_flatten
(
...
@@ -191,29 +224,6 @@ class Module(metaclass=ABCMeta):
...
@@ -191,29 +224,6 @@ class Module(metaclass=ABCMeta):
with_key
=
False
,
predicate
=
_is_buffer
,
recursive
=
recursive
,
**
kwargs
with_key
=
False
,
predicate
=
_is_buffer
,
recursive
=
recursive
,
**
kwargs
)
)
def
replace_param
(
self
,
params
:
dict
,
start_pos
:
int
,
seen
:
Optional
[
Set
[
int
]]
=
None
):
offset
=
0
if
seen
is
None
:
seen
=
set
([
id
(
self
)])
module_dict
=
vars
(
self
)
for
key
in
sorted
(
module_dict
):
hash_id
=
id
(
module_dict
[
key
])
if
hash_id
in
seen
:
continue
seen
.
add
(
hash_id
)
if
isinstance
(
module_dict
[
key
],
Parameter
):
if
start_pos
+
offset
in
params
:
assert
module_dict
[
key
].
shape
==
params
[
start_pos
+
offset
].
shape
module_dict
[
key
]
=
params
[
start_pos
+
offset
]
offset
+=
1
if
isinstance
(
module_dict
[
key
],
Module
):
offset
+=
module_dict
[
key
].
replace_param
(
params
,
start_pos
+
offset
,
seen
)
return
offset
def
named_buffers
(
def
named_buffers
(
self
,
prefix
:
Optional
[
str
]
=
None
,
recursive
:
bool
=
True
,
**
kwargs
self
,
prefix
:
Optional
[
str
]
=
None
,
recursive
:
bool
=
True
,
**
kwargs
)
->
Iterable
[
Tuple
[
str
,
Buffer
]]:
)
->
Iterable
[
Tuple
[
str
,
Buffer
]]:
...
@@ -327,6 +337,32 @@ class Module(metaclass=ABCMeta):
...
@@ -327,6 +337,32 @@ class Module(metaclass=ABCMeta):
self
.
apply
(
fn
)
self
.
apply
(
fn
)
def
replace_param
(
self
,
params
:
dict
,
start_pos
:
int
,
seen
:
Optional
[
Set
[
int
]]
=
None
):
"""Replace module's parameters with `params`, used by :class:`~.ParamPack` to
speedup multimachine training.
"""
offset
=
0
if
seen
is
None
:
seen
=
set
([
id
(
self
)])
module_dict
=
vars
(
self
)
for
key
in
sorted
(
module_dict
):
hash_id
=
id
(
module_dict
[
key
])
if
hash_id
in
seen
:
continue
seen
.
add
(
hash_id
)
if
isinstance
(
module_dict
[
key
],
Parameter
):
if
start_pos
+
offset
in
params
:
assert
module_dict
[
key
].
shape
==
params
[
start_pos
+
offset
].
shape
module_dict
[
key
]
=
params
[
start_pos
+
offset
]
offset
+=
1
if
isinstance
(
module_dict
[
key
],
Module
):
offset
+=
module_dict
[
key
].
replace_param
(
params
,
start_pos
+
offset
,
seen
)
return
offset
def
state_dict
(
self
,
rst
=
None
,
prefix
=
""
,
keep_var
=
False
):
def
state_dict
(
self
,
rst
=
None
,
prefix
=
""
,
keep_var
=
False
):
r
"""Returns a dictionary containing whole states of the module.
r
"""Returns a dictionary containing whole states of the module.
"""
"""
...
...
python_module/megengine/utils/hook.py
0 → 100644
浏览文件 @
13e8f00a
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
weakref
class
HookHandler
:
hook_num
=
0
def
__init__
(
self
,
source_dict
,
hook
):
self
.
id
=
HookHandler
.
hook_num
HookHandler
.
hook_num
+=
1
source_dict
[
self
.
id
]
=
hook
self
.
source_ref
=
weakref
.
ref
(
source_dict
)
def
remove
(
self
):
source_dict
=
self
.
source_ref
()
if
source_dict
is
not
None
and
self
.
id
in
source_dict
:
del
source_dict
[
self
.
id
]
python_module/test/unit/module/test_module.py
浏览文件 @
13e8f00a
...
@@ -17,6 +17,7 @@ from helpers import MLP
...
@@ -17,6 +17,7 @@ from helpers import MLP
import
megengine
as
mge
import
megengine
as
mge
import
megengine._internal
as
mgb
import
megengine._internal
as
mgb
import
megengine.functional
as
F
from
megengine.core
import
Buffer
,
Parameter
,
Tensor
,
tensor
from
megengine.core
import
Buffer
,
Parameter
,
Tensor
,
tensor
from
megengine.module
import
(
from
megengine.module
import
(
BatchNorm1d
,
BatchNorm1d
,
...
@@ -37,7 +38,7 @@ class MyModule(Module):
...
@@ -37,7 +38,7 @@ class MyModule(Module):
self
.
bn
=
BatchNorm2d
(
4
)
self
.
bn
=
BatchNorm2d
(
4
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
bn
(
x
)
return
self
.
bn
(
x
)
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
...
@@ -145,6 +146,57 @@ def test_module_api_iterable_stability():
...
@@ -145,6 +146,57 @@ def test_module_api_iterable_stability():
assert
list
(
m
.
modules
())
==
l
assert
list
(
m
.
modules
())
==
l
def
test_module_api_hooks
():
net
=
MyModule
()
pre_hook_num
=
0
post_hook_num
=
0
hooks
=
[]
def
pre_hook
(
module
,
inputs
):
nonlocal
pre_hook_num
pre_hook_num
+=
1
modified_inputs
=
tuple
(
inp
+
1
for
inp
in
inputs
)
return
modified_inputs
def
post_hook
(
module
,
inputs
,
outputs
):
nonlocal
post_hook_num
post_hook_num
+=
1
outputs
+=
1
return
outputs
net
.
apply
(
lambda
module
:
hooks
.
append
(
module
.
register_forward_pre_hook
(
pre_hook
)))
net
.
apply
(
lambda
module
:
hooks
.
append
(
module
.
register_forward_hook
(
post_hook
)))
shape
=
(
1
,
4
,
1
,
1
)
x
=
tensor
(
np
.
zeros
(
shape
,
dtype
=
np
.
float32
))
y
=
net
(
x
)
assert
pre_hook_num
==
4
assert
post_hook_num
==
4
mean1
=
Parameter
(
np
.
zeros
(
shape
),
dtype
=
np
.
float32
)
bn1
=
F
.
batch_norm2d
(
x
+
3
,
mean1
,
Parameter
(
np
.
ones
(
shape
),
dtype
=
np
.
float32
),
training
=
True
)
assertTensorClose
(
net
.
i
.
bn
.
running_mean
,
mean1
,
)
mean2
=
Parameter
(
np
.
zeros
(
shape
),
dtype
=
np
.
float32
)
bn2
=
F
.
batch_norm2d
(
bn1
+
3
,
mean2
,
Parameter
(
np
.
ones
(
shape
),
dtype
=
np
.
float32
),
training
=
True
)
assertTensorClose
(
net
.
bn
.
running_mean
,
mean2
,
)
assertTensorClose
(
bn2
+
2
,
y
)
assert
len
(
hooks
)
==
8
for
handler
in
hooks
:
handler
.
remove
()
y
=
net
(
x
)
assert
pre_hook_num
==
4
assert
post_hook_num
==
4
class
MyModule2
(
Module
):
class
MyModule2
(
Module
):
class
InnerModule
(
Module
):
class
InnerModule
(
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录