theseus_layer.py 4.5 KB
Newer Older
W
weishengyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
from abc import ABC
from paddle import nn
import re


class Identity(nn.Layer):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, inputs):
        return inputs


W
dbg  
weishengyu 已提交
14
class TheseusLayer(nn.Layer):
W
weishengyu 已提交
15
    def __init__(self, *args, **kwargs):
W
weishengyu 已提交
16
        super(TheseusLayer, self).__init__()
W
dbg  
weishengyu 已提交
17
        self.res_dict = {}
W
weishengyu 已提交
18

19 20
    # stop doesn't work when stop layer has a parallel branch.
    def stop_after(self, stop_layer_name: str):
W
weishengyu 已提交
21 22 23 24
        after_stop = False
        for layer_i in self._sub_layers:
            if after_stop:
                self._sub_layers[layer_i] = Identity()
25 26 27 28 29 30
                continue
            layer_name = self._sub_layers[layer_i].full_name()
            if layer_name == stop_layer_name:
                after_stop = True
                continue
            if isinstance(self._sub_layers[layer_i], TheseusLayer):
D
dongshuilong 已提交
31 32
                after_stop = self._sub_layers[layer_i].stop_after(
                    stop_layer_name)
33 34
        return after_stop

W
weishengyu 已提交
35
    def _update_res(self, return_patterns):
W
weishengyu 已提交
36 37
        if not return_patterns:
            return
38
        for layer_i in self._sub_layers:
W
weishengyu 已提交
39 40
            if isinstance(self._sub_layers[layer_i], (nn.Sequential, nn.LayerList)):
                self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i], return_patterns)
41
            layer_name = self._sub_layers[layer_i].full_name()
W
weishengyu 已提交
42
            for return_pattern in return_patterns:
W
weishengyu 已提交
43 44 45 46 47 48
                if re.match(return_pattern, layer_name):
                    if not isinstance(self._sub_layers[layer_i], TheseusLayer):
                        self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i], return_patterns)
                    self._sub_layers[layer_i].register_forward_post_hook(
                        self._sub_layers[layer_i]._save_sub_res_hook)
                    self._sub_layers[layer_i].res_dict = self.res_dict
W
weishengyu 已提交
49
            if isinstance(self._sub_layers[layer_i], TheseusLayer):
W
weishengyu 已提交
50
                self._sub_layers[layer_i]._update_res(return_patterns)
W
weishengyu 已提交
51

W
weishengyu 已提交
52
    def _save_sub_res_hook(self, layer, input, output):
W
weishengyu 已提交
53 54
        if self.res_dict is not None:
            self.res_dict[layer.full_name()] = output
W
weishengyu 已提交
55 56 57 58

    def replace_sub(self, layer_name_pattern, replace_function, recursive=True):
        for layer_i in self._sub_layers:
            layer_name = self._sub_layers[layer_i].full_name()
W
weishengyu 已提交
59
            if re.match(layer_name_pattern, layer_name):
W
weishengyu 已提交
60
                self._sub_layers[layer_i] = replace_function(self._sub_layers[layer_i])
D
dongshuilong 已提交
61
            if recursive:
W
weishengyu 已提交
62 63
                if isinstance(self._sub_layers[layer_i], TheseusLayer):
                    self._sub_layers[layer_i].replace_sub(
D
dongshuilong 已提交
64
                        layer_name_pattern, replace_function, recursive)
W
weishengyu 已提交
65 66 67
                elif isinstance(self._sub_layers[layer_i], (nn.Sequential, nn.LayerList)):
                    for layer_j in self._sub_layers[layer_i]._sub_layers:
                        self._sub_layers[layer_i]._sub_layers[layer_j].replace_sub(
D
dongshuilong 已提交
68
                            layer_name_pattern, replace_function, recursive)
W
weishengyu 已提交
69 70 71 72 73 74 75 76 77 78 79 80

    '''
    example of replace function:
    def replace_conv(origin_conv: nn.Conv2D):
        new_conv = nn.Conv2D(
            in_channels=origin_conv._in_channels,
            out_channels=origin_conv._out_channels,
            kernel_size=origin_conv._kernel_size,
            stride=2
        )
        return new_conv

D
dongshuilong 已提交
81
        '''
W
weishengyu 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110


class WrapLayer(TheseusLayer):
    def __init__(self, sub_layer):
        super(WrapLayer, self).__init__()
        self.sub_layer = sub_layer
        self.name = sub_layer.full_name()

    def full_name(self):
        return self.name

    def forward(self, *inputs, **kwargs):
        self.sub_layer(*inputs, **kwargs)


def wrap_theseus(sub_layer, return_patterns):
    if isinstance(sub_layer, (nn.Sequential, nn.LayerList)):
        for layer_i in sub_layer._sub_layers:
            if isinstance(sub_layer._sub_layers[layer_i], TheseusLayer):
                continue
            elif isinstance(sub_layer._sub_layers[layer_i], (nn.Sequential, nn.LayerList)):
                wrap_theseus(sub_layer._sub_layers[layer_i], return_patterns)
            elif isinstance(sub_layer._sub_layers[layer_i], nn.Layer):
                layer_name = sub_layer._sub_layers[layer_i].full_name()
                for return_pattern in return_patterns:
                    if re.match(return_pattern, layer_name):
                        wrap_theseus(sub_layer._sub_layers[layer_i], return_patterns)
    wrapped_layer = WrapLayer(sub_layer)
    return wrapped_layer