diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index a69c0acca10952e926a022f547ed61a302394579..b9842c7e852cd399457f7c6f675875e9be15959c 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -42,8 +42,9 @@ class TheseusLayer(nn.Layer): layer_name = self._sub_layers[layer_i].full_name() for return_pattern in return_patterns: if return_patterns is not None and re.match(return_pattern, layer_name): - self._sub_layers[layer_i].register_forward_post_hook( - self._save_sub_res_hook) + if isinstance(self._sub_layers[layer_i], TheseusLayer): + self._sub_layers[layer_i].register_forward_post_hook( + self._sub_layers[layer_i]._save_sub_res_hook) if isinstance(self._sub_layers[layer_i], TheseusLayer): self._sub_layers[layer_i]._update_res(return_patterns) diff --git a/ppcls/arch/backbone/legendary_models/vgg.py b/ppcls/arch/backbone/legendary_models/vgg.py index e70dfa6a72dee24a21103f8ac3417ea0eaf8a256..7d88fe9c628ab88b5cc66338610c587dc7788f5b 100644 --- a/ppcls/arch/backbone/legendary_models/vgg.py +++ b/ppcls/arch/backbone/legendary_models/vgg.py @@ -84,7 +84,8 @@ class ConvBlock(TheseusLayer): self.max_pool = MaxPool2D(kernel_size=2, stride=2, padding=0) self.relu = nn.ReLU() - def forward(self, inputs): + def forward(self, inputs, res_dict=None): + super(ConvBlock, self).forward(inputs, res_dict) x = self.conv1(inputs) x = self.relu(x) if self.groups == 2 or self.groups == 3 or self.groups == 4: @@ -140,11 +141,12 @@ class VGGNet(TheseusLayer): self._update_res(return_patterns) def forward(self, inputs, res_dict=None): - x = self.conv_block_1(inputs) - x = self.conv_block_2(x) - x = self.conv_block_3(x) - x = self.conv_block_4(x) - x = self.conv_block_5(x) + super(VGGNet, self).forward(inputs, res_dict=res_dict) + x = self.conv_block_1(inputs, res_dict) + x = self.conv_block_2(x, res_dict) + x = self.conv_block_3(x, res_dict) + x = self.conv_block_4(x, res_dict) + x = self.conv_block_5(x, res_dict) x = self.flatten(x) x = self.fc1(x) x = self.relu(x)