diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index 41599cd99e088015eda54304af328ae1ead6dd16..04871e48be179406312b1e0c5a1d36077c14193d 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -109,6 +109,8 @@ class WrapLayer(TheseusLayer): for return_pattern in return_patterns: if re.match(return_pattern, layer_name): self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict + self._sub_layers[layer_i].register_forward_post_hook( + self._sub_layers[layer_i]._save_sub_res_hook) if isinstance(self.sub_layer._sub_layers[layer_i], TheseusLayer): self.sub_layer._sub_layers[layer_i].update_res(return_patterns) diff --git a/ppcls/arch/backbone/legendary_models/vgg.py b/ppcls/arch/backbone/legendary_models/vgg.py index fbfdaca0366d87bcd7bce85769c1d4866c2926c0..c8a1693227f4e93781eee9dea44e1b6866c7f552 100644 --- a/ppcls/arch/backbone/legendary_models/vgg.py +++ b/ppcls/arch/backbone/legendary_models/vgg.py @@ -137,7 +137,6 @@ class VGGNet(TheseusLayer): self.fc1 = Linear(7 * 7 * 512, 4096) self.fc2 = Linear(4096, 4096) self.fc3 = Linear(4096, class_num) - self.update_res(return_patterns) def forward(self, inputs, res_dict=None): x = self.conv_block_1(inputs) diff --git a/ppcls/engine/trainer.py b/ppcls/engine/trainer.py index eb588ed0208a571929c32323dfbdb77d8db99345..0fcb9c3241268cf730457f0c6736d225e0c663b2 100644 --- a/ppcls/engine/trainer.py +++ b/ppcls/engine/trainer.py @@ -77,6 +77,7 @@ class Trainer(object): self.model = build_model(self.config["Arch"]) if "return_patterns" in self.config["Arch"] and isinstance(self.model, TheseusLayer): + self.model.update_res(self.config["Arch"]["return_patterns"]) self.return_inter = True else: self.return_inter = False