theseus_layer.py 11.5 KB
Newer Older
G
gaotingquan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from typing import Tuple, List, Dict, Union, Callable, Any
G
gaotingquan 已提交
16

W
weishengyu 已提交
17
from paddle import nn
18
from ppcls.utils import logger
W
weishengyu 已提交
19 20 21 22 23 24 25 26 27 28


class Identity(nn.Layer):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, inputs):
        return inputs


W
dbg  
weishengyu 已提交
29
class TheseusLayer(nn.Layer):
W
weishengyu 已提交
30
    def __init__(self, *args, **kwargs):
W
weishengyu 已提交
31
        super(TheseusLayer, self).__init__()
W
dbg  
weishengyu 已提交
32
        self.res_dict = {}
33
        self.res_name = self.full_name()
W
weishengyu 已提交
34 35
        self.pruner = None
        self.quanter = None
W
weishengyu 已提交
36

37 38
    def _return_dict_hook(self, layer, input, output):
        res_dict = {"output": output}
39
        # 'list' is needed to avoid error raised by popping self.res_dict
40
        for res_key in list(self.res_dict):
G
gaotingquan 已提交
41
            # clear the res_dict because the forward process may change according to input
42 43 44
            res_dict[res_key] = self.res_dict.pop(res_key)
        return res_dict

45 46 47 48 49 50 51 52 53 54 55
    def init_res(self,
                 stages_pattern,
                 return_patterns=None,
                 return_stages=None):
        if return_patterns and return_stages:
            msg = f"The 'return_patterns' would be ignored when 'return_stages' is set."
            logger.warning(msg)
            return_stages = None

        if return_stages is True:
            return_patterns = stages_pattern
G
gaotingquan 已提交
56 57
        # return_stages is int or bool
        if type(return_stages) is int:
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
            return_stages = [return_stages]
        if isinstance(return_stages, list):
            if max(return_stages) > len(stages_pattern) or min(
                    return_stages) < 0:
                msg = f"The 'return_stages' set error. Illegal value(s) have been ignored. The stages' pattern list is {stages_pattern}."
                logger.warning(msg)
                return_stages = [
                    val for val in return_stages
                    if val >= 0 and val < len(stages_pattern)
                ]
            return_patterns = [stages_pattern[i] for i in return_stages]

        if return_patterns:
            self.update_res(return_patterns)

73
    def replace_sub(self, *args, **kwargs) -> None:
G
gaotingquan 已提交
74
        msg = "The function 'replace_sub()' is deprecated, please use 'upgrade_sublayer()' instead."
75 76 77
        logger.error(DeprecationWarning(msg))
        raise DeprecationWarning(msg)

G
gaotingquan 已提交
78 79 80 81
    def upgrade_sublayer(self,
                         layer_name_pattern: Union[str, List[str]],
                         handle_func: Callable[[nn.Layer, str], nn.Layer]
                         ) -> Dict[str, nn.Layer]:
82 83 84 85
        """use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.

        Args:
            layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'.
G
gaotingquan 已提交
86
            handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'. The formal params are the layer(nn.Layer) and pattern(str) that is (a member of) layer_name_pattern (when layer_name_pattern is List type). And the return is the layer processed.
87 88

        Returns:
G
gaotingquan 已提交
89 90
            Dict[str, nn.Layer]: The key is the pattern and corresponding value is the result returned by 'handle_func()'.

91 92 93 94 95
        Examples:

            from paddle import nn
            import paddleclas

G
gaotingquan 已提交
96
            def rep_func(layer: nn.Layer, pattern: str):
97
                new_layer = nn.Conv2D(
G
gaotingquan 已提交
98 99
                    in_channels=layer._in_channels,
                    out_channels=layer._out_channels,
100 101 102 103 104 105 106 107
                    kernel_size=5,
                    padding=2
                )
                return new_layer

            net = paddleclas.MobileNetV1()
            res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
            print(res)
108
            # {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
109
        """
110

111 112 113
        if not isinstance(layer_name_pattern, list):
            layer_name_pattern = [layer_name_pattern]

114
        hit_layer_pattern_list = []
115
        for pattern in layer_name_pattern:
G
gaotingquan 已提交
116
            # parse pattern to find target layer and its parent
G
gaotingquan 已提交
117 118
            layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
            if not layer_list:
119
                continue
G
gaotingquan 已提交
120 121
            sub_layer_parent = layer_list[-2]["layer"] if len(
                layer_list) > 1 else self
122

G
gaotingquan 已提交
123 124 125
            sub_layer = layer_list[-1]["layer"]
            sub_layer_name = layer_list[-1]["name"]
            sub_layer_index = layer_list[-1]["index"]
126

127
            new_sub_layer = handle_func(sub_layer, pattern)
128 129

            if sub_layer_index:
130
                getattr(sub_layer_parent,
131
                        sub_layer_name)[sub_layer_index] = new_sub_layer
W
weishengyu 已提交
132
            else:
133
                setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
134

135 136
            hit_layer_pattern_list.append(pattern)
        return hit_layer_pattern_list
137

138 139 140 141
    def stop_after(self, stop_layer_name: str) -> bool:
        """stop forward and backward after 'stop_layer_name'.

        Args:
G
gaotingquan 已提交
142
            stop_layer_name (str): The name of layer that stop forward and backward after this layer.
W
weishengyu 已提交
143

144 145 146 147
        Returns:
            bool: 'True' if successful, 'False' otherwise.
        """

G
gaotingquan 已提交
148 149 150
        layer_list = parse_pattern_str(stop_layer_name, self)
        if not layer_list:
            return False
151

G
gaotingquan 已提交
152 153 154 155 156
        parent_layer = self
        for layer_dict in layer_list:
            name, index = layer_dict["name"], layer_dict["index"]
            if not set_identity(parent_layer, name, index):
                msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'."
