提交 43fce425 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: change the returned result from dict to list

上级 3b9f6292
......@@ -112,7 +112,7 @@ class TheseusLayer(nn.Layer):
if not isinstance(layer_name_pattern, list):
layer_name_pattern = [layer_name_pattern]
handle_res_dict = {}
hit_layer_pattern_list = []
for pattern in layer_name_pattern:
# parse pattern to find target layer and its parent
layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
......@@ -133,8 +133,8 @@ class TheseusLayer(nn.Layer):
else:
setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
handle_res_dict[pattern] = new_sub_layer
return handle_res_dict
hit_layer_pattern_list.append(pattern)
return hit_layer_pattern_list
def stop_after(self, stop_layer_name: str) -> bool:
"""stop forward and backward after 'stop_layer_name'.
......@@ -192,7 +192,7 @@ class TheseusLayer(nn.Layer):
handle_func = Handler(self.res_dict)
res_dict = self.upgrade_sublayer(
hit_layer_pattern_list = self.upgrade_sublayer(
return_patterns, handle_func=handle_func)
if hasattr(self, "hook_remove_helper"):
......@@ -200,7 +200,7 @@ class TheseusLayer(nn.Layer):
self.hook_remove_helper = self.register_forward_post_hook(
self._return_dict_hook)
return res_dict
return hit_layer_pattern_list
def set_identity(parent_layer: nn.Layer,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册