theseus_layer.py 11.5 KB
Newer Older
G
gaotingquan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from typing import Tuple, List, Dict, Union, Callable, Any
W
weishengyu 已提交
16
from paddle import nn
17
from ppcls.utils import logger
W
weishengyu 已提交
18 19 20 21 22 23 24 25 26 27


class Identity(nn.Layer):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, inputs):
        return inputs


W
dbg  
weishengyu 已提交
28
class TheseusLayer(nn.Layer):
W
weishengyu 已提交
29
    def __init__(self, *args, **kwargs):
W
weishengyu 已提交
30
        super(TheseusLayer, self).__init__()
W
dbg  
weishengyu 已提交
31
        self.res_dict = {}
32
        self.res_name = self.full_name()
W
weishengyu 已提交
33 34
        self.pruner = None
        self.quanter = None
W
weishengyu 已提交
35

36 37
    def _return_dict_hook(self, layer, input, output):
        res_dict = {"output": output}
38
        # 'list' is needed to avoid error raised by popping self.res_dict
39
        for res_key in list(self.res_dict):
G
gaotingquan 已提交
40
            # clear the res_dict because the forward process may change according to input
41 42 43 44 45 46
            res_dict[res_key] = self.res_dict.pop(res_key)
        return res_dict

    def _save_sub_res_hook(self, layer, input, output):
        self.res_dict[self.res_name] = output

47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
    def init_res(self,
                 stages_pattern,
                 return_patterns=None,
                 return_stages=None):
        if return_patterns and return_stages:
            msg = f"The 'return_patterns' would be ignored when 'return_stages' is set."
            logger.warning(msg)
            return_stages = None

        if return_stages is True:
            return_patterns = stages_pattern
        if isinstance(return_stages, int):
            return_stages = [return_stages]
        if isinstance(return_stages, list):
            if max(return_stages) > len(stages_pattern) or min(
                    return_stages) < 0:
                msg = f"The 'return_stages' set error. Illegal value(s) have been ignored. The stages' pattern list is {stages_pattern}."
                logger.warning(msg)
                return_stages = [
                    val for val in return_stages
                    if val >= 0 and val < len(stages_pattern)
                ]
            return_patterns = [stages_pattern[i] for i in return_stages]

        if return_patterns:
            self.update_res(return_patterns)

74
    def replace_sub(self, *args, **kwargs) -> None:
G
gaotingquan 已提交
75
        msg = "The function 'replace_sub()' is deprecated, please use 'upgrade_sublayer()' instead."
76 77 78
        logger.error(DeprecationWarning(msg))
        raise DeprecationWarning(msg)

G
gaotingquan 已提交
79 80 81 82
    def upgrade_sublayer(self,
                         layer_name_pattern: Union[str, List[str]],
                         handle_func: Callable[[nn.Layer, str], nn.Layer]
                         ) -> Dict[str, nn.Layer]:
83 84 85 86
        """use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.

        Args:
            layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'.
G
gaotingquan 已提交
87
            handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'. The formal params are the layer(nn.Layer) and pattern(str) that is (a member of) layer_name_pattern (when layer_name_pattern is List type). And the return is the layer processed.
88 89

        Returns:
G
gaotingquan 已提交
90 91
            Dict[str, nn.Layer]: The key is the pattern and corresponding value is the result returned by 'handle_func()'.

92 93 94 95 96
        Examples:

            from paddle import nn
            import paddleclas

G
gaotingquan 已提交
97
            def rep_func(layer: nn.Layer, pattern: str):
98
                new_layer = nn.Conv2D(
G
gaotingquan 已提交
99 100
                    in_channels=layer._in_channels,
                    out_channels=layer._out_channels,
101 102 103 104 105 106 107 108
                    kernel_size=5,
                    padding=2
                )
                return new_layer

            net = paddleclas.MobileNetV1()
            res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
            print(res)
109
            # {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
110
        """
111

112 113 114
        if not isinstance(layer_name_pattern, list):
            layer_name_pattern = [layer_name_pattern]

115
        handle_res_dict = {}
116
        for pattern in layer_name_pattern:
G
gaotingquan 已提交
117
            # parse pattern to find target layer and its parent
G
gaotingquan 已提交
118 119
            layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
            if not layer_list:
120
                continue
G
gaotingquan 已提交
121 122
            sub_layer_parent = layer_list[-2]["layer"] if len(
                layer_list) > 1 else self
123

G
gaotingquan 已提交
124 125 126
            sub_layer = layer_list[-1]["layer"]
            sub_layer_name = layer_list[-1]["name"]
            sub_layer_index = layer_list[-1]["index"]
127

128
            new_sub_layer = handle_func(sub_layer, pattern)
129 130

            if sub_layer_index:
131
                getattr(sub_layer_parent,
132
                        sub_layer_name)[sub_layer_index] = new_sub_layer
W
weishengyu 已提交
133
            else:
134
                setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
135

136
            handle_res_dict[pattern] = new_sub_layer
137 138
        return handle_res_dict

139 140 141 142
    def stop_after(self, stop_layer_name: str) -> bool:
        """stop forward and backward after 'stop_layer_name'.

        Args:
G
gaotingquan 已提交
143
            stop_layer_name (str): The name of layer that stop forward and backward after this layer.
W
weishengyu 已提交
144

145 146 147 148
        Returns:
            bool: 'True' if successful, 'False' otherwise.
        """

G
gaotingquan 已提交
149 150 151
        layer_list = parse_pattern_str(stop_layer_name, self)
        if not layer_list:
            return False
152

