From 13e8f00a377ad7efa821310e7f51784ccced6d95 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 13 Aug 2020 14:11:59 +0800 Subject: [PATCH] feat(mge/module): add forward hook support GitOrigin-RevId: c0db58df13ce12ee293026aad30b2de93e9c6f80 --- python_module/megengine/module/module.py | 90 +++++++++++++------ python_module/megengine/utils/hook.py | 23 +++++ python_module/test/unit/module/test_module.py | 54 ++++++++++- 3 files changed, 139 insertions(+), 28 deletions(-) create mode 100644 python_module/megengine/utils/hook.py diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index 66aac1e94..ecd8b1851 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -14,6 +14,7 @@ import numpy as np from .._internal.dtype import is_quantize from ..core import Buffer, Parameter, Tensor from ..logger import get_logger +from ..utils.hook import HookHandler logger = get_logger(__name__) @@ -57,19 +58,51 @@ class Module(metaclass=ABCMeta): """ def __init__(self): + # runtime attributes self.training = True self.quantize_diabled = False + # hooks + self._forward_pre_hooks = OrderedDict() + self._forward_hooks = OrderedDict() + @abstractmethod def forward(self, inputs): pass + def register_forward_pre_hook(self, hook: Callable) -> HookHandler: + """Register a hook to handle forward inputs. `hook` should be a function + + Note that `inputs` keyword inputs + + :param hook: a function that receive `module` and `inputs`, then return + a modified `inputs` or `None`. + :return: a handler with :meth:`~.HookHandler.remove` interface to delete the hook. + """ + return HookHandler(self._forward_pre_hooks, hook) + + def register_forward_hook(self, hook: Callable) -> HookHandler: + """Register a hook to handle forward results. `hook` should be a function that + receive `module`, `inputs` and `outputs`, then return a modified `outputs` or `None`. + + This method return a handler with :meth:`~.HookHandler.remove` interface to delete the hook. + """ + return HookHandler(self._forward_hooks, hook) + def __call__(self, *inputs, **kwargs): - # ToDo: Convert numpy or scalar - # Maybe ToDo: set training phase - # Maybe ToDo: set computing graph + for hook in self._forward_pre_hooks.values(): + modified_inputs = hook(self, inputs) + if modified_inputs is not None: + if not isinstance(modified_inputs, tuple): + modified_inputs = (modified_inputs,) + inputs = modified_inputs + outputs = self.forward(*inputs, **kwargs) - # Maybe ToDo: set connectivity metadata + + for hook in self._forward_hooks.values(): + modified_outputs = hook(self, inputs, outputs) + if modified_outputs is not None: + outputs = modified_outputs return outputs def _flatten( @@ -191,29 +224,6 @@ class Module(metaclass=ABCMeta): with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs ) - def replace_param( - self, params: dict, start_pos: int, seen: Optional[Set[int]] = None - ): - offset = 0 - if seen is None: - seen = set([id(self)]) - module_dict = vars(self) - for key in sorted(module_dict): - hash_id = id(module_dict[key]) - if hash_id in seen: - continue - seen.add(hash_id) - if isinstance(module_dict[key], Parameter): - if start_pos + offset in params: - assert module_dict[key].shape == params[start_pos + offset].shape - module_dict[key] = params[start_pos + offset] - offset += 1 - if isinstance(module_dict[key], Module): - offset += module_dict[key].replace_param( - params, start_pos + offset, seen - ) - return offset - def named_buffers( self, prefix: Optional[str] = None, recursive: bool = True, **kwargs ) -> Iterable[Tuple[str, Buffer]]: @@ -327,6 +337,32 @@ class Module(metaclass=ABCMeta): self.apply(fn) + def replace_param( + self, params: dict, start_pos: int, seen: Optional[Set[int]] = None + ): + """Replace module's parameters with `params`, used by :class:`~.ParamPack` to + speedup multimachine training. + """ + offset = 0 + if seen is None: + seen = set([id(self)]) + module_dict = vars(self) + for key in sorted(module_dict): + hash_id = id(module_dict[key]) + if hash_id in seen: + continue + seen.add(hash_id) + if isinstance(module_dict[key], Parameter): + if start_pos + offset in params: + assert module_dict[key].shape == params[start_pos + offset].shape + module_dict[key] = params[start_pos + offset] + offset += 1 + if isinstance(module_dict[key], Module): + offset += module_dict[key].replace_param( + params, start_pos + offset, seen + ) + return offset + def state_dict(self, rst=None, prefix="", keep_var=False): r"""Returns a dictionary containing whole states of the module. """ diff --git a/python_module/megengine/utils/hook.py b/python_module/megengine/utils/hook.py new file mode 100644 index 000000000..9864a94a1 --- /dev/null +++ b/python_module/megengine/utils/hook.py @@ -0,0 +1,23 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import weakref + + +class HookHandler: + hook_num = 0 + + def __init__(self, source_dict, hook): + self.id = HookHandler.hook_num + HookHandler.hook_num += 1 + source_dict[self.id] = hook + self.source_ref = weakref.ref(source_dict) + + def remove(self): + source_dict = self.source_ref() + if source_dict is not None and self.id in source_dict: + del source_dict[self.id] diff --git a/python_module/test/unit/module/test_module.py b/python_module/test/unit/module/test_module.py index 2181407ac..0766f6eee 100644 --- a/python_module/test/unit/module/test_module.py +++ b/python_module/test/unit/module/test_module.py @@ -17,6 +17,7 @@ from helpers import MLP import megengine as mge import megengine._internal as mgb +import megengine.functional as F from megengine.core import Buffer, Parameter, Tensor, tensor from megengine.module import ( BatchNorm1d, @@ -37,7 +38,7 @@ class MyModule(Module): self.bn = BatchNorm2d(4) def forward(self, x): - x = self.bn(x) + return self.bn(x) def __init__(self): super().__init__() @@ -145,6 +146,57 @@ def test_module_api_iterable_stability(): assert list(m.modules()) == l +def test_module_api_hooks(): + net = MyModule() + pre_hook_num = 0 + post_hook_num = 0 + hooks = [] + + def pre_hook(module, inputs): + nonlocal pre_hook_num + pre_hook_num += 1 + modified_inputs = tuple(inp + 1 for inp in inputs) + return modified_inputs + + def post_hook(module, inputs, outputs): + nonlocal post_hook_num + post_hook_num += 1 + outputs += 1 + return outputs + + net.apply(lambda module: hooks.append(module.register_forward_pre_hook(pre_hook))) + net.apply(lambda module: hooks.append(module.register_forward_hook(post_hook))) + + shape = (1, 4, 1, 1) + x = tensor(np.zeros(shape, dtype=np.float32)) + y = net(x) + + assert pre_hook_num == 4 + assert post_hook_num == 4 + mean1 = Parameter(np.zeros(shape), dtype=np.float32) + bn1 = F.batch_norm2d( + x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True + ) + assertTensorClose( + net.i.bn.running_mean, mean1, + ) + mean2 = Parameter(np.zeros(shape), dtype=np.float32) + bn2 = F.batch_norm2d( + bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True + ) + assertTensorClose( + net.bn.running_mean, mean2, + ) + assertTensorClose(bn2 + 2, y) + + assert len(hooks) == 8 + for handler in hooks: + handler.remove() + y = net(x) + assert pre_hook_num == 4 + assert post_hook_num == 4 + + class MyModule2(Module): class InnerModule(Module): def __init__(self): -- GitLab