提交 6963c258 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: fix bug

bug is that error rasied by calling update_res() when sublayer does not have _save_sub_res_hook() because nn.Sequential and nn.LayerList is not TheseusLayer.
上级 a6a76f3d
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Tuple, List, Dict, Union, Callable, Any from typing import Tuple, List, Dict, Union, Callable, Any
from paddle import nn from paddle import nn
from ppcls.utils import logger from ppcls.utils import logger
...@@ -41,9 +42,6 @@ class TheseusLayer(nn.Layer): ...@@ -41,9 +42,6 @@ class TheseusLayer(nn.Layer):
res_dict[res_key] = self.res_dict.pop(res_key) res_dict[res_key] = self.res_dict.pop(res_key)
return res_dict return res_dict
def _save_sub_res_hook(self, layer, input, output):
self.res_dict[self.res_name] = output
def init_res(self, def init_res(self,
stages_pattern, stages_pattern,
return_patterns=None, return_patterns=None,
...@@ -55,7 +53,8 @@ class TheseusLayer(nn.Layer): ...@@ -55,7 +53,8 @@ class TheseusLayer(nn.Layer):
if return_stages is True: if return_stages is True:
return_patterns = stages_pattern return_patterns = stages_pattern
if isinstance(return_stages, int): # return_stages is int or bool
if type(return_stages) is int:
return_stages = [return_stages] return_stages = [return_stages]
if isinstance(return_stages, list): if isinstance(return_stages, list):
if max(return_stages) > len(stages_pattern) or min( if max(return_stages) > len(stages_pattern) or min(
...@@ -187,7 +186,7 @@ class TheseusLayer(nn.Layer): ...@@ -187,7 +186,7 @@ class TheseusLayer(nn.Layer):
if hasattr(layer, "hook_remove_helper"): if hasattr(layer, "hook_remove_helper"):
layer.hook_remove_helper.remove() layer.hook_remove_helper.remove()
layer.hook_remove_helper = layer.register_forward_post_hook( layer.hook_remove_helper = layer.register_forward_post_hook(
layer._save_sub_res_hook) save_sub_res_hook)
return layer return layer
handle_func = Handler(self.res_dict) handle_func = Handler(self.res_dict)
...@@ -203,6 +202,10 @@ class TheseusLayer(nn.Layer): ...@@ -203,6 +202,10 @@ class TheseusLayer(nn.Layer):
return hit_layer_pattern_list return hit_layer_pattern_list
def save_sub_res_hook(layer, input, output):
layer.res_dict[layer.res_name] = output
def set_identity(parent_layer: nn.Layer, def set_identity(parent_layer: nn.Layer,
layer_name: str, layer_name: str,
layer_index: str=None) -> bool: layer_index: str=None) -> bool:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册