hook.py 3.0 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5
import numpy as np

import paddle
import paddle.nn as nn

L
LielinJiang 已提交
6

L
LielinJiang 已提交
7
def is_listy(x):
L
LielinJiang 已提交
8
    return isinstance(x, (tuple, list))
L
LielinJiang 已提交
9 10 11 12


class Hook():
    "Create a hook on `m` with `hook_func`."
L
LielinJiang 已提交
13

L
LielinJiang 已提交
14
    def __init__(self, m, hook_func, is_forward=True, detach=True):
L
LielinJiang 已提交
15
        self.hook_func, self.detach, self.stored = hook_func, detach, None
L
LielinJiang 已提交
16 17 18 19 20 21 22
        f = m.register_forward_post_hook if is_forward else m.register_backward_hook
        self.hook = f(self.hook_fn)
        self.removed = False

    def hook_fn(self, module, input, output):
        "Applies `hook_func` to `module`, `input`, `output`."
        if self.detach:
L
LielinJiang 已提交
23 24 25 26
            input = (o.detach()
                     for o in input) if is_listy(input) else input.detach()
            output = (o.detach()
                      for o in output) if is_listy(output) else output.detach()
L
LielinJiang 已提交
27 28 29 30 31 32
        self.stored = self.hook_func(module, input, output)

    def remove(self):
        "Remove the hook from the model."
        if not self.removed:
            self.hook.remove()
L
LielinJiang 已提交
33 34 35 36 37 38 39
            self.removed = True

    def __enter__(self, *args):
        return self

    def __exit__(self, *args):
        self.remove()
L
LielinJiang 已提交
40 41 42 43


class Hooks():
    "Create several hooks on the modules in `ms` with `hook_func`."
L
LielinJiang 已提交
44

L
LielinJiang 已提交
45 46 47 48 49 50
    def __init__(self, ms, hook_func, is_forward=True, detach=True):
        self.hooks = []
        try:
            for m in ms:
                self.hooks.append(Hook(m, hook_func, is_forward, detach))
        except Exception as e:
L
LielinJiang 已提交
51 52 53 54 55 56 57 58 59 60
            pass

    def __getitem__(self, i: int) -> Hook:
        return self.hooks[i]

    def __len__(self) -> int:
        return len(self.hooks)

    def __iter__(self):
        return iter(self.hooks)
L
LielinJiang 已提交
61 62

    @property
L
LielinJiang 已提交
63 64
    def stored(self):
        return [o.stored for o in self]
L
LielinJiang 已提交
65 66 67

    def remove(self):
        "Remove the hooks from the model."
L
LielinJiang 已提交
68 69
        for h in self.hooks:
            h.remove()
L
LielinJiang 已提交
70

L
LielinJiang 已提交
71 72
    def __enter__(self, *args):
        return self
L
LielinJiang 已提交
73

L
LielinJiang 已提交
74 75
    def __exit__(self, *args):
        self.remove()
L
LielinJiang 已提交
76

L
LielinJiang 已提交
77 78 79 80 81 82 83

def _hook_inner(m, i, o):
    return o if isinstance(
        o, paddle.framework.Variable) else o if is_listy(o) else list(o)


def hook_output(module, detach=True, grad=False):
L
LielinJiang 已提交
84 85 86
    "Return a `Hook` that stores activations of `module` in `self.stored`"
    return Hook(module, _hook_inner, detach=detach, is_forward=not grad)

L
LielinJiang 已提交
87

L
LielinJiang 已提交
88 89 90 91
def hook_outputs(modules, detach=True, grad=False):
    "Return `Hooks` that store activations of all `modules` in `self.stored`"
    return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad)

L
LielinJiang 已提交
92 93

def model_sizes(m, size=(64, 64)):
L
LielinJiang 已提交
94 95 96 97 98
    "Pass a dummy input through the model `m` to get the various sizes of activations."
    with hook_outputs(m) as hooks:
        x = dummy_eval(m, size)
        return [o.stored.shape for o in hooks]

L
LielinJiang 已提交
99 100

def dummy_eval(m, size=(64, 64)):
L
LielinJiang 已提交
101 102 103 104
    "Pass a `dummy_batch` in evaluation mode in `m` with `size`."
    m.eval()
    return m(dummy_batch(size))

L
LielinJiang 已提交
105 106

def dummy_batch(size=(64, 64), ch_in=3):
L
LielinJiang 已提交
107 108 109
    "Create a dummy batch to go through `m` with `size`."
    arr = np.random.rand(1, ch_in, *size).astype('float32') * 2 - 1
    return paddle.to_tensor(arr)