theseus_layer.py 10.4 KB
Newer Older
G
gaotingquan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from typing import Tuple, List, Dict, Union, Callable, Any
W
weishengyu 已提交
16
from paddle import nn
17
from ppcls.utils import logger
W
weishengyu 已提交
18 19 20 21 22 23 24 25 26 27


class Identity(nn.Layer):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, inputs):
        return inputs


W
dbg  
weishengyu 已提交
28
class TheseusLayer(nn.Layer):
W
weishengyu 已提交
29
    def __init__(self, *args, **kwargs):
W
weishengyu 已提交
30
        super(TheseusLayer, self).__init__()
W
dbg  
weishengyu 已提交
31
        self.res_dict = {}
32
        self.res_name = self.full_name()
W
weishengyu 已提交
33 34
        self.pruner = None
        self.quanter = None
W
weishengyu 已提交
35

36 37
    def _return_dict_hook(self, layer, input, output):
        res_dict = {"output": output}
38
        # 'list' is needed to avoid error raised by popping self.res_dict
39
        for res_key in list(self.res_dict):
G
gaotingquan 已提交
40
            # clear the res_dict because the forward process may change according to input
41 42 43 44 45 46
            res_dict[res_key] = self.res_dict.pop(res_key)
        return res_dict

    def _save_sub_res_hook(self, layer, input, output):
        self.res_dict[self.res_name] = output

47
    def replace_sub(self, *args, **kwargs) -> None:
G
gaotingquan 已提交
48
        msg = "The function 'replace_sub()' is deprecated, please use 'layer_wrench()' instead."
