提交 f8ee6c0f 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: fix result returned by stop_after

上级 cf205e13
from typing import List, Union, Callable, Any from typing import List, Dict, Union, Callable, Any
from paddle import nn from paddle import nn
from ppcls.utils import logger from ppcls.utils import logger
...@@ -19,7 +19,6 @@ class TheseusLayer(nn.Layer): ...@@ -19,7 +19,6 @@ class TheseusLayer(nn.Layer):
self.pruner = None self.pruner = None
self.quanter = None self.quanter = None
# TODO(gaotingquan): weishengyu
def _return_dict_hook(self, layer, input, output): def _return_dict_hook(self, layer, input, output):
res_dict = {"output": output} res_dict = {"output": output}
for res_key in list(self.res_dict): for res_key in list(self.res_dict):
...@@ -54,7 +53,7 @@ class TheseusLayer(nn.Layer): ...@@ -54,7 +53,7 @@ class TheseusLayer(nn.Layer):
sub_layer_parent = sub_layer_parent.sub_layer sub_layer_parent = sub_layer_parent.sub_layer
pattern_list = pattern_list[1:] pattern_list = pattern_list[1:]
if sub_layer_parent is None: if sub_layer_parent is None:
msg = f"Not found layer by name({pattern_list[0]}) specifed in pattern({pattern})." msg = f"Not found parent layer of sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})."
logger.warning(msg) logger.warning(msg)
continue continue
...@@ -62,17 +61,33 @@ class TheseusLayer(nn.Layer): ...@@ -62,17 +61,33 @@ class TheseusLayer(nn.Layer):
if '[' in pattern_list[0]: if '[' in pattern_list[0]:
sub_layer_name = pattern_list[0].split('[')[0] sub_layer_name = pattern_list[0].split('[')[0]
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
sub_layer = getattr(sub_layer_parent, else:
sub_layer_name)[sub_layer_index] sub_layer_name = pattern_list[0]
sub_layer_index = None
sub_layer = getattr(sub_layer_parent, sub_layer_name, False)
if sub_layer is False:
msg = f"Not found sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})."
logger.warning(msg)
continue
try:
sub_layer = sub_layer[
sub_layer_index] if sub_layer_index is not None else sub_layer
except KeyError as e:
msg = f"Not found sub-layer by index({sub_layer_index}) specifed in pattern({pattern})."
logger.warning(msg)
continue
if not isinstance(sub_layer, TheseusLayer): if not isinstance(sub_layer, TheseusLayer):
sub_layer = wrap_theseus(sub_layer) sub_layer = wrap_theseus(sub_layer)
if sub_layer_index:
getattr(sub_layer_parent, getattr(sub_layer_parent,
sub_layer_name)[sub_layer_index] = sub_layer sub_layer_name)[sub_layer_index] = sub_layer
else: else:
sub_layer = getattr(sub_layer_parent, pattern_list[0]) setattr(sub_layer_parent, sub_layer_name, sub_layer)
if not isinstance(sub_layer, TheseusLayer):
sub_layer = wrap_theseus(sub_layer)
setattr(sub_layer_parent, pattern_list[0], sub_layer)
handle_res = handle_func(sub_layer, pattern) handle_res = handle_func(sub_layer, pattern)
handle_res_dict[pattern] = handle_res handle_res_dict[pattern] = handle_res
...@@ -136,7 +151,7 @@ class TheseusLayer(nn.Layer): ...@@ -136,7 +151,7 @@ class TheseusLayer(nn.Layer):
return self._find_layers_handle( return self._find_layers_handle(
layer_name_pattern, handle_func=replace_function) layer_name_pattern, handle_func=replace_function)
# stop doesn't work when stop layer has a parallel branch. # TODO(weishengyu): stop doesn't work when stop layer has a parallel branch.
def stop_after(self, stop_layer_name: str) -> bool: def stop_after(self, stop_layer_name: str) -> bool:
"""stop forward and backward after 'stop_layer_name'. """stop forward and backward after 'stop_layer_name'.
...@@ -176,14 +191,15 @@ class TheseusLayer(nn.Layer): ...@@ -176,14 +191,15 @@ class TheseusLayer(nn.Layer):
return False return False
return True return True
def update_res(self, return_patterns: Union[str, List[str]]) -> bool: def update_res(self,
return_patterns: Union[str, List[str]]) -> Dict[str, bool]:
"""update the results needed returned. """update the results needed returned.
Args: Args:
return_patterns (Union[str, List[str]]): The layer(s)' name to be retruened. return_patterns (Union[str, List[str]]): [description]
Returns: Returns:
bool: 'True' if successful, 'False' otherwise. Dict[str, bool]: The pattern(str) is be set successfully if True(bool), failed otherwise.
""" """
class Handler(object): class Handler(object):
...@@ -194,6 +210,7 @@ class TheseusLayer(nn.Layer): ...@@ -194,6 +210,7 @@ class TheseusLayer(nn.Layer):
layer.res_dict = self.res_dict layer.res_dict = self.res_dict
layer.res_name = pattern layer.res_name = pattern
layer.register_forward_post_hook(layer._save_sub_res_hook) layer.register_forward_post_hook(layer._save_sub_res_hook)
return True
handle_func = Handler(self.res_dict) handle_func = Handler(self.res_dict)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册