theseus_layer.py 8.6 KB
Newer Older
1
from typing import List, Dict, Union, Callable, Any
W
weishengyu 已提交
2
from paddle import nn
3
from ppcls.utils import logger
W
weishengyu 已提交
4 5 6 7 8 9 10 11 12 13


class Identity(nn.Layer):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, inputs):
        return inputs


W
dbg  
weishengyu 已提交
14
class TheseusLayer(nn.Layer):
W
weishengyu 已提交
15
    def __init__(self, *args, **kwargs):
W
weishengyu 已提交
16
        super(TheseusLayer, self).__init__()
W
dbg  
weishengyu 已提交
17
        self.res_dict = {}
18
        self.res_name = self.full_name()
W
weishengyu 已提交
19 20
        self.pruner = None
        self.quanter = None
W
weishengyu 已提交
21

22 23 24 25 26 27 28 29 30 31
    def _return_dict_hook(self, layer, input, output):
        res_dict = {"output": output}
        for res_key in list(self.res_dict):
            res_dict[res_key] = self.res_dict.pop(res_key)
        return res_dict

    def _save_sub_res_hook(self, layer, input, output):
        self.res_dict[self.res_name] = output

    def _find_layers_handle(self, patterns, handle_func):
32
        handle_res_dict = {}
33 34
        for pattern in patterns:
            pattern_list = pattern.split(".")
35 36
            if not pattern_list:
                continue
37 38

            # find parent layer of sub-layer specified by pattern
39 40 41 42
            sub_layer_parent = self
            while len(pattern_list) > 1:
                if '[' in pattern_list[0]:
                    sub_layer_name = pattern_list[0].split('[')[0]
43 44 45 46
                    sub_layer_index = pattern_list[0].split('[')[1].split(']')[
                        0]
                    sub_layer_parent = getattr(sub_layer_parent,
                                               sub_layer_name)[sub_layer_index]
47
                else:
48 49
                    sub_layer_parent = getattr(sub_layer_parent,
                                               pattern_list[0], None)
50 51 52 53 54 55
                    if sub_layer_parent is None:
                        break
                if isinstance(sub_layer_parent, WrapLayer):
                    sub_layer_parent = sub_layer_parent.sub_layer
                pattern_list = pattern_list[1:]
            if sub_layer_parent is None:
56
                msg = f"Not found parent layer of sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})."
57
                logger.warning(msg)
58
                continue
59 60

            # find sub-layer specified by pattern
61 62 63
            if '[' in pattern_list[0]:
                sub_layer_name = pattern_list[0].split('[')[0]
                sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
            else:
                sub_layer_name = pattern_list[0]
                sub_layer_index = None

            sub_layer = getattr(sub_layer_parent, sub_layer_name, False)

            if sub_layer is False:
                msg = f"Not found sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})."
                logger.warning(msg)
                continue

            try:
                sub_layer = sub_layer[
                    sub_layer_index] if sub_layer_index is not None else sub_layer
            except KeyError as e:
                msg = f"Not found sub-layer by index({sub_layer_index}) specifed in pattern({pattern})."
                logger.warning(msg)
                continue

            if not isinstance(sub_layer, TheseusLayer):
                sub_layer = wrap_theseus(sub_layer)

            if sub_layer_index:
87 88
                getattr(sub_layer_parent,
                        sub_layer_name)[sub_layer_index] = sub_layer
W
weishengyu 已提交
89
            else:
90
                setattr(sub_layer_parent, sub_layer_name, sub_layer)
91

92
            handle_res = handle_func(sub_layer, pattern)
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
            handle_res_dict[pattern] = handle_res
        return handle_res_dict

    def _set_identity(self, layer, layer_name, layer_index=None):
        stop_after = False
        for sub_layer_name in layer._sub_layers:
            if stop_after:
                layer._sub_layers[sub_layer_name] = Identity()
                continue
            if sub_layer_name == layer_name:
                stop_after = True

        if layer_index and stop_after:
            stop_after = False
            for sub_layer_index in layer._sub_layers[layer_name]._sub_layers:
                if stop_after:
                    layer._sub_layers[layer_name][sub_layer_index] = Identity()
                    continue
                if layer_index == sub_layer_index:
                    stop_after = True

        return stop_after
