theseus_layer.py 2.6 KB
Newer Older
W
weishengyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
from abc import ABC
from paddle import nn
import re


class Identity(nn.Layer):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, inputs):
        return inputs


W
dbg  
weishengyu 已提交
14
class TheseusLayer(nn.Layer):
15
    def __init__(self, *args, return_patterns=None, **kwargs):
W
weishengyu 已提交
16 17 18
        super(TheseusLayer, self).__init__()
        self.res_dict = None
        self.register_forward_post_hook(self._disconnect_res_dict_hook)
19 20
        if return_patterns is not None:
            self._update_res(return_patterns)
W
weishengyu 已提交
21 22 23 24 25

    def forward(self, *input, res_dict=None, **kwargs):
        if res_dict is not None:
            self.res_dict = res_dict

26 27
    # stop doesn't work when stop layer has a parallel branch.
    def stop_after(self, stop_layer_name: str):
W
weishengyu 已提交
28 29 30 31
        after_stop = False
        for layer_i in self._sub_layers:
            if after_stop:
                self._sub_layers[layer_i] = Identity()
32 33 34 35 36 37 38 39 40 41 42 43
                continue
            layer_name = self._sub_layers[layer_i].full_name()
            if layer_name == stop_layer_name:
                after_stop = True
                continue
            if isinstance(self._sub_layers[layer_i], TheseusLayer):
                after_stop = self._sub_layers[layer_i].stop_after(stop_layer_name)
        return after_stop

    def _update_res(self, return_layers):
        for layer_i in self._sub_layers:
            layer_name = self._sub_layers[layer_i].full_name()
W
weishengyu 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
            for return_pattern in return_layers:
                if return_layers is not None and re.match(return_pattern, layer_name):
                    self._sub_layers[layer_i].register_forward_post_hook(self._save_sub_res_hook)

    def _save_sub_res_hook(self, layer, input, output):
        self.res_dict[layer.full_name()] = output

    def _disconnect_res_dict_hook(self, input, output):
        self.res_dict = None

    def replace_sub(self, layer_name_pattern, replace_function, recursive=True):
        for layer_i in self._sub_layers:
            layer_name = self._sub_layers[layer_i].full_name()
            if re.match(layer_name_pattern, layer_name):
                self._sub_layers[layer_i] = replace_function(self._sub_layers[layer_i])
            if recursive and isinstance(self._sub_layers[layer_i], TheseusLayer):
                self._sub_layers[layer_i].replace_sub(layer_name_pattern, replace_function, recursive)

    '''
    example of replace function:
    def replace_conv(origin_conv: nn.Conv2D):
        new_conv = nn.Conv2D(
            in_channels=origin_conv._in_channels,
            out_channels=origin_conv._out_channels,
            kernel_size=origin_conv._kernel_size,
            stride=2
        )
        return new_conv

        '''