提交 721ac0bf 编写于 作者: G gaotingquan 提交者: Tingquan Gao

refactor: simplify code

1. remove WrapLayer and wrap_theseus;
2. support call update_res() one more;
3. optim parse_pattern_str() to return list of layer parsed.
上级 56911b57
...@@ -23,6 +23,7 @@ class TheseusLayer(nn.Layer): ...@@ -23,6 +23,7 @@ class TheseusLayer(nn.Layer):
res_dict = {"output": output} res_dict = {"output": output}
# 'list' is needed to avoid error raised by popping self.res_dict # 'list' is needed to avoid error raised by popping self.res_dict
for res_key in list(self.res_dict): for res_key in list(self.res_dict):
# clear the res_dict because the forward process may change according to input
res_dict[res_key] = self.res_dict.pop(res_key) res_dict[res_key] = self.res_dict.pop(res_key)
return res_dict return res_dict
...@@ -30,7 +31,7 @@ class TheseusLayer(nn.Layer): ...@@ -30,7 +31,7 @@ class TheseusLayer(nn.Layer):
self.res_dict[self.res_name] = output self.res_dict[self.res_name] = output
def replace_sub(self, *args, **kwargs) -> None: def replace_sub(self, *args, **kwargs) -> None:
msg = "\"replace_sub\" is deprecated, please use \"layer_wrench\" instead." msg = "The function 'replace_sub()' is deprecated, please use 'layer_wrench()' instead."
logger.error(DeprecationWarning(msg)) logger.error(DeprecationWarning(msg))
raise DeprecationWarning(msg) raise DeprecationWarning(msg)
...@@ -43,20 +44,20 @@ class TheseusLayer(nn.Layer): ...@@ -43,20 +44,20 @@ class TheseusLayer(nn.Layer):
Args: Args:
layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'. layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'.
handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'. handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'. The formal params are the layer(nn.Layer) and pattern(str) that is (a member of) layer_name_pattern (when layer_name_pattern is List type). And the return is the layer processed.
Returns: Returns:
Dict[str, nn.Layer]: The key is the patter and corresponding value is the result returned by 'handle_func'. Dict[str, nn.Layer]: The key is the pattern and corresponding value is the result returned by 'handle_func()'.
Examples: Examples:
from paddle import nn from paddle import nn
import paddleclas import paddleclas
def rep_func(sub_layer: nn.Layer, pattern: str): def rep_func(layer: nn.Layer, pattern: str):
new_layer = nn.Conv2D( new_layer = nn.Conv2D(
in_channels=sub_layer._in_channels, in_channels=layer._in_channels,
out_channels=sub_layer._out_channels, out_channels=layer._out_channels,
kernel_size=5, kernel_size=5,
padding=2 padding=2
) )
...@@ -73,25 +74,16 @@ class TheseusLayer(nn.Layer): ...@@ -73,25 +74,16 @@ class TheseusLayer(nn.Layer):
handle_res_dict = {} handle_res_dict = {}
for pattern in layer_name_pattern: for pattern in layer_name_pattern:
# find parent layer of sub-layer specified by pattern # parse pattern to find target layer and its parent
sub_layer_parent = None layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
for target_layer_dict in parse_pattern_str( if not layer_list:
pattern=pattern, idx=(0, -1), parent_layer=self):
sub_layer_parent = target_layer_dict["target_layer"]
if not sub_layer_parent:
continue continue
sub_layer_parent = layer_list[-2]["layer"] if len(
layer_list) > 1 else self
# find sub-layer specified by pattern sub_layer = layer_list[-1]["layer"]
sub_layer = None sub_layer_name = layer_list[-1]["name"]
for target_layer_dict in parse_pattern_str( sub_layer_index = layer_list[-1]["index"]
pattern=pattern, idx=-1, parent_layer=sub_layer_parent):
sub_layer = target_layer_dict["target_layer"]
sub_layer_name = target_layer_dict["target_layer_name"]
sub_layer_index = target_layer_dict["target_layer_index"]
if not sub_layer:
continue
new_sub_layer = handle_func(sub_layer, pattern) new_sub_layer = handle_func(sub_layer, pattern)
...@@ -114,65 +106,60 @@ class TheseusLayer(nn.Layer): ...@@ -114,65 +106,60 @@ class TheseusLayer(nn.Layer):
bool: 'True' if successful, 'False' otherwise. bool: 'True' if successful, 'False' otherwise.
""" """
to_identity_list = [] layer_list = parse_pattern_str(stop_layer_name, self)
if not layer_list:
for target_layer_dict in parse_pattern_str(stop_layer_name, self): return False
sub_layer_name = target_layer_dict["target_layer_name"]
sub_layer_index = target_layer_dict["target_layer_index"]
parent_layer = target_layer_dict["parent_layer"]
to_identity_list.append(
(parent_layer, sub_layer_name, sub_layer_index))
for to_identity_layer in to_identity_list: parent_layer = self
if not set_identity(*to_identity_layer): for layer_dict in layer_list:
msg = "Failed to set the layers that after stop_layer_name to IdentityLayer." name, index = layer_dict["name"], layer_dict["index"]
if not set_identity(parent_layer, name, index):
msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'."
logger.warning(msg) logger.warning(msg)
return False return False
parent_layer = layer_dict["layer"]
return True return True
def update_res(self, def update_res(
return_patterns: Union[str, List[str]]) -> Dict[str, bool]: self,
"""update the results to be returned. return_patterns: Union[str, List[str]]) -> Dict[str, nn.Layer]:
"""update the result(s) to be returned.
Args: Args:
return_patterns (Union[str, List[str]]): The name of layer to return output. return_patterns (Union[str, List[str]]): The name of layer to return output.
Returns: Returns:
Dict[str, bool]: The pattern(str) is be set successfully if 'True'(bool), failed if 'False'(bool). Dict[str, nn.Layer]: The pattern(str) and corresponding layer(nn.Layer) that have been set successfully.
""" """
# clear res_dict that could have been set
self.res_dict = {}
class Handler(object): class Handler(object):
def __init__(self, res_dict): def __init__(self, res_dict):
# res_dict is a reference
self.res_dict = res_dict self.res_dict = res_dict
def __call__(self, layer, pattern): def __call__(self, layer, pattern):
layer.res_dict = self.res_dict layer.res_dict = self.res_dict
layer.res_name = pattern layer.res_name = pattern
layer.register_forward_post_hook(layer._save_sub_res_hook) if hasattr(layer, "hook_remove_helper"):
layer.hook_remove_helper.remove()
layer.hook_remove_helper = layer.register_forward_post_hook(
layer._save_sub_res_hook)
return layer return layer
handle_func = Handler(self.res_dict) handle_func = Handler(self.res_dict)
return self.replace_sub(return_patterns, handle_func=handle_func) res_dict = self.layer_wrench(return_patterns, handle_func=handle_func)
class WrapLayer(TheseusLayer): if hasattr(self, "hook_remove_helper"):
def __init__(self, sub_layer): self.hook_remove_helper.remove()
super(WrapLayer, self).__init__() self.hook_remove_helper = self.register_forward_post_hook(
self.sub_layer = sub_layer self._return_dict_hook)
def forward(self, *inputs, **kwargs): return res_dict
return self.sub_layer(*inputs, **kwargs)
def wrap_theseus(sub_layer):
return WrapLayer(sub_layer)
def unwrap_theseus(sub_layer):
if isinstance(sub_layer, WrapLayer):
sub_layer = sub_layer.sub_layer
return sub_layer
def set_identity(parent_layer: nn.Layer, def set_identity(parent_layer: nn.Layer,
...@@ -211,58 +198,30 @@ def set_identity(parent_layer: nn.Layer, ...@@ -211,58 +198,30 @@ def set_identity(parent_layer: nn.Layer,
return stop_after return stop_after
def slice_pattern(pattern: str, idx: Union[Tuple, int]=None) -> List: def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[
"""slice the string type "pattern" to list type by separator ".". None, List[Dict[str, Union[nn.Layer, str, None]]]]:
Args:
pattern (str): The pattern to discribe layer name.
idx (Union[Tuple, int], optional): The index(s) of sub-list of list sliced. Defaults to None.
Returns:
List: The sub-list of list sliced by "pattern".
"""
pattern_list = pattern.split(".")
if idx:
if isinstance(idx, Tuple):
if len(idx) == 1:
return pattern_list[idx[0]]
elif len(idx) == 2:
return pattern_list[idx[0]:idx[1]]
else:
msg = f"Only support length of 'idx' is 1 or 2 when 'idx' is a Tuple."
logger.warning(msg)
return None
elif isinstance(idx, int):
return [pattern_list[idx]]
else:
msg = f"Only support type of 'idx' is int or Tuple."
logger.warning(msg)
return None
return pattern_list
def parse_pattern_str(pattern: str, parent_layer: nn.Layer,
idx=None) -> Dict[str, Union[nn.Layer, None, str]]:
"""parse the string type pattern. """parse the string type pattern.
Args: Args:
pattern (str): The pattern to discribe layer name. pattern (str): The pattern to discribe layer.
parent_layer (nn.Layer): The parent layer of target layer(s) specified by "pattern". parent_layer (nn.Layer): The root layer relative to the pattern.
idx ([type], optional): [description]. The index(s) of sub-list of list sliced. Defaults to None.
Returns: Returns:
Dict[str, Union[nn.Layer, None, str]]: Dict["target_layer": Union[nn.Layer, None], "target_layer_name": str, "target_layer_index": str, "parent_layer": nn.Layer] Union[None, List[Dict[str, Union[nn.Layer, str, None]]]]: None if failed. If successfully, the members are layers parsed in order:
[
Yields: {"layer": first layer, "name": first layer's name parsed, "index": first layer's index parsed if exist},
Iterator[Dict[str, Union[nn.Layer, None, str]]]: Dict["target_layer": Union[nn.Layer, None], "target_layer_name": str, "target_layer_index": str, "parent_layer": nn.Layer] {"layer": second layer, "name": second layer's name parsed, "index": second layer's index parsed if exist},
...
]
""" """
pattern_list = slice_pattern(pattern, idx) pattern_list = pattern.split(".")
if not pattern_list: if not pattern_list:
return None, None, None msg = f"The pattern('{pattern}') is illegal. Please check and retry."
logger.warning(msg)
return None
layer_list = []
while len(pattern_list) > 0: while len(pattern_list) > 0:
if '[' in pattern_list[0]: if '[' in pattern_list[0]:
target_layer_name = pattern_list[0].split('[')[0] target_layer_name = pattern_list[0].split('[')[0]
...@@ -272,38 +231,27 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer, ...@@ -272,38 +231,27 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer,
target_layer_index = None target_layer_index = None
target_layer = getattr(parent_layer, target_layer_name, None) target_layer = getattr(parent_layer, target_layer_name, None)
target_layer = unwrap_theseus(target_layer)
if target_layer is None: if target_layer is None:
msg = f"Not found layer named({target_layer_name}) specifed in pattern({pattern})." msg = f"Not found layer named('{target_layer_name}') specifed in pattern('{pattern}')."
logger.warning(msg) logger.warning(msg)
return { return None
"target_layer": None,
"target_layer_name": target_layer_name,
"target_layer_index": target_layer_index,
"parent_layer": parent_layer
}
if target_layer_index and target_layer: if target_layer_index and target_layer:
if int(target_layer_index) < 0 or int(target_layer_index) >= len( if int(target_layer_index) < 0 or int(target_layer_index) >= len(
target_layer): target_layer):
msg = f"Not found layer by index({target_layer_index}) specifed in pattern({pattern}). The lenght of sub_layer's parent layer is < '{len(parent_layer)}' and > '0'." msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0."
logger.warning(msg) logger.warning(msg)
return { return None
"target_layer": None,
"target_layer_name": target_layer_name,
"target_layer_index": target_layer_index,
"parent_layer": parent_layer
}
target_layer = target_layer[target_layer_index] target_layer = target_layer[target_layer_index]
target_layer = unwrap_theseus(target_layer)
yield { layer_list.append({
"target_layer": target_layer, "layer": target_layer,
"target_layer_name": target_layer_name, "name": target_layer_name,
"target_layer_index": target_layer_index, "index": target_layer_index
"parent_layer": parent_layer })
}
pattern_list = pattern_list[1:] pattern_list = pattern_list[1:]
parent_layer = target_layer parent_layer = target_layer
return layer_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册