提交 13e8f00a 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/module): add forward hook support

GitOrigin-RevId: c0db58df13ce12ee293026aad30b2de93e9c6f80
上级 ab9fa48e
......@@ -14,6 +14,7 @@ import numpy as np
from .._internal.dtype import is_quantize
from ..core import Buffer, Parameter, Tensor
from ..logger import get_logger
from ..utils.hook import HookHandler
logger = get_logger(__name__)
......@@ -57,19 +58,51 @@ class Module(metaclass=ABCMeta):
"""
def __init__(self):
# runtime attributes
self.training = True
self.quantize_diabled = False
# hooks
self._forward_pre_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
@abstractmethod
def forward(self, inputs):
pass
def register_forward_pre_hook(self, hook: Callable) -> HookHandler:
"""Register a hook to handle forward inputs. `hook` should be a function
Note that `inputs` keyword inputs
:param hook: a function that receive `module` and `inputs`, then return
a modified `inputs` or `None`.
:return: a handler with :meth:`~.HookHandler.remove` interface to delete the hook.
"""
return HookHandler(self._forward_pre_hooks, hook)
def register_forward_hook(self, hook: Callable) -> HookHandler:
"""Register a hook to handle forward results. `hook` should be a function that
receive `module`, `inputs` and `outputs`, then return a modified `outputs` or `None`.
This method return a handler with :meth:`~.HookHandler.remove` interface to delete the hook.
"""
return HookHandler(self._forward_hooks, hook)
def __call__(self, *inputs, **kwargs):
# ToDo: Convert numpy or scalar
# Maybe ToDo: set training phase
# Maybe ToDo: set computing graph
for hook in self._forward_pre_hooks.values():
modified_inputs = hook(self, inputs)
if modified_inputs is not None:
if not isinstance(modified_inputs, tuple):
modified_inputs = (modified_inputs,)
inputs = modified_inputs
outputs = self.forward(*inputs, **kwargs)
# Maybe ToDo: set connectivity metadata
for hook in self._forward_hooks.values():
modified_outputs = hook(self, inputs, outputs)
if modified_outputs is not None:
outputs = modified_outputs
return outputs
def _flatten(
......@@ -191,29 +224,6 @@ class Module(metaclass=ABCMeta):
with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs
)
def replace_param(
self, params: dict, start_pos: int, seen: Optional[Set[int]] = None
):
offset = 0
if seen is None:
seen = set([id(self)])
module_dict = vars(self)
for key in sorted(module_dict):
hash_id = id(module_dict[key])
if hash_id in seen:
continue
seen.add(hash_id)
if isinstance(module_dict[key], Parameter):
if start_pos + offset in params:
assert module_dict[key].shape == params[start_pos + offset].shape
module_dict[key] = params[start_pos + offset]
offset += 1
if isinstance(module_dict[key], Module):
offset += module_dict[key].replace_param(
params, start_pos + offset, seen
)
return offset
def named_buffers(
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
) -> Iterable[Tuple[str, Buffer]]:
......@@ -327,6 +337,32 @@ class Module(metaclass=ABCMeta):
self.apply(fn)
def replace_param(
self, params: dict, start_pos: int, seen: Optional[Set[int]] = None
):
"""Replace module's parameters with `params`, used by :class:`~.ParamPack` to
speedup multimachine training.
"""
offset = 0
if seen is None:
seen = set([id(self)])
module_dict = vars(self)
for key in sorted(module_dict):
hash_id = id(module_dict[key])
if hash_id in seen:
continue
seen.add(hash_id)
if isinstance(module_dict[key], Parameter):
if start_pos + offset in params:
assert module_dict[key].shape == params[start_pos + offset].shape
module_dict[key] = params[start_pos + offset]
offset += 1
if isinstance(module_dict[key], Module):
offset += module_dict[key].replace_param(
params, start_pos + offset, seen
)
return offset
def state_dict(self, rst=None, prefix="", keep_var=False):
r"""Returns a dictionary containing whole states of the module.
"""
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import weakref
class HookHandler:
hook_num = 0
def __init__(self, source_dict, hook):
self.id = HookHandler.hook_num
HookHandler.hook_num += 1
source_dict[self.id] = hook
self.source_ref = weakref.ref(source_dict)
def remove(self):
source_dict = self.source_ref()
if source_dict is not None and self.id in source_dict:
del source_dict[self.id]
......@@ -17,6 +17,7 @@ from helpers import MLP
import megengine as mge
import megengine._internal as mgb
import megengine.functional as F
from megengine.core import Buffer, Parameter, Tensor, tensor
from megengine.module import (
BatchNorm1d,
......@@ -37,7 +38,7 @@ class MyModule(Module):
self.bn = BatchNorm2d(4)
def forward(self, x):
x = self.bn(x)
return self.bn(x)
def __init__(self):
super().__init__()
......@@ -145,6 +146,57 @@ def test_module_api_iterable_stability():
assert list(m.modules()) == l
def test_module_api_hooks():
net = MyModule()
pre_hook_num = 0
post_hook_num = 0
hooks = []
def pre_hook(module, inputs):
nonlocal pre_hook_num
pre_hook_num += 1
modified_inputs = tuple(inp + 1 for inp in inputs)
return modified_inputs
def post_hook(module, inputs, outputs):
nonlocal post_hook_num
post_hook_num += 1
outputs += 1
return outputs
net.apply(lambda module: hooks.append(module.register_forward_pre_hook(pre_hook)))
net.apply(lambda module: hooks.append(module.register_forward_hook(post_hook)))
shape = (1, 4, 1, 1)
x = tensor(np.zeros(shape, dtype=np.float32))
y = net(x)
assert pre_hook_num == 4
assert post_hook_num == 4
mean1 = Parameter(np.zeros(shape), dtype=np.float32)
bn1 = F.batch_norm2d(
x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True
)
assertTensorClose(
net.i.bn.running_mean, mean1,
)
mean2 = Parameter(np.zeros(shape), dtype=np.float32)
bn2 = F.batch_norm2d(
bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True
)
assertTensorClose(
net.bn.running_mean, mean2,
)
assertTensorClose(bn2 + 2, y)
assert len(hooks) == 8
for handler in hooks:
handler.remove()
y = net(x)
assert pre_hook_num == 4
assert post_hook_num == 4
class MyModule2(Module):
class InnerModule(Module):
def __init__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册