From 8c1515b5db24f7d7ef2f1937bf041e3c4ee3ef56 Mon Sep 17 00:00:00 2001 From: weishengyu Date: Sun, 8 Aug 2021 00:37:53 +0800 Subject: [PATCH] add hook --- ppcls/arch/backbone/base/theseus_layer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index 28620820..7dbc3dbc 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -15,6 +15,7 @@ class TheseusLayer(nn.Layer): def __init__(self, *args, return_patterns=None, **kwargs): super(TheseusLayer, self).__init__() self.res_dict = None + self.register_forward_post_hook(self._disconnect_res_dict_hook) if return_patterns is not None: self._update_res(return_patterns) @@ -47,7 +48,8 @@ class TheseusLayer(nn.Layer): self._save_sub_res_hook) def _save_sub_res_hook(self, layer, input, output): - self.res_dict[layer.full_name()] = output + if self.res_dict is not None: + self.res_dict[layer.full_name()] = output def _disconnect_res_dict_hook(self, input, output): self.res_dict = None -- GitLab