115 116 117

    def replace_sub(self,
                    layer_name_pattern: Union[str, List[str]],
118
                    replace_function: Callable[[nn.Layer, str], Any]) -> Any:
119 120 121 122 123 124 125 126
        """use 'replace_function' to modify the 'layer_name_pattern'.

        Args:
            layer_name_pattern (str): The name of target layer variable.
            replace_function (FunctionType): The function to modify target layer,

        Returns:
            bool: 'True' if successful, 'False' otherwise.
127

128 129
        Examples:

130
            from paddle import nn
131 132
            import paddleclas

133 134 135 136 137 138
            def rep_func(warp_layer: nn.Layer, pattern: str):
                sub_layer = warp_layer.sub_layer
                new_layer = nn.Conv2D(
                    in_channels=sub_layer._in_channels,
                    out_channels=sub_layer._out_channels,
                    kernel_size=5
139
                )
140 141
                warp_layer.sub_layer = new_layer
                return True
142 143

            net = paddleclas.MobileNetV1()
144 145 146
            res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], replace_function=rep_func)
            print(res)
            # {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
147 148 149 150 151 152 153
        """

        if not isinstance(layer_name_pattern, list):
            layer_name_pattern = [layer_name_pattern]
        return self._find_layers_handle(
            layer_name_pattern, handle_func=replace_function)

154
    # TODO(weishengyu): stop doesn't work when stop layer has a parallel branch.
155 156 157 158 159
    def stop_after(self, stop_layer_name: str) -> bool:
        """stop forward and backward after 'stop_layer_name'.

        Args:
            stop_layer_name (str): The name of target layer variable.
W
weishengyu 已提交
160

161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
        Returns:
            bool: 'True' if successful, 'False' otherwise.
        """
        pattern_list = stop_layer_name.split(".")
        to_identity_list = []

        layer = self
        while len(pattern_list) > 0:
            layer_parent = layer
            if '[' in pattern_list[0]:
                sub_layer_name = pattern_list[0].split('[')[0]
                sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
                layer = getattr(layer, sub_layer_name)[sub_layer_index]
            else:
                sub_layer_name = pattern_list[0]
                sub_layer_index = None
                layer = getattr(layer, sub_layer_name, None)
            if layer is None:
179
                msg = f"Not found layer by name({pattern_list[0]}) specifed in stop_layer_name({stop_layer_name})."
180 181 182 183 184 185 186 187 188 189 190 191 192 193
                logger.warning(msg)
                return False

            to_identity_list.append(
                (layer_parent, sub_layer_name, sub_layer_index))
            pattern_list = pattern_list[1:]

        for to_identity_layer in to_identity_list:
            if not self._set_identity(*to_identity_layer):
                msg = "Failed to set the layers that after stop_layer_name to IdentityLayer."
                logger.warning(msg)
                return False
        return True

194 195
    def update_res(self,
                   return_patterns: Union[str, List[str]]) -> Dict[str, bool]:
196 197 198
        """update the results needed returned.

        Args:
199
            return_patterns (Union[str, List[str]]): [description]
200 201

        Returns:
202
            Dict[str, bool]: The pattern(str) is be set successfully if True(bool), failed otherwise.
203 204 205 206 207 208 209 210 211 212
        """

        class Handler(object):
            def __init__(self, res_dict):
                self.res_dict = res_dict

            def __call__(self, layer, pattern):
                layer.res_dict = self.res_dict
                layer.res_name = pattern
                layer.register_forward_post_hook(layer._save_sub_res_hook)
213
                return True
214 215 216 217 218 219 220 221

        handle_func = Handler(self.res_dict)

        if not isinstance(return_patterns, list):
            return_patterns = [return_patterns]

        return self._find_layers_handle(
            return_patterns, handle_func=handle_func)
W
weishengyu 已提交
222 223 224


class WrapLayer(TheseusLayer):
225
    def __init__(self, sub_layer):
W
weishengyu 已提交
226 227 228 229
        super(WrapLayer, self).__init__()
        self.sub_layer = sub_layer

    def forward(self, *inputs, **kwargs):
W
dbg  
weishengyu 已提交
230
        return self.sub_layer(*inputs, **kwargs)
W
weishengyu 已提交
231

232 233 234

def wrap_theseus(sub_layer):
    wrapped_layer = WrapLayer(sub_layer)
W
weishengyu 已提交
235
    return wrapped_layer