theseus_layer.py 4.9 KB
Newer Older
W
weishengyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
from abc import ABC
from paddle import nn
import re


class Identity(nn.Layer):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, inputs):
        return inputs


W
dbg  
weishengyu 已提交
14
class TheseusLayer(nn.Layer):
W
weishengyu 已提交
15
    def __init__(self, *args, **kwargs):
W
weishengyu 已提交
16
        super(TheseusLayer, self).__init__()
W
dbg  
weishengyu 已提交
17
        self.res_dict = {}
littletomatodonkey's avatar
littletomatodonkey 已提交
18
        self.res_name = self.full_name()
W
weishengyu 已提交
19

20 21
    # stop doesn't work when stop layer has a parallel branch.
    def stop_after(self, stop_layer_name: str):
W
weishengyu 已提交
22 23 24 25
        after_stop = False
        for layer_i in self._sub_layers:
            if after_stop:
                self._sub_layers[layer_i] = Identity()
26 27 28 29 30 31
                continue
            layer_name = self._sub_layers[layer_i].full_name()
            if layer_name == stop_layer_name:
                after_stop = True
                continue
            if isinstance(self._sub_layers[layer_i], TheseusLayer):
D
dongshuilong 已提交
32 33
                after_stop = self._sub_layers[layer_i].stop_after(
                    stop_layer_name)
34 35
        return after_stop

W
weishengyu 已提交
36
    def update_res(self, return_patterns):
littletomatodonkey's avatar
littletomatodonkey 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
        for return_pattern in return_patterns:
            pattern_list = return_pattern.split(".")
            if not pattern_list:
                continue
            sub_layer_parent = self
            while len(pattern_list) > 1:
                if '[' in pattern_list[0]:
                    sub_layer_name = pattern_list[0].split('[')[0]
                    sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
                    sub_layer_parent = getattr(sub_layer_parent, sub_layer_name)[sub_layer_index]
                else:
                    sub_layer_parent = getattr(sub_layer_parent, pattern_list[0],
                                               None)
                    if sub_layer_parent is None:
                        break
                if isinstance(sub_layer_parent, WrapLayer):
                    sub_layer_parent = sub_layer_parent.sub_layer
                pattern_list = pattern_list[1:]
            if sub_layer_parent is None:
                continue
            if '[' in pattern_list[0]:
                sub_layer_name = pattern_list[0].split('[')[0]
                sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
                sub_layer = getattr(sub_layer_parent, sub_layer_name)[sub_layer_index]
                if not isinstance(sub_layer, TheseusLayer):
                    sub_layer = wrap_theseus(sub_layer)
                getattr(sub_layer_parent, sub_layer_name)[sub_layer_index] = sub_layer
W
weishengyu 已提交
64
            else:
littletomatodonkey's avatar
littletomatodonkey 已提交
65 66 67 68 69 70 71 72
                sub_layer = getattr(sub_layer_parent, pattern_list[0])
                if not isinstance(sub_layer, TheseusLayer):
                    sub_layer = wrap_theseus(sub_layer)
                setattr(sub_layer_parent, pattern_list[0], sub_layer)

            sub_layer.res_dict = self.res_dict
            sub_layer.res_name = return_pattern
            sub_layer.register_forward_post_hook(sub_layer._save_sub_res_hook)
W
weishengyu 已提交
73

W
weishengyu 已提交
74
    def _save_sub_res_hook(self, layer, input, output):
littletomatodonkey's avatar
littletomatodonkey 已提交
75
        self.res_dict[self.res_name] = output
W
weishengyu 已提交
76

W
weishengyu 已提交
77 78 79 80 81 82
    def _return_dict_hook(self, layer, input, output):
        res_dict = {"output": output}
        for res_key in list(self.res_dict):
            res_dict[res_key] = self.res_dict.pop(res_key)
        return res_dict

littletomatodonkey's avatar
littletomatodonkey 已提交
83 84
    def replace_sub(self, layer_name_pattern, replace_function,
                    recursive=True):
W
weishengyu 已提交
85 86
        for layer_i in self._sub_layers:
            layer_name = self._sub_layers[layer_i].full_name()
W
weishengyu 已提交
87
            if re.match(layer_name_pattern, layer_name):
littletomatodonkey's avatar
littletomatodonkey 已提交
88 89
                self._sub_layers[layer_i] = replace_function(self._sub_layers[
                    layer_i])
D
dongshuilong 已提交
90
            if recursive:
W
weishengyu 已提交
91 92
                if isinstance(self._sub_layers[layer_i], TheseusLayer):
                    self._sub_layers[layer_i].replace_sub(
D
dongshuilong 已提交
93
                        layer_name_pattern, replace_function, recursive)
littletomatodonkey's avatar
littletomatodonkey 已提交
94 95
                elif isinstance(self._sub_layers[layer_i],
                                (nn.Sequential, nn.LayerList)):
W
weishengyu 已提交
96
                    for layer_j in self._sub_layers[layer_i]._sub_layers:
littletomatodonkey's avatar
littletomatodonkey 已提交
97 98 99
                        self._sub_layers[layer_i]._sub_layers[
                            layer_j].replace_sub(layer_name_pattern,
                                                 replace_function, recursive)
W
weishengyu 已提交
100 101 102 103 104 105 106 107 108 109 110 111

    '''
    example of replace function:
    def replace_conv(origin_conv: nn.Conv2D):
        new_conv = nn.Conv2D(
            in_channels=origin_conv._in_channels,
            out_channels=origin_conv._out_channels,
            kernel_size=origin_conv._kernel_size,
            stride=2
        )
        return new_conv

D
dongshuilong 已提交
112
        '''
W
weishengyu 已提交
113 114 115


class WrapLayer(TheseusLayer):
littletomatodonkey's avatar
littletomatodonkey 已提交
116
    def __init__(self, sub_layer):
W
weishengyu 已提交
117 118 119 120
        super(WrapLayer, self).__init__()
        self.sub_layer = sub_layer

    def forward(self, *inputs, **kwargs):
W
dbg  
weishengyu 已提交
121
        return self.sub_layer(*inputs, **kwargs)
W
weishengyu 已提交
122

littletomatodonkey's avatar
littletomatodonkey 已提交
123 124 125

def wrap_theseus(sub_layer):
    wrapped_layer = WrapLayer(sub_layer)
W
weishengyu 已提交
126
    return wrapped_layer