G
gaotingquan 已提交
153 154 155 156 157
        parent_layer = self
        for layer_dict in layer_list:
            name, index = layer_dict["name"], layer_dict["index"]
            if not set_identity(parent_layer, name, index):
                msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'."
158 159
                logger.warning(msg)
                return False
G
gaotingquan 已提交
160 161
            parent_layer = layer_dict["layer"]

162 163
        return True

G
gaotingquan 已提交
164 165 166 167
    def update_res(
            self,
            return_patterns: Union[str, List[str]]) -> Dict[str, nn.Layer]:
        """update the result(s) to be returned.
168 169

        Args:
G
gaotingquan 已提交
170
            return_patterns (Union[str, List[str]]): The name of layer to return output.
171 172

        Returns:
G
gaotingquan 已提交
173
            Dict[str, nn.Layer]: The pattern(str) and corresponding layer(nn.Layer) that have been set successfully.
174 175
        """

G
gaotingquan 已提交
176 177 178
        # clear res_dict that could have been set
        self.res_dict = {}

179 180
        class Handler(object):
            def __init__(self, res_dict):
G
gaotingquan 已提交
181
                # res_dict is a reference
182 183 184 185 186
                self.res_dict = res_dict

            def __call__(self, layer, pattern):
                layer.res_dict = self.res_dict
                layer.res_name = pattern
G
gaotingquan 已提交
187 188 189 190
                if hasattr(layer, "hook_remove_helper"):
                    layer.hook_remove_helper.remove()
                layer.hook_remove_helper = layer.register_forward_post_hook(
                    layer._save_sub_res_hook)
191
                return layer
192 193 194

        handle_func = Handler(self.res_dict)

G
gaotingquan 已提交
195 196
        res_dict = self.upgrade_sublayer(
            return_patterns, handle_func=handle_func)
W
weishengyu 已提交
197

G
gaotingquan 已提交
198 199 200 201
        if hasattr(self, "hook_remove_helper"):
            self.hook_remove_helper.remove()
        self.hook_remove_helper = self.register_forward_post_hook(
            self._return_dict_hook)
W
weishengyu 已提交
202

G
gaotingquan 已提交
203
        return res_dict
204 205


206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
def set_identity(parent_layer: nn.Layer,
                 layer_name: str,
                 layer_index: str=None) -> bool:
    """set the layer specified by layer_name and layer_index to Indentity.

    Args:
        parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and layer_index.
        layer_name (str): The name of target layer to be set to Indentity.
        layer_index (str, optional): The index of target layer to be set to Indentity in parent_layer. Defaults to None.

    Returns:
        bool: True if successfully, False otherwise.
    """

    stop_after = False
    for sub_layer_name in parent_layer._sub_layers:
        if stop_after:
            parent_layer._sub_layers[sub_layer_name] = Identity()
            continue
        if sub_layer_name == layer_name:
            stop_after = True

    if layer_index and stop_after:
        stop_after = False
        for sub_layer_index in parent_layer._sub_layers[
                layer_name]._sub_layers:
            if stop_after:
                parent_layer._sub_layers[layer_name][
                    sub_layer_index] = Identity()
                continue
            if layer_index == sub_layer_index:
                stop_after = True

    return stop_after


G
gaotingquan 已提交
242 243
def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[
        None, List[Dict[str, Union[nn.Layer, str, None]]]]:
244 245 246
    """parse the string type pattern.

    Args:
G
gaotingquan 已提交
247 248
        pattern (str): The pattern to discribe layer.
        parent_layer (nn.Layer): The root layer relative to the pattern.
249 250

    Returns:
G
gaotingquan 已提交
251 252 253 254
        Union[None, List[Dict[str, Union[nn.Layer, str, None]]]]: None if failed. If successfully, the members are layers parsed in order:
                                                                [
                                                                    {"layer": first layer, "name": first layer's name parsed, "index": first layer's index parsed if exist},
                                                                    {"layer": second layer, "name": second layer's name parsed, "index": second layer's index parsed if exist},
G
gaotingquan 已提交
255 256
                                                                    ...
                                                                ]
257 258
    """

G
gaotingquan 已提交
259
    pattern_list = pattern.split(".")
260
    if not pattern_list:
G
gaotingquan 已提交
261 262 263
        msg = f"The pattern('{pattern}') is illegal. Please check and retry."
        logger.warning(msg)
        return None
264

G
gaotingquan 已提交
265
    layer_list = []
266 267
    while len(pattern_list) > 0:
        if '[' in pattern_list[0]:
268 269
            target_layer_name = pattern_list[0].split('[')[0]
            target_layer_index = pattern_list[0].split('[')[1].split(']')[0]
270
        else:
271 272
            target_layer_name = pattern_list[0]
            target_layer_index = None
273

274
        target_layer = getattr(parent_layer, target_layer_name, None)
275

276
        if target_layer is None:
G
gaotingquan 已提交
277
            msg = f"Not found layer named('{target_layer_name}') specifed in pattern('{pattern}')."
278
            logger.warning(msg)
G
gaotingquan 已提交
279
            return None
280 281 282 283

        if target_layer_index and target_layer:
            if int(target_layer_index) < 0 or int(target_layer_index) >= len(
                    target_layer):
G
gaotingquan 已提交
284
                msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0."
285
                logger.warning(msg)
G
gaotingquan 已提交
286 287
                return None

288 289
            target_layer = target_layer[target_layer_index]

G
gaotingquan 已提交
290 291 292 293 294
        layer_list.append({
            "layer": target_layer,
            "name": target_layer_name,
            "index": target_layer_index
        })
295 296

        pattern_list = pattern_list[1:]
297
        parent_layer = target_layer
G
gaotingquan 已提交
298
    return layer_list