diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index 12a7a2729dfc682a5a58d96207b7e8a0f3ab4002..40f5d317140036f21c265681b0f64a4e687b0bfb 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Union, Callable, Any +from typing import Tuple, List, Dict, Union, Callable, Any from paddle import nn from ppcls.utils import logger @@ -61,23 +61,28 @@ class TheseusLayer(nn.Layer): print(res) # {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True} """ + if not isinstance(layer_name_pattern, list): layer_name_pattern = [layer_name_pattern] handle_res_dict = {} for pattern in layer_name_pattern: - # pattern_list = pattern.split(".") - # find parent layer of sub-layer specified by pattern - sub_layer_parent, _, _ = parse_pattern_str( - pattern=pattern, idx=(0, -1), sub_layer_parent=self) + sub_layer_parent = None + for target_layer_dict in parse_pattern_str( + pattern=pattern, idx=(0, -1), parent_layer=self): + sub_layer_parent = target_layer_dict["target_layer"] if not sub_layer_parent: continue # find sub-layer specified by pattern - sub_layer, sub_layer_name, sub_layer_index = parse_pattern_str( - pattern=pattern, idx=-1, sub_layer_parent=sub_layer_parent) + sub_layer = None + for target_layer_dict in parse_pattern_str( + pattern=pattern, idx=-1, parent_layer=sub_layer_parent): + sub_layer = target_layer_dict["target_layer"] + sub_layer_name = target_layer_dict["target_layer_name"] + sub_layer_index = target_layer_dict["target_layer_index"] if not sub_layer: continue @@ -93,26 +98,6 @@ class TheseusLayer(nn.Layer): handle_res_dict[pattern] = new_sub_layer return handle_res_dict - def _set_identity(self, layer, layer_name, layer_index=None): - stop_after = False - for sub_layer_name in layer._sub_layers: - if stop_after: - layer._sub_layers[sub_layer_name] = Identity() - continue - if sub_layer_name == layer_name: - stop_after = True - - if layer_index and stop_after: - stop_after = False - for sub_layer_index in layer._sub_layers[layer_name]._sub_layers: - if stop_after: - layer._sub_layers[layer_name][sub_layer_index] = Identity() - continue - if layer_index == sub_layer_index: - stop_after = True - - return stop_after - def stop_after(self, stop_layer_name: str) -> bool: """stop forward and backward after 'stop_layer_name'. @@ -122,32 +107,18 @@ class TheseusLayer(nn.Layer): Returns: bool: 'True' if successful, 'False' otherwise. """ - pattern_list = stop_layer_name.split(".") - to_identity_list = [] - # TODO(gaotingquan): replace code by self._parse_pattern_str() - layer = self - while len(pattern_list) > 0: - layer_parent = layer - if '[' in pattern_list[0]: - sub_layer_name = pattern_list[0].split('[')[0] - sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] - layer = getattr(layer, sub_layer_name)[sub_layer_index] - else: - sub_layer_name = pattern_list[0] - sub_layer_index = None - layer = getattr(layer, sub_layer_name, None) - if layer is None: - msg = f"Not found layer by name({pattern_list[0]}) specifed in stop_layer_name({stop_layer_name})." - logger.warning(msg) - return False + to_identity_list = [] + for target_layer_dict in parse_pattern_str(stop_layer_name, self): + sub_layer_name = target_layer_dict["target_layer_name"] + sub_layer_index = target_layer_dict["target_layer_index"] + parent_layer = target_layer_dict["parent_layer"] to_identity_list.append( - (layer_parent, sub_layer_name, sub_layer_index)) - pattern_list = pattern_list[1:] + (parent_layer, sub_layer_name, sub_layer_index)) for to_identity_layer in to_identity_list: - if not self._set_identity(*to_identity_layer): + if not set_identity(*to_identity_layer): msg = "Failed to set the layers that after stop_layer_name to IdentityLayer." logger.warning(msg) return False @@ -198,58 +169,135 @@ def unwrap_theseus(sub_layer): return sub_layer -def slice_pattern(pattern, idx): +def set_identity(parent_layer: nn.Layer, + layer_name: str, + layer_index: str=None) -> bool: + """set the layer specified by layer_name and layer_index to Indentity. + + Args: + parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and layer_index. + layer_name (str): The name of target layer to be set to Indentity. + layer_index (str, optional): The index of target layer to be set to Indentity in parent_layer. Defaults to None. + + Returns: + bool: True if successfully, False otherwise. + """ + + stop_after = False + for sub_layer_name in parent_layer._sub_layers: + if stop_after: + parent_layer._sub_layers[sub_layer_name] = Identity() + continue + if sub_layer_name == layer_name: + stop_after = True + + if layer_index and stop_after: + stop_after = False + for sub_layer_index in parent_layer._sub_layers[ + layer_name]._sub_layers: + if stop_after: + parent_layer._sub_layers[layer_name][ + sub_layer_index] = Identity() + continue + if layer_index == sub_layer_index: + stop_after = True + + return stop_after + + +def slice_pattern(pattern: str, idx: Union[Tuple, int]=None) -> List: + """slice the string type "pattern" to list type by separator ".". + + Args: + pattern (str): The pattern to discribe layer name. + idx (Union[Tuple, int], optional): The index(s) of sub-list of list sliced. Defaults to None. + + Returns: + List: The sub-list of list sliced by "pattern". + """ + pattern_list = pattern.split(".") if idx: - if isinstance(idx, tuple): + if isinstance(idx, Tuple): if len(idx) == 1: return pattern_list[idx[0]] elif len(idx) == 2: return pattern_list[idx[0]:idx[1]] else: - msg = f"Only support length of 'idx' is 1 or 2 when 'idx' is a tuple." + msg = f"Only support length of 'idx' is 1 or 2 when 'idx' is a Tuple." logger.warning(msg) return None elif isinstance(idx, int): return [pattern_list[idx]] else: - msg = f"Only support type of 'idx' is int or tuple." + msg = f"Only support type of 'idx' is int or Tuple." logger.warning(msg) return None return pattern_list -def parse_pattern_str(pattern, sub_layer_parent, idx=None): +def parse_pattern_str(pattern: str, parent_layer: nn.Layer, + idx=None) -> Dict[str, Union[nn.Layer, None, str]]: + """parse the string type pattern. + + Args: + pattern (str): The pattern to discribe layer name. + parent_layer (nn.Layer): The parent layer of target layer(s) specified by "pattern". + idx ([type], optional): [description]. The index(s) of sub-list of list sliced. Defaults to None. + + Returns: + Dict[str, Union[nn.Layer, None, str]]: Dict["target_layer": Union[nn.Layer, None], "target_layer_name": str, "target_layer_index": str, "parent_layer": nn.Layer] + + Yields: + Iterator[Dict[str, Union[nn.Layer, None, str]]]: Dict["target_layer": Union[nn.Layer, None], "target_layer_name": str, "target_layer_index": str, "parent_layer": nn.Layer] + """ + pattern_list = slice_pattern(pattern, idx) if not pattern_list: return None, None, None while len(pattern_list) > 0: if '[' in pattern_list[0]: - sub_layer_name = pattern_list[0].split('[')[0] - sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] + target_layer_name = pattern_list[0].split('[')[0] + target_layer_index = pattern_list[0].split('[')[1].split(']')[0] else: - sub_layer_name = pattern_list[0] - sub_layer_index = None + target_layer_name = pattern_list[0] + target_layer_index = None - sub_layer_parent = getattr(sub_layer_parent, sub_layer_name, None) - sub_layer_parent = unwrap_theseus(sub_layer_parent) + target_layer = getattr(parent_layer, target_layer_name, None) + target_layer = unwrap_theseus(target_layer) - if sub_layer_parent is None: - msg = f"Not found layer named({sub_layer_name}) specifed in pattern({pattern})." + if target_layer is None: + msg = f"Not found layer named({target_layer_name}) specifed in pattern({pattern})." logger.warning(msg) - return None, sub_layer_name, sub_layer_index - - if sub_layer_index and sub_layer_parent: - if int(sub_layer_index) < 0 or int(sub_layer_index) >= len( - sub_layer_parent): - msg = f"Not found layer by index({sub_layer_index}) specifed in pattern({pattern}). The lenght of sub_layer's parent layer is < '{len(sub_layer_parent)}' and > '0'." + return { + "target_layer": None, + "target_layer_name": target_layer_name, + "target_layer_index": target_layer_index, + "parent_layer": parent_layer + } + + if target_layer_index and target_layer: + if int(target_layer_index) < 0 or int(target_layer_index) >= len( + target_layer): + msg = f"Not found layer by index({target_layer_index}) specifed in pattern({pattern}). The lenght of sub_layer's parent layer is < '{len(parent_layer)}' and > '0'." logger.warning(msg) - return None, sub_layer_name, sub_layer_index - sub_layer_parent = sub_layer_parent[sub_layer_index] - sub_layer_parent = unwrap_theseus(sub_layer_parent) + return { + "target_layer": None, + "target_layer_name": target_layer_name, + "target_layer_index": target_layer_index, + "parent_layer": parent_layer + } + target_layer = target_layer[target_layer_index] + target_layer = unwrap_theseus(target_layer) + + yield { + "target_layer": target_layer, + "target_layer_name": target_layer_name, + "target_layer_index": target_layer_index, + "parent_layer": parent_layer + } pattern_list = pattern_list[1:] - - return sub_layer_parent, sub_layer_name, sub_layer_index + parent_layer = target_layer