提交 43fce425 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: change the returned result from dict to list

上级 3b9f6292
...@@ -112,7 +112,7 @@ class TheseusLayer(nn.Layer): ...@@ -112,7 +112,7 @@ class TheseusLayer(nn.Layer):
if not isinstance(layer_name_pattern, list): if not isinstance(layer_name_pattern, list):
layer_name_pattern = [layer_name_pattern] layer_name_pattern = [layer_name_pattern]
handle_res_dict = {} hit_layer_pattern_list = []
for pattern in layer_name_pattern: for pattern in layer_name_pattern:
# parse pattern to find target layer and its parent # parse pattern to find target layer and its parent
layer_list = parse_pattern_str(pattern=pattern, parent_layer=self) layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
...@@ -133,8 +133,8 @@ class TheseusLayer(nn.Layer): ...@@ -133,8 +133,8 @@ class TheseusLayer(nn.Layer):
else: else:
setattr(sub_layer_parent, sub_layer_name, new_sub_layer) setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
handle_res_dict[pattern] = new_sub_layer hit_layer_pattern_list.append(pattern)
return handle_res_dict return hit_layer_pattern_list
def stop_after(self, stop_layer_name: str) -> bool: def stop_after(self, stop_layer_name: str) -> bool:
"""stop forward and backward after 'stop_layer_name'. """stop forward and backward after 'stop_layer_name'.
...@@ -192,7 +192,7 @@ class TheseusLayer(nn.Layer): ...@@ -192,7 +192,7 @@ class TheseusLayer(nn.Layer):
handle_func = Handler(self.res_dict) handle_func = Handler(self.res_dict)
res_dict = self.upgrade_sublayer( hit_layer_pattern_list = self.upgrade_sublayer(
return_patterns, handle_func=handle_func) return_patterns, handle_func=handle_func)
if hasattr(self, "hook_remove_helper"): if hasattr(self, "hook_remove_helper"):
...@@ -200,7 +200,7 @@ class TheseusLayer(nn.Layer): ...@@ -200,7 +200,7 @@ class TheseusLayer(nn.Layer):
self.hook_remove_helper = self.register_forward_post_hook( self.hook_remove_helper = self.register_forward_post_hook(
self._return_dict_hook) self._return_dict_hook)
return res_dict return hit_layer_pattern_list
def set_identity(parent_layer: nn.Layer, def set_identity(parent_layer: nn.Layer,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册