diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index 04871e48be179406312b1e0c5a1d36077c14193d..f328aed6f0a5109b9bb5e7e5ccc37438ff734d51 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -50,11 +50,11 @@ class TheseusLayer(nn.Layer): self._sub_layers[layer_i]._save_sub_res_hook) self._sub_layers[layer_i].res_dict = self.res_dict if isinstance(self._sub_layers[layer_i], TheseusLayer): + self._sub_layers[layer_i].res_dict = self.res_dict self._sub_layers[layer_i].update_res(return_patterns) def _save_sub_res_hook(self, layer, input, output): - if self.res_dict is not None: - self.res_dict[layer.full_name()] = output + self.res_dict[layer.full_name()] = output def replace_sub(self, layer_name_pattern, replace_function, recursive=True): for layer_i in self._sub_layers: @@ -109,7 +109,7 @@ class WrapLayer(TheseusLayer): for return_pattern in return_patterns: if re.match(return_pattern, layer_name): self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict - self._sub_layers[layer_i].register_forward_post_hook( + self.sub_layer._sub_layers[layer_i].register_forward_post_hook( self._sub_layers[layer_i]._save_sub_res_hook) if isinstance(self.sub_layer._sub_layers[layer_i], TheseusLayer): diff --git a/ppcls/arch/backbone/legendary_models/vgg.py b/ppcls/arch/backbone/legendary_models/vgg.py index c8a1693227f4e93781eee9dea44e1b6866c7f552..a45637d2a6a89bad081ca4452497e540872dfafe 100644 --- a/ppcls/arch/backbone/legendary_models/vgg.py +++ b/ppcls/arch/backbone/legendary_models/vgg.py @@ -152,7 +152,7 @@ class VGGNet(TheseusLayer): x = self.relu(x) x = self.drop(x) x = self.fc3(x) - if self.res_dict: + if self.res_dict and res_dict is not None: for res_key in list(self.res_dict): res_dict[res_key] = self.res_dict.pop(res_key) return x