提交 6963c258 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: fix bug

bug is that error rasied by calling update_res() when sublayer does not have _save_sub_res_hook() because nn.Sequential and nn.LayerList is not TheseusLayer.
上级 a6a76f3d
......@@ -13,6 +13,7 @@
# limitations under the License.
from typing import Tuple, List, Dict, Union, Callable, Any
from paddle import nn
from ppcls.utils import logger
......@@ -41,9 +42,6 @@ class TheseusLayer(nn.Layer):
res_dict[res_key] = self.res_dict.pop(res_key)
return res_dict
def _save_sub_res_hook(self, layer, input, output):
self.res_dict[self.res_name] = output
def init_res(self,
stages_pattern,
return_patterns=None,
......@@ -55,7 +53,8 @@ class TheseusLayer(nn.Layer):
if return_stages is True:
return_patterns = stages_pattern
if isinstance(return_stages, int):
# return_stages is int or bool
if type(return_stages) is int:
return_stages = [return_stages]
if isinstance(return_stages, list):
if max(return_stages) > len(stages_pattern) or min(
......@@ -187,7 +186,7 @@ class TheseusLayer(nn.Layer):
if hasattr(layer, "hook_remove_helper"):
layer.hook_remove_helper.remove()
layer.hook_remove_helper = layer.register_forward_post_hook(
layer._save_sub_res_hook)
save_sub_res_hook)
return layer
handle_func = Handler(self.res_dict)
......@@ -203,6 +202,10 @@ class TheseusLayer(nn.Layer):
return hit_layer_pattern_list
def save_sub_res_hook(layer, input, output):
layer.res_dict[layer.res_name] = output
def set_identity(parent_layer: nn.Layer,
layer_name: str,
layer_index: str=None) -> bool:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册