157 158
                logger.warning(msg)
                return False
G
gaotingquan 已提交
159 160
            parent_layer = layer_dict["layer"]

161 162
        return True

G
gaotingquan 已提交
163 164 165 166
    def update_res(
            self,
            return_patterns: Union[str, List[str]]) -> Dict[str, nn.Layer]:
        """update the result(s) to be returned.
167 168

        Args:
G
gaotingquan 已提交
169
            return_patterns (Union[str, List[str]]): The name of layer to return output.
170 171

        Returns:
G
gaotingquan 已提交
172
            Dict[str, nn.Layer]: The pattern(str) and corresponding layer(nn.Layer) that have been set successfully.
173 174
        """

G
gaotingquan 已提交
175 176 177
        # clear res_dict that could have been set
        self.res_dict = {}

178 179
        class Handler(object):
            def __init__(self, res_dict):
G
gaotingquan 已提交
180
                # res_dict is a reference
181 182 183 184 185
                self.res_dict = res_dict

            def __call__(self, layer, pattern):
                layer.res_dict = self.res_dict
                layer.res_name = pattern
G
gaotingquan 已提交
186 187 188
                if hasattr(layer, "hook_remove_helper"):
                    layer.hook_remove_helper.remove()
                layer.hook_remove_helper = layer.register_forward_post_hook(
G
gaotingquan 已提交
189
                    save_sub_res_hook)
190
                return layer
191 192 193

        handle_func = Handler(self.res_dict)

194
        hit_layer_pattern_list = self.upgrade_sublayer(
G
gaotingquan 已提交
195
            return_patterns, handle_func=handle_func)
W
weishengyu 已提交
196

G
gaotingquan 已提交
197 198 199 200
        if hasattr(self, "hook_remove_helper"):
            self.hook_remove_helper.remove()
        self.hook_remove_helper = self.register_forward_post_hook(
            self._return_dict_hook)
W
weishengyu 已提交
201

202
        return hit_layer_pattern_list
203 204


G
gaotingquan 已提交
205 206 207 208
def save_sub_res_hook(layer, input, output):
    layer.res_dict[layer.res_name] = output


209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
def set_identity(parent_layer: nn.Layer,
                 layer_name: str,
                 layer_index: str=None) -> bool:
    """set the layer specified by layer_name and layer_index to Indentity.

    Args:
        parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and layer_index.
        layer_name (str): The name of target layer to be set to Indentity.
        layer_index (str, optional): The index of target layer to be set to Indentity in parent_layer. Defaults to None.

    Returns:
        bool: True if successfully, False otherwise.
    """

    stop_after = False
    for sub_layer_name in parent_layer._sub_layers:
        if stop_after:
            parent_layer._sub_layers[sub_layer_name] = Identity()
            continue
        if sub_layer_name == layer_name:
            stop_after = True

    if layer_index and stop_after:
        stop_after = False
        for sub_layer_index in parent_layer._sub_layers[
                layer_name]._sub_layers:
            if stop_after:
                parent_layer._sub_layers[layer_name][
                    sub_layer_index] = Identity()
                continue
            if layer_index == sub_layer_index:
                stop_after = True

    return stop_after


G
gaotingquan 已提交
245 246
def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[
        None, List[Dict[str, Union[nn.Layer, str, None]]]]:
247 248 249
    """parse the string type pattern.

    Args:
G
gaotingquan 已提交
250 251
        pattern (str): The pattern to discribe layer.
        parent_layer (nn.Layer): The root layer relative to the pattern.
252 253

    Returns:
G
gaotingquan 已提交
254 255 256 257
        Union[None, List[Dict[str, Union[nn.Layer, str, None]]]]: None if failed. If successfully, the members are layers parsed in order:
                                                                [
                                                                    {"layer": first layer, "name": first layer's name parsed, "index": first layer's index parsed if exist},
                                                                    {"layer": second layer, "name": second layer's name parsed, "index": second layer's index parsed if exist},
G
gaotingquan 已提交
258 259
                                                                    ...
                                                                ]
260 261
    """

G
gaotingquan 已提交
262
    pattern_list = pattern.split(".")
263
    if not pattern_list:
G
gaotingquan 已提交
264 265 266
        msg = f"The pattern('{pattern}') is illegal. Please check and retry."
        logger.warning(msg)
        return None
267

G
gaotingquan 已提交
268
    layer_list = []
269 270
    while len(pattern_list) > 0:
        if '[' in pattern_list[0]:
271 272
            target_layer_name = pattern_list[0].split('[')[0]
            target_layer_index = pattern_list[0].split('[')[1].split(']')[0]
273
        else:
274 275
            target_layer_name = pattern_list[0]
            target_layer_index = None
276

277
        target_layer = getattr(parent_layer, target_layer_name, None)
278

279
        if target_layer is None:
G
gaotingquan 已提交
280
            msg = f"Not found layer named('{target_layer_name}') specifed in pattern('{pattern}')."
281
            logger.warning(msg)
G
gaotingquan 已提交
282
            return None
283 284 285 286

        if target_layer_index and target_layer:
            if int(target_layer_index) < 0 or int(target_layer_index) >= len(
                    target_layer):
G
gaotingquan 已提交
287
                msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0."
288
                logger.warning(msg)
G
gaotingquan 已提交
289 290
                return None

291 292
            target_layer = target_layer[target_layer_index]

G
gaotingquan 已提交
293 294 295 296 297
        layer_list.append({
            "layer": target_layer,
            "name": target_layer_name,
            "index": target_layer_index
        })
298 299

        pattern_list = pattern_list[1:]
300
        parent_layer = target_layer
G
gaotingquan 已提交
301
    return layer_list