提交 0f126b75 编写于 作者: G gaotingquan 提交者: Tingquan Gao

refactor: strengthen parse_pattern_str() func

上级 8d0b0d4b
from typing import List, Dict, Union, Callable, Any from typing import Tuple, List, Dict, Union, Callable, Any
from paddle import nn from paddle import nn
from ppcls.utils import logger from ppcls.utils import logger
...@@ -61,23 +61,28 @@ class TheseusLayer(nn.Layer): ...@@ -61,23 +61,28 @@ class TheseusLayer(nn.Layer):
print(res) print(res)
# {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True} # {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
""" """
if not isinstance(layer_name_pattern, list): if not isinstance(layer_name_pattern, list):
layer_name_pattern = [layer_name_pattern] layer_name_pattern = [layer_name_pattern]
handle_res_dict = {} handle_res_dict = {}
for pattern in layer_name_pattern: for pattern in layer_name_pattern:
# pattern_list = pattern.split(".")
# find parent layer of sub-layer specified by pattern # find parent layer of sub-layer specified by pattern
sub_layer_parent, _, _ = parse_pattern_str( sub_layer_parent = None
pattern=pattern, idx=(0, -1), sub_layer_parent=self) for target_layer_dict in parse_pattern_str(
pattern=pattern, idx=(0, -1), parent_layer=self):
sub_layer_parent = target_layer_dict["target_layer"]
if not sub_layer_parent: if not sub_layer_parent:
continue continue
# find sub-layer specified by pattern # find sub-layer specified by pattern
sub_layer, sub_layer_name, sub_layer_index = parse_pattern_str( sub_layer = None
pattern=pattern, idx=-1, sub_layer_parent=sub_layer_parent) for target_layer_dict in parse_pattern_str(
pattern=pattern, idx=-1, parent_layer=sub_layer_parent):
sub_layer = target_layer_dict["target_layer"]
sub_layer_name = target_layer_dict["target_layer_name"]
sub_layer_index = target_layer_dict["target_layer_index"]
if not sub_layer: if not sub_layer:
continue continue
...@@ -93,26 +98,6 @@ class TheseusLayer(nn.Layer): ...@@ -93,26 +98,6 @@ class TheseusLayer(nn.Layer):
handle_res_dict[pattern] = new_sub_layer handle_res_dict[pattern] = new_sub_layer
return handle_res_dict return handle_res_dict
def _set_identity(self, layer, layer_name, layer_index=None):
stop_after = False
for sub_layer_name in layer._sub_layers:
if stop_after:
layer._sub_layers[sub_layer_name] = Identity()
continue
if sub_layer_name == layer_name:
stop_after = True
if layer_index and stop_after:
stop_after = False
for sub_layer_index in layer._sub_layers[layer_name]._sub_layers:
if stop_after:
layer._sub_layers[layer_name][sub_layer_index] = Identity()
continue
if layer_index == sub_layer_index:
stop_after = True
return stop_after
def stop_after(self, stop_layer_name: str) -> bool: def stop_after(self, stop_layer_name: str) -> bool:
"""stop forward and backward after 'stop_layer_name'. """stop forward and backward after 'stop_layer_name'.
...@@ -122,32 +107,18 @@ class TheseusLayer(nn.Layer): ...@@ -122,32 +107,18 @@ class TheseusLayer(nn.Layer):
Returns: Returns:
bool: 'True' if successful, 'False' otherwise. bool: 'True' if successful, 'False' otherwise.
""" """
pattern_list = stop_layer_name.split(".")
to_identity_list = []
# TODO(gaotingquan): replace code by self._parse_pattern_str() to_identity_list = []
layer = self
while len(pattern_list) > 0:
layer_parent = layer
if '[' in pattern_list[0]:
sub_layer_name = pattern_list[0].split('[')[0]
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
layer = getattr(layer, sub_layer_name)[sub_layer_index]
else:
sub_layer_name = pattern_list[0]
sub_layer_index = None
layer = getattr(layer, sub_layer_name, None)
if layer is None:
msg = f"Not found layer by name({pattern_list[0]}) specifed in stop_layer_name({stop_layer_name})."
logger.warning(msg)
return False
for target_layer_dict in parse_pattern_str(stop_layer_name, self):
sub_layer_name = target_layer_dict["target_layer_name"]
sub_layer_index = target_layer_dict["target_layer_index"]
parent_layer = target_layer_dict["parent_layer"]
to_identity_list.append( to_identity_list.append(
(layer_parent, sub_layer_name, sub_layer_index)) (parent_layer, sub_layer_name, sub_layer_index))
pattern_list = pattern_list[1:]
for to_identity_layer in to_identity_list: for to_identity_layer in to_identity_list:
if not self._set_identity(*to_identity_layer): if not set_identity(*to_identity_layer):
msg = "Failed to set the layers that after stop_layer_name to IdentityLayer." msg = "Failed to set the layers that after stop_layer_name to IdentityLayer."
logger.warning(msg) logger.warning(msg)
return False return False
...@@ -198,58 +169,135 @@ def unwrap_theseus(sub_layer): ...@@ -198,58 +169,135 @@ def unwrap_theseus(sub_layer):
return sub_layer return sub_layer
def slice_pattern(pattern, idx): def set_identity(parent_layer: nn.Layer,
layer_name: str,
layer_index: str=None) -> bool:
"""set the layer specified by layer_name and layer_index to Indentity.
Args:
parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and layer_index.
layer_name (str): The name of target layer to be set to Indentity.
layer_index (str, optional): The index of target layer to be set to Indentity in parent_layer. Defaults to None.
Returns:
bool: True if successfully, False otherwise.
"""
stop_after = False
for sub_layer_name in parent_layer._sub_layers:
if stop_after:
parent_layer._sub_layers[sub_layer_name] = Identity()
continue
if sub_layer_name == layer_name:
stop_after = True
if layer_index and stop_after:
stop_after = False
for sub_layer_index in parent_layer._sub_layers[
layer_name]._sub_layers:
if stop_after:
parent_layer._sub_layers[layer_name][
sub_layer_index] = Identity()
continue
if layer_index == sub_layer_index:
stop_after = True
return stop_after
def slice_pattern(pattern: str, idx: Union[Tuple, int]=None) -> List:
"""slice the string type "pattern" to list type by separator ".".
Args:
pattern (str): The pattern to discribe layer name.
idx (Union[Tuple, int], optional): The index(s) of sub-list of list sliced. Defaults to None.
Returns:
List: The sub-list of list sliced by "pattern".
"""
pattern_list = pattern.split(".") pattern_list = pattern.split(".")
if idx: if idx:
if isinstance(idx, tuple): if isinstance(idx, Tuple):
if len(idx) == 1: if len(idx) == 1:
return pattern_list[idx[0]] return pattern_list[idx[0]]
elif len(idx) == 2: elif len(idx) == 2:
return pattern_list[idx[0]:idx[1]] return pattern_list[idx[0]:idx[1]]
else: else:
msg = f"Only support length of 'idx' is 1 or 2 when 'idx' is a tuple." msg = f"Only support length of 'idx' is 1 or 2 when 'idx' is a Tuple."
logger.warning(msg) logger.warning(msg)
return None return None
elif isinstance(idx, int): elif isinstance(idx, int):
return [pattern_list[idx]] return [pattern_list[idx]]
else: else:
msg = f"Only support type of 'idx' is int or tuple." msg = f"Only support type of 'idx' is int or Tuple."
logger.warning(msg) logger.warning(msg)
return None return None
return pattern_list return pattern_list
def parse_pattern_str(pattern, sub_layer_parent, idx=None): def parse_pattern_str(pattern: str, parent_layer: nn.Layer,
idx=None) -> Dict[str, Union[nn.Layer, None, str]]:
"""parse the string type pattern.
Args:
pattern (str): The pattern to discribe layer name.
parent_layer (nn.Layer): The parent layer of target layer(s) specified by "pattern".
idx ([type], optional): [description]. The index(s) of sub-list of list sliced. Defaults to None.
Returns:
Dict[str, Union[nn.Layer, None, str]]: Dict["target_layer": Union[nn.Layer, None], "target_layer_name": str, "target_layer_index": str, "parent_layer": nn.Layer]
Yields:
Iterator[Dict[str, Union[nn.Layer, None, str]]]: Dict["target_layer": Union[nn.Layer, None], "target_layer_name": str, "target_layer_index": str, "parent_layer": nn.Layer]
"""
pattern_list = slice_pattern(pattern, idx) pattern_list = slice_pattern(pattern, idx)
if not pattern_list: if not pattern_list:
return None, None, None return None, None, None
while len(pattern_list) > 0: while len(pattern_list) > 0:
if '[' in pattern_list[0]: if '[' in pattern_list[0]:
sub_layer_name = pattern_list[0].split('[')[0] target_layer_name = pattern_list[0].split('[')[0]
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] target_layer_index = pattern_list[0].split('[')[1].split(']')[0]
else: else:
sub_layer_name = pattern_list[0] target_layer_name = pattern_list[0]
sub_layer_index = None target_layer_index = None
sub_layer_parent = getattr(sub_layer_parent, sub_layer_name, None) target_layer = getattr(parent_layer, target_layer_name, None)
sub_layer_parent = unwrap_theseus(sub_layer_parent) target_layer = unwrap_theseus(target_layer)
if sub_layer_parent is None: if target_layer is None:
msg = f"Not found layer named({sub_layer_name}) specifed in pattern({pattern})." msg = f"Not found layer named({target_layer_name}) specifed in pattern({pattern})."
logger.warning(msg) logger.warning(msg)
return None, sub_layer_name, sub_layer_index return {
"target_layer": None,
if sub_layer_index and sub_layer_parent: "target_layer_name": target_layer_name,
if int(sub_layer_index) < 0 or int(sub_layer_index) >= len( "target_layer_index": target_layer_index,
sub_layer_parent): "parent_layer": parent_layer
msg = f"Not found layer by index({sub_layer_index}) specifed in pattern({pattern}). The lenght of sub_layer's parent layer is < '{len(sub_layer_parent)}' and > '0'." }
if target_layer_index and target_layer:
if int(target_layer_index) < 0 or int(target_layer_index) >= len(
target_layer):
msg = f"Not found layer by index({target_layer_index}) specifed in pattern({pattern}). The lenght of sub_layer's parent layer is < '{len(parent_layer)}' and > '0'."
logger.warning(msg) logger.warning(msg)
return None, sub_layer_name, sub_layer_index return {
sub_layer_parent = sub_layer_parent[sub_layer_index] "target_layer": None,
sub_layer_parent = unwrap_theseus(sub_layer_parent) "target_layer_name": target_layer_name,
"target_layer_index": target_layer_index,
"parent_layer": parent_layer
}
target_layer = target_layer[target_layer_index]
target_layer = unwrap_theseus(target_layer)
yield {
"target_layer": target_layer,
"target_layer_name": target_layer_name,
"target_layer_index": target_layer_index,
"parent_layer": parent_layer
}
pattern_list = pattern_list[1:] pattern_list = pattern_list[1:]
parent_layer = target_layer
return sub_layer_parent, sub_layer_name, sub_layer_index
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册