theseus_layer.py 8.4 KB
Newer Older
1
from typing import List, Dict, Union, Callable, Any
W
weishengyu 已提交
2
from paddle import nn
3
from ppcls.utils import logger
W
weishengyu 已提交
4 5 6 7 8 9 10 11 12 13


class Identity(nn.Layer):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, inputs):
        return inputs


W
dbg  
weishengyu 已提交
14
class TheseusLayer(nn.Layer):
W
weishengyu 已提交
15
    def __init__(self, *args, **kwargs):
W
weishengyu 已提交
16
        super(TheseusLayer, self).__init__()
W
dbg  
weishengyu 已提交
17
        self.res_dict = {}
18
        self.res_name = self.full_name()
W
weishengyu 已提交
19 20
        self.pruner = None
        self.quanter = None
W
weishengyu 已提交
21

22 23
    def _return_dict_hook(self, layer, input, output):
        res_dict = {"output": output}
24
        # 'list' is needed to avoid error raised by popping self.res_dict
25 26 27 28 29 30 31
        for res_key in list(self.res_dict):
            res_dict[res_key] = self.res_dict.pop(res_key)
        return res_dict

    def _save_sub_res_hook(self, layer, input, output):
        self.res_dict[self.res_name] = output

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
    def replace_sub(self,
                    layer_name_pattern: Union[str, List[str]],
                    handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[
                        str, nn.Layer]:
        """use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.

        Args:
            layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'.
            handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'.

        Returns:
            Dict[str, nn.Layer]: The key is the patter and corresponding value is the result returned by 'handle_func'.
        
        Examples:

            from paddle import nn
            import paddleclas

            def rep_func(sub_layer: nn.Layer, pattern: str):
                new_layer = nn.Conv2D(
                    in_channels=sub_layer._in_channels,
                    out_channels=sub_layer._out_channels,
                    kernel_size=5,
                    padding=2
                )
                return new_layer

            net = paddleclas.MobileNetV1()
            res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
            print(res)
            # {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
        """
        if not isinstance(layer_name_pattern, list):
            layer_name_pattern = [layer_name_pattern]

67
        handle_res_dict = {}
68
        for pattern in layer_name_pattern:
69
            pattern_list = pattern.split(".")
70 71

            # find parent layer of sub-layer specified by pattern
72 73 74 75
            sub_layer_parent = self
            while len(pattern_list) > 1:
                if '[' in pattern_list[0]:
                    sub_layer_name = pattern_list[0].split('[')[0]
76 77 78 79
                    sub_layer_index = pattern_list[0].split('[')[1].split(']')[
                        0]
                    sub_layer_parent = getattr(sub_layer_parent,
                                               sub_layer_name)[sub_layer_index]
80
                else:
81 82
                    sub_layer_parent = getattr(sub_layer_parent,
                                               pattern_list[0], None)
83 84 85 86 87 88
                    if sub_layer_parent is None:
                        break
                if isinstance(sub_layer_parent, WrapLayer):
                    sub_layer_parent = sub_layer_parent.sub_layer
                pattern_list = pattern_list[1:]
            if sub_layer_parent is None:
89
                msg = f"Not found parent layer of sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})."
90
                logger.warning(msg)
91
                continue
92 93

            # find sub-layer specified by pattern
94 95 96
            if '[' in pattern_list[0]:
                sub_layer_name = pattern_list[0].split('[')[0]
                sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
97 98 99 100
            else:
                sub_layer_name = pattern_list[0]
                sub_layer_index = None

101
            sub_layer = getattr(sub_layer_parent, sub_layer_name, None)
102

103
            if not sub_layer:
104 105 106 107
                msg = f"Not found sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})."
                logger.warning(msg)
                continue

108 109 110 111 112 113 114
            if sub_layer_index is not None:
                if int(sub_layer_index) < 0 or int(sub_layer_index) >= len(
                        sub_layer):
                    msg = f"Not found sub-layer by index({sub_layer_index}) specifed in pattern({pattern})."
                    logger.warning(msg)
                    continue
                sub_layer = sub_layer[sub_layer_index]
115

116
            new_sub_layer = handle_func(sub_layer, pattern)
117 118

            if sub_layer_index:
119
                getattr(sub_layer_parent,
120
                        sub_layer_name)[sub_layer_index] = new_sub_layer
W
weishengyu 已提交
121
            else:
122
                setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
123

124
            handle_res_dict[pattern] = new_sub_layer
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
        return handle_res_dict

    def _set_identity(self, layer, layer_name, layer_index=None):
        stop_after = False
        for sub_layer_name in layer._sub_layers:
            if stop_after:
                layer._sub_layers[sub_layer_name] = Identity()
                continue
            if sub_layer_name == layer_name:
                stop_after = True

        if layer_index and stop_after:
            stop_after = False
            for sub_layer_index in layer._sub_layers[layer_name]._sub_layers:
                if stop_after:
                    layer._sub_layers[layer_name][sub_layer_index] = Identity()
                    continue
                if layer_index == sub_layer_index:
                    stop_after = True

        return stop_after
146 147 148 149 150

    def stop_after(self, stop_layer_name: str) -> bool:
        """stop forward and backward after 'stop_layer_name'.

        Args:
G
gaotingquan 已提交
151
            stop_layer_name (str): The name of layer that stop forward and backward after this layer.
W
weishengyu 已提交
152

153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
        Returns:
            bool: 'True' if successful, 'False' otherwise.
        """
        pattern_list = stop_layer_name.split(".")
        to_identity_list = []

        layer = self
        while len(pattern_list) > 0:
            layer_parent = layer
            if '[' in pattern_list[0]:
                sub_layer_name = pattern_list[0].split('[')[0]
                sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
                layer = getattr(layer, sub_layer_name)[sub_layer_index]
            else:
                sub_layer_name = pattern_list[0]
                sub_layer_index = None
                layer = getattr(layer, sub_layer_name, None)
            if layer is None:
171
                msg = f"Not found layer by name({pattern_list[0]}) specifed in stop_layer_name({stop_layer_name})."
172 173 174 175 176 177 178 179 180 181 182 183 184 185
                logger.warning(msg)
                return False

            to_identity_list.append(
                (layer_parent, sub_layer_name, sub_layer_index))
            pattern_list = pattern_list[1:]

        for to_identity_layer in to_identity_list:
            if not self._set_identity(*to_identity_layer):
                msg = "Failed to set the layers that after stop_layer_name to IdentityLayer."
                logger.warning(msg)
                return False
        return True

186 187
    def update_res(self,
                   return_patterns: Union[str, List[str]]) -> Dict[str, bool]:
G
gaotingquan 已提交
188
        """update the results to be returned.
189 190

        Args:
G
gaotingquan 已提交
191
            return_patterns (Union[str, List[str]]): The name of layer to return output.
192 193

        Returns:
G
gaotingquan 已提交
194
            Dict[str, bool]: The pattern(str) is be set successfully if 'True'(bool), failed if 'False'(bool).
195 196 197 198 199 200 201 202 203 204
        """

        class Handler(object):
            def __init__(self, res_dict):
                self.res_dict = res_dict

            def __call__(self, layer, pattern):
                layer.res_dict = self.res_dict
                layer.res_name = pattern
                layer.register_forward_post_hook(layer._save_sub_res_hook)
205
                return layer
206 207 208

        handle_func = Handler(self.res_dict)

209
        return self.replace_sub(return_patterns, handle_func=handle_func)
W
weishengyu 已提交
210 211 212


class WrapLayer(TheseusLayer):
213
    def __init__(self, sub_layer):
W
weishengyu 已提交
214 215 216 217
        super(WrapLayer, self).__init__()
        self.sub_layer = sub_layer

    def forward(self, *inputs, **kwargs):
W
dbg  
weishengyu 已提交
218
        return self.sub_layer(*inputs, **kwargs)
W
weishengyu 已提交
219

220 221 222

def wrap_theseus(sub_layer):
    wrapped_layer = WrapLayer(sub_layer)
W
weishengyu 已提交
223
    return wrapped_layer