49 50 51 52 53 54 55 56
        logger.error(DeprecationWarning(msg))
        raise DeprecationWarning(msg)

    # TODO(gaotingquan): what is a good name?
    def layer_wrench(self,
                     layer_name_pattern: Union[str, List[str]],
                     handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[
                         str, nn.Layer]:
57 58 59 60
        """use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.

        Args:
            layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'.
G
gaotingquan 已提交
61
            handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'. The formal params are the layer(nn.Layer) and pattern(str) that is (a member of) layer_name_pattern (when layer_name_pattern is List type). And the return is the layer processed.
62 63

        Returns:
G
gaotingquan 已提交
64 65
            Dict[str, nn.Layer]: The key is the pattern and corresponding value is the result returned by 'handle_func()'.

66 67 68 69 70
        Examples:

            from paddle import nn
            import paddleclas

G
gaotingquan 已提交
71
            def rep_func(layer: nn.Layer, pattern: str):
72
                new_layer = nn.Conv2D(
G
gaotingquan 已提交
73 74
                    in_channels=layer._in_channels,
                    out_channels=layer._out_channels,
75 76 77 78 79 80 81 82
                    kernel_size=5,
                    padding=2
                )
                return new_layer

            net = paddleclas.MobileNetV1()
            res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
            print(res)
83
            # {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
84
        """
85

86 87 88
        if not isinstance(layer_name_pattern, list):
            layer_name_pattern = [layer_name_pattern]

89
        handle_res_dict = {}
90
        for pattern in layer_name_pattern:
G
gaotingquan 已提交
91 92 93
            # parse pattern to find target layer and its parent 
            layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
            if not layer_list:
94
                continue
G
gaotingquan 已提交
95 96
            sub_layer_parent = layer_list[-2]["layer"] if len(
                layer_list) > 1 else self
97

G
gaotingquan 已提交
98 99 100
            sub_layer = layer_list[-1]["layer"]
            sub_layer_name = layer_list[-1]["name"]
            sub_layer_index = layer_list[-1]["index"]
101

102
            new_sub_layer = handle_func(sub_layer, pattern)
103 104

            if sub_layer_index:
105
                getattr(sub_layer_parent,
106
                        sub_layer_name)[sub_layer_index] = new_sub_layer
W
weishengyu 已提交
107
            else:
108
                setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
109

110
            handle_res_dict[pattern] = new_sub_layer
111 112
        return handle_res_dict

113 114 115 116
    def stop_after(self, stop_layer_name: str) -> bool:
        """stop forward and backward after 'stop_layer_name'.

        Args:
G
gaotingquan 已提交
117
            stop_layer_name (str): The name of layer that stop forward and backward after this layer.
W
weishengyu 已提交
118

119 120 121 122
        Returns:
            bool: 'True' if successful, 'False' otherwise.
        """

G
gaotingquan 已提交
123 124 125
        layer_list = parse_pattern_str(stop_layer_name, self)
        if not layer_list:
            return False
126

G
gaotingquan 已提交
127 128 129 130 131
        parent_layer = self
        for layer_dict in layer_list:
            name, index = layer_dict["name"], layer_dict["index"]
            if not set_identity(parent_layer, name, index):
                msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'."
132 133
                logger.warning(msg)
                return False
G
gaotingquan 已提交
134 135
            parent_layer = layer_dict["layer"]

136 137
        return True

G
gaotingquan 已提交
138 139 140 141
    def update_res(
            self,
            return_patterns: Union[str, List[str]]) -> Dict[str, nn.Layer]:
        """update the result(s) to be returned.
142 143

        Args:
G
gaotingquan 已提交
144
            return_patterns (Union[str, List[str]]): The name of layer to return output.
145 146

        Returns:
G
gaotingquan 已提交
147
            Dict[str, nn.Layer]: The pattern(str) and corresponding layer(nn.Layer) that have been set successfully.
148 149
        """

G
gaotingquan 已提交
150 151 152
        # clear res_dict that could have been set
        self.res_dict = {}

153 154
        class Handler(object):
            def __init__(self, res_dict):
G
gaotingquan 已提交
155
                # res_dict is a reference
156 157 158 159 160
                self.res_dict = res_dict

            def __call__(self, layer, pattern):
                layer.res_dict = self.res_dict
                layer.res_name = pattern
G
gaotingquan 已提交
161 162 163 164
                if hasattr(layer, "hook_remove_helper"):
                    layer.hook_remove_helper.remove()
                layer.hook_remove_helper = layer.register_forward_post_hook(
                    layer._save_sub_res_hook)
165
                return layer
166 167 168

        handle_func = Handler(self.res_dict)

G
gaotingquan 已提交
169
        res_dict = self.layer_wrench(return_patterns, handle_func=handle_func)
W
weishengyu 已提交
170

G
gaotingquan 已提交
171 172 173 174
        if hasattr(self, "hook_remove_helper"):
            self.hook_remove_helper.remove()
        self.hook_remove_helper = self.register_forward_post_hook(
            self._return_dict_hook)
W
weishengyu 已提交
175

G
gaotingquan 已提交
176
        return res_dict
177 178


179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
def set_identity(parent_layer: nn.Layer,
                 layer_name: str,
                 layer_index: str=None) -> bool:
    """set the layer specified by layer_name and layer_index to Indentity.

    Args:
        parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and layer_index.
        layer_name (str): The name of target layer to be set to Indentity.
        layer_index (str, optional): The index of target layer to be set to Indentity in parent_layer. Defaults to None.

    Returns:
        bool: True if successfully, False otherwise.
    """

    stop_after = False
    for sub_layer_name in parent_layer._sub_layers:
        if stop_after:
            parent_layer._sub_layers[sub_layer_name] = Identity()
            continue
        if sub_layer_name == layer_name:
            stop_after = True

    if layer_index and stop_after:
        stop_after = False
        for sub_layer_index in parent_layer._sub_layers[
                layer_name]._sub_layers:
            if stop_after:
                parent_layer._sub_layers[layer_name][
                    sub_layer_index] = Identity()
                continue
            if layer_index == sub_layer_index:
                stop_after = True

    return stop_after


G
gaotingquan 已提交
215 216
def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[
        None, List[Dict[str, Union[nn.Layer, str, None]]]]:
217 218 219
    """parse the string type pattern.

    Args:
G
gaotingquan 已提交
220 221
        pattern (str): The pattern to discribe layer.
        parent_layer (nn.Layer): The root layer relative to the pattern.
222 223

    Returns:
G
gaotingquan 已提交
224 225 226 227 228 229
        Union[None, List[Dict[str, Union[nn.Layer, str, None]]]]: None if failed. If successfully, the members are layers parsed in order: 
                                                                [   
                                                                    {"layer": first layer, "name": first layer's name parsed, "index": first layer's index parsed if exist}, 
                                                                    {"layer": second layer, "name": second layer's name parsed, "index": second layer's index parsed if exist}, 
                                                                    ...
                                                                ]
230 231
    """

G
gaotingquan 已提交
232
    pattern_list = pattern.split(".")
233
    if not pattern_list:
G
gaotingquan 已提交
234 235 236
        msg = f"The pattern('{pattern}') is illegal. Please check and retry."
        logger.warning(msg)
        return None
237

G
gaotingquan 已提交
238
    layer_list = []
239 240
    while len(pattern_list) > 0:
        if '[' in pattern_list[0]:
241 242
            target_layer_name = pattern_list[0].split('[')[0]
            target_layer_index = pattern_list[0].split('[')[1].split(']')[0]
243
        else:
244 245
            target_layer_name = pattern_list[0]
            target_layer_index = None
246

247
        target_layer = getattr(parent_layer, target_layer_name, None)
248

249
        if target_layer is None:
G
gaotingquan 已提交
250
            msg = f"Not found layer named('{target_layer_name}') specifed in pattern('{pattern}')."
251
            logger.warning(msg)
G
gaotingquan 已提交
252
            return None
253 254 255 256

        if target_layer_index and target_layer:
            if int(target_layer_index) < 0 or int(target_layer_index) >= len(
                    target_layer):
G
gaotingquan 已提交
257
                msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0."
258
                logger.warning(msg)
G
gaotingquan 已提交
259 260
                return None

261 262
            target_layer = target_layer[target_layer_index]

G
gaotingquan 已提交
263 264 265 266 267
        layer_list.append({
            "layer": target_layer,
            "name": target_layer_name,
            "index": target_layer_index
        })
268 269

        pattern_list = pattern_list[1:]
270
        parent_layer = target_layer
G
gaotingquan 已提交
271
    return layer_list