提交 f8ee6c0f 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: fix result returned by stop_after

上级 cf205e13
from typing import List, Union, Callable, Any
from typing import List, Dict, Union, Callable, Any
from paddle import nn
from ppcls.utils import logger
......@@ -19,7 +19,6 @@ class TheseusLayer(nn.Layer):
self.pruner = None
self.quanter = None
# TODO(gaotingquan): weishengyu
def _return_dict_hook(self, layer, input, output):
res_dict = {"output": output}
for res_key in list(self.res_dict):
......@@ -54,7 +53,7 @@ class TheseusLayer(nn.Layer):
sub_layer_parent = sub_layer_parent.sub_layer
pattern_list = pattern_list[1:]
if sub_layer_parent is None:
msg = f"Not found layer by name({pattern_list[0]}) specifed in pattern({pattern})."
msg = f"Not found parent layer of sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})."
logger.warning(msg)
continue
......@@ -62,17 +61,33 @@ class TheseusLayer(nn.Layer):
if '[' in pattern_list[0]:
sub_layer_name = pattern_list[0].split('[')[0]
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
sub_layer = getattr(sub_layer_parent,
sub_layer_name)[sub_layer_index]
if not isinstance(sub_layer, TheseusLayer):
sub_layer = wrap_theseus(sub_layer)
else:
sub_layer_name = pattern_list[0]
sub_layer_index = None
sub_layer = getattr(sub_layer_parent, sub_layer_name, False)
if sub_layer is False:
msg = f"Not found sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})."
logger.warning(msg)
continue
try:
sub_layer = sub_layer[
sub_layer_index] if sub_layer_index is not None else sub_layer
except KeyError as e:
msg = f"Not found sub-layer by index({sub_layer_index}) specifed in pattern({pattern})."
logger.warning(msg)
continue
if not isinstance(sub_layer, TheseusLayer):
sub_layer = wrap_theseus(sub_layer)
if sub_layer_index:
getattr(sub_layer_parent,
sub_layer_name)[sub_layer_index] = sub_layer
else:
sub_layer = getattr(sub_layer_parent, pattern_list[0])
if not isinstance(sub_layer, TheseusLayer):
sub_layer = wrap_theseus(sub_layer)
setattr(sub_layer_parent, pattern_list[0], sub_layer)
setattr(sub_layer_parent, sub_layer_name, sub_layer)
handle_res = handle_func(sub_layer, pattern)
handle_res_dict[pattern] = handle_res
......@@ -136,7 +151,7 @@ class TheseusLayer(nn.Layer):
return self._find_layers_handle(
layer_name_pattern, handle_func=replace_function)
# stop doesn't work when stop layer has a parallel branch.
# TODO(weishengyu): stop doesn't work when stop layer has a parallel branch.
def stop_after(self, stop_layer_name: str) -> bool:
"""stop forward and backward after 'stop_layer_name'.
......@@ -176,14 +191,15 @@ class TheseusLayer(nn.Layer):
return False
return True
def update_res(self, return_patterns: Union[str, List[str]]) -> bool:
def update_res(self,
return_patterns: Union[str, List[str]]) -> Dict[str, bool]:
"""update the results needed returned.
Args:
return_patterns (Union[str, List[str]]): The layer(s)' name to be retruened.
return_patterns (Union[str, List[str]]): [description]
Returns:
bool: 'True' if successful, 'False' otherwise.
Dict[str, bool]: The pattern(str) is be set successfully if True(bool), failed otherwise.
"""
class Handler(object):
......@@ -194,6 +210,7 @@ class TheseusLayer(nn.Layer):
layer.res_dict = self.res_dict
layer.res_name = pattern
layer.register_forward_post_hook(layer._save_sub_res_hook)
return True
handle_func = Handler(self.res_dict)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册