theseus_layer.py 9.1 KB
Newer Older
1
from typing import List, Dict, Union, Callable, Any
W
weishengyu 已提交
2
from paddle import nn
3
from ppcls.utils import logger
W
weishengyu 已提交
4 5 6 7 8 9 10 11 12 13


class Identity(nn.Layer):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, inputs):
        return inputs


W
dbg  
weishengyu 已提交
14
class TheseusLayer(nn.Layer):
W
weishengyu 已提交
15
    def __init__(self, *args, **kwargs):
W
weishengyu 已提交
16
        super(TheseusLayer, self).__init__()
W
dbg  
weishengyu 已提交
17
        self.res_dict = {}
18
        self.res_name = self.full_name()
W
weishengyu 已提交
19 20
        self.pruner = None
        self.quanter = None
W
weishengyu 已提交
21

22 23
    def _return_dict_hook(self, layer, input, output):
        res_dict = {"output": output}
24
        # 'list' is needed to avoid error raised by popping self.res_dict
25 26 27 28 29 30 31
        for res_key in list(self.res_dict):
            res_dict[res_key] = self.res_dict.pop(res_key)
        return res_dict

    def _save_sub_res_hook(self, layer, input, output):
        self.res_dict[self.res_name] = output

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
    def replace_sub(self,
                    layer_name_pattern: Union[str, List[str]],
                    handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[
                        str, nn.Layer]:
        """use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.

        Args:
            layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'.
            handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'.

        Returns:
            Dict[str, nn.Layer]: The key is the patter and corresponding value is the result returned by 'handle_func'.
        
        Examples:

            from paddle import nn
            import paddleclas

            def rep_func(sub_layer: nn.Layer, pattern: str):
                new_layer = nn.Conv2D(
                    in_channels=sub_layer._in_channels,
                    out_channels=sub_layer._out_channels,
                    kernel_size=5,
                    padding=2
                )
                return new_layer

            net = paddleclas.MobileNetV1()
            res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
            print(res)
            # {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
        """
        if not isinstance(layer_name_pattern, list):
            layer_name_pattern = [layer_name_pattern]

67
        handle_res_dict = {}
68
        for pattern in layer_name_pattern:
69
            # pattern_list = pattern.split(".")
70 71

            # find parent layer of sub-layer specified by pattern
72 73 74 75
            sub_layer_parent, _, _ = parse_pattern_str(
                pattern=pattern, idx=(0, -1), sub_layer_parent=self)

            if not sub_layer_parent:
76
                continue
77 78

            # find sub-layer specified by pattern
79 80
            sub_layer, sub_layer_name, sub_layer_index = parse_pattern_str(
                pattern=pattern, idx=-1, sub_layer_parent=sub_layer_parent)
81

82
            if not sub_layer:
83 84
                continue

85
            new_sub_layer = handle_func(sub_layer, pattern)
86 87

            if sub_layer_index:
88
                getattr(sub_layer_parent,
89
                        sub_layer_name)[sub_layer_index] = new_sub_layer
W
weishengyu 已提交
90
            else:
91
                setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
92

93
            handle_res_dict[pattern] = new_sub_layer
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
        return handle_res_dict

    def _set_identity(self, layer, layer_name, layer_index=None):
        stop_after = False
        for sub_layer_name in layer._sub_layers:
            if stop_after:
                layer._sub_layers[sub_layer_name] = Identity()
                continue
            if sub_layer_name == layer_name:
                stop_after = True

        if layer_index and stop_after:
            stop_after = False
            for sub_layer_index in layer._sub_layers[layer_name]._sub_layers:
                if stop_after:
                    layer._sub_layers[layer_name][sub_layer_index] = Identity()
                    continue
                if layer_index == sub_layer_index:
                    stop_after = True

        return stop_after
115 116 117 118 119

    def stop_after(self, stop_layer_name: str) -> bool:
        """stop forward and backward after 'stop_layer_name'.

        Args:
G
gaotingquan 已提交
120
            stop_layer_name (str): The name of layer that stop forward and backward after this layer.
W
weishengyu 已提交
121

122 123 124 125 126 127
        Returns:
            bool: 'True' if successful, 'False' otherwise.
        """
        pattern_list = stop_layer_name.split(".")
        to_identity_list = []

128
        # TODO(gaotingquan): replace code by self._parse_pattern_str()
129 130 131 132 133 134 135 136 137 138 139 140
        layer = self
        while len(pattern_list) > 0:
            layer_parent = layer
            if '[' in pattern_list[0]:
                sub_layer_name = pattern_list[0].split('[')[0]
                sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
                layer = getattr(layer, sub_layer_name)[sub_layer_index]
            else:
                sub_layer_name = pattern_list[0]
                sub_layer_index = None
                layer = getattr(layer, sub_layer_name, None)
            if layer is None:
141
                msg = f"Not found layer by name({pattern_list[0]}) specifed in stop_layer_name({stop_layer_name})."
142 143 144 145 146 147 148 149 150 151 152 153 154 155
                logger.warning(msg)
                return False

            to_identity_list.append(
                (layer_parent, sub_layer_name, sub_layer_index))
            pattern_list = pattern_list[1:]

        for to_identity_layer in to_identity_list:
            if not self._set_identity(*to_identity_layer):
                msg = "Failed to set the layers that after stop_layer_name to IdentityLayer."
                logger.warning(msg)
                return False
        return True

156 157
    def update_res(self,
                   return_patterns: Union[str, List[str]]) -> Dict[str, bool]:
G
gaotingquan 已提交
158
        """update the results to be returned.
159 160

        Args:
G
gaotingquan 已提交
161
            return_patterns (Union[str, List[str]]): The name of layer to return output.
162 163

        Returns:
G
gaotingquan 已提交
164
            Dict[str, bool]: The pattern(str) is be set successfully if 'True'(bool), failed if 'False'(bool).
165 166 167 168 169 170 171 172 173 174
        """

        class Handler(object):
            def __init__(self, res_dict):
                self.res_dict = res_dict

            def __call__(self, layer, pattern):
                layer.res_dict = self.res_dict
                layer.res_name = pattern
                layer.register_forward_post_hook(layer._save_sub_res_hook)
175
                return layer
176 177 178

        handle_func = Handler(self.res_dict)

179
        return self.replace_sub(return_patterns, handle_func=handle_func)
W
weishengyu 已提交
180 181 182


class WrapLayer(TheseusLayer):
183
    def __init__(self, sub_layer):
W
weishengyu 已提交
184 185 186 187
        super(WrapLayer, self).__init__()
        self.sub_layer = sub_layer

    def forward(self, *inputs, **kwargs):
W
dbg  
weishengyu 已提交
188
        return self.sub_layer(*inputs, **kwargs)
W
weishengyu 已提交
189

190 191

def wrap_theseus(sub_layer):
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
    return WrapLayer(sub_layer)


def unwrap_theseus(sub_layer):
    if isinstance(sub_layer, WrapLayer):
        sub_layer = sub_layer.sub_layer
    return sub_layer


def slice_pattern(pattern, idx):
    pattern_list = pattern.split(".")
    if idx:
        if isinstance(idx, tuple):
            if len(idx) == 1:
                return pattern_list[idx[0]]
            elif len(idx) == 2:
                return pattern_list[idx[0]:idx[1]]
            else:
                msg = f"Only support length of 'idx' is 1 or 2 when 'idx' is a tuple."
                logger.warning(msg)
                return None
        elif isinstance(idx, int):
            return [pattern_list[idx]]
        else:
            msg = f"Only support type of 'idx' is int or tuple."
            logger.warning(msg)
            return None

    return pattern_list


def parse_pattern_str(pattern, sub_layer_parent, idx=None):
    pattern_list = slice_pattern(pattern, idx)
    if not pattern_list:
        return None, None, None

    while len(pattern_list) > 0:
        if '[' in pattern_list[0]:
            sub_layer_name = pattern_list[0].split('[')[0]
            sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
        else:
            sub_layer_name = pattern_list[0]
            sub_layer_index = None

        sub_layer_parent = getattr(sub_layer_parent, sub_layer_name, None)
        sub_layer_parent = unwrap_theseus(sub_layer_parent)

        if sub_layer_parent is None:
            msg = f"Not found layer named({sub_layer_name}) specifed in pattern({pattern})."
            logger.warning(msg)
            return None, sub_layer_name, sub_layer_index

        if sub_layer_index and sub_layer_parent:
            if int(sub_layer_index) < 0 or int(sub_layer_index) >= len(
                    sub_layer_parent):
                msg = f"Not found layer by index({sub_layer_index}) specifed in pattern({pattern}). The lenght of sub_layer's parent layer is < '{len(sub_layer_parent)}' and > '0'."
                logger.warning(msg)
                return None, sub_layer_name, sub_layer_index
            sub_layer_parent = sub_layer_parent[sub_layer_index]
            sub_layer_parent = unwrap_theseus(sub_layer_parent)

        pattern_list = pattern_list[1:]

    return sub_layer_parent, sub_layer_name, sub_layer_index