提交 8c1515b5 编写于 作者: W weishengyu

add hook

上级 0040a973
...@@ -15,6 +15,7 @@ class TheseusLayer(nn.Layer): ...@@ -15,6 +15,7 @@ class TheseusLayer(nn.Layer):
def __init__(self, *args, return_patterns=None, **kwargs): def __init__(self, *args, return_patterns=None, **kwargs):
super(TheseusLayer, self).__init__() super(TheseusLayer, self).__init__()
self.res_dict = None self.res_dict = None
self.register_forward_post_hook(self._disconnect_res_dict_hook)
if return_patterns is not None: if return_patterns is not None:
self._update_res(return_patterns) self._update_res(return_patterns)
...@@ -47,6 +48,7 @@ class TheseusLayer(nn.Layer): ...@@ -47,6 +48,7 @@ class TheseusLayer(nn.Layer):
self._save_sub_res_hook) self._save_sub_res_hook)
def _save_sub_res_hook(self, layer, input, output): def _save_sub_res_hook(self, layer, input, output):
if self.res_dict is not None:
self.res_dict[layer.full_name()] = output self.res_dict[layer.full_name()] = output
def _disconnect_res_dict_hook(self, input, output): def _disconnect_res_dict_hook(self, input, output):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册