hook.py 3.6 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15 16 17 18 19
import numpy as np

import paddle
import paddle.nn as nn

L
LielinJiang 已提交
20

L
LielinJiang 已提交
21
def is_listy(x):
L
LielinJiang 已提交
22
    return isinstance(x, (tuple, list))
L
LielinJiang 已提交
23 24 25 26


class Hook():
    "Create a hook on `m` with `hook_func`."
L
LielinJiang 已提交
27

L
LielinJiang 已提交
28
    def __init__(self, m, hook_func, is_forward=True, detach=True):
L
LielinJiang 已提交
29
        self.hook_func, self.detach, self.stored = hook_func, detach, None
L
LielinJiang 已提交
30 31 32 33 34 35 36
        f = m.register_forward_post_hook if is_forward else m.register_backward_hook
        self.hook = f(self.hook_fn)
        self.removed = False

    def hook_fn(self, module, input, output):
        "Applies `hook_func` to `module`, `input`, `output`."
        if self.detach:
L
LielinJiang 已提交
37 38 39 40
            input = (o.detach()
                     for o in input) if is_listy(input) else input.detach()
            output = (o.detach()
                      for o in output) if is_listy(output) else output.detach()
L
LielinJiang 已提交
41 42 43 44 45 46
        self.stored = self.hook_func(module, input, output)

    def remove(self):
        "Remove the hook from the model."
        if not self.removed:
            self.hook.remove()
L
LielinJiang 已提交
47 48 49 50 51 52 53
            self.removed = True

    def __enter__(self, *args):
        return self

    def __exit__(self, *args):
        self.remove()
L
LielinJiang 已提交
54 55 56 57


class Hooks():
    "Create several hooks on the modules in `ms` with `hook_func`."
L
LielinJiang 已提交
58

L
LielinJiang 已提交
59 60 61 62 63 64
    def __init__(self, ms, hook_func, is_forward=True, detach=True):
        self.hooks = []
        try:
            for m in ms:
                self.hooks.append(Hook(m, hook_func, is_forward, detach))
        except Exception as e:
L
LielinJiang 已提交
65 66 67 68 69 70 71 72 73 74
            pass

    def __getitem__(self, i: int) -> Hook:
        return self.hooks[i]

    def __len__(self) -> int:
        return len(self.hooks)

    def __iter__(self):
        return iter(self.hooks)
L
LielinJiang 已提交
75 76

    @property
L
LielinJiang 已提交
77 78
    def stored(self):
        return [o.stored for o in self]
L
LielinJiang 已提交
79 80 81

    def remove(self):
        "Remove the hooks from the model."
L
LielinJiang 已提交
82 83
        for h in self.hooks:
            h.remove()
L
LielinJiang 已提交
84

L
LielinJiang 已提交
85 86
    def __enter__(self, *args):
        return self
L
LielinJiang 已提交
87

L
LielinJiang 已提交
88 89
    def __exit__(self, *args):
        self.remove()
L
LielinJiang 已提交
90

L
LielinJiang 已提交
91 92 93

def _hook_inner(m, i, o):
    return o if isinstance(
L
LielinJiang 已提交
94
        o, paddle.fluid.framework.Variable) else o if is_listy(o) else list(o)
L
LielinJiang 已提交
95 96 97


def hook_output(module, detach=True, grad=False):
L
LielinJiang 已提交
98 99 100
    "Return a `Hook` that stores activations of `module` in `self.stored`"
    return Hook(module, _hook_inner, detach=detach, is_forward=not grad)

L
LielinJiang 已提交
101

L
LielinJiang 已提交
102 103 104 105
def hook_outputs(modules, detach=True, grad=False):
    "Return `Hooks` that store activations of all `modules` in `self.stored`"
    return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad)

L
LielinJiang 已提交
106 107

def model_sizes(m, size=(64, 64)):
L
LielinJiang 已提交
108 109 110 111 112
    "Pass a dummy input through the model `m` to get the various sizes of activations."
    with hook_outputs(m) as hooks:
        x = dummy_eval(m, size)
        return [o.stored.shape for o in hooks]

L
LielinJiang 已提交
113 114

def dummy_eval(m, size=(64, 64)):
L
LielinJiang 已提交
115 116 117 118
    "Pass a `dummy_batch` in evaluation mode in `m` with `size`."
    m.eval()
    return m(dummy_batch(size))

L
LielinJiang 已提交
119 120

def dummy_batch(size=(64, 64), ch_in=3):
L
LielinJiang 已提交
121 122 123
    "Create a dummy batch to go through `m` with `size`."
    arr = np.random.rand(1, ch_in, *size).astype('float32') * 2 - 1
    return paddle.to_tensor(arr)