diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index b4f60f36bfaa87fa3df2b7ce6ca7263c52188870..de00c2a2a18f335ffa1ed043283a8ea6e5bfe80f 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -47,3 +47,4 @@ from ppcls.arch.backbone.model_zoo.distillation_models import ResNet50_vd_distil from ppcls.arch.backbone.model_zoo.swin_transformer import SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384, SwinTransformer_large_patch4_window7_224, SwinTransformer_large_patch4_window12_384 from ppcls.arch.backbone.model_zoo.mixnet import MixNet_S, MixNet_M, MixNet_L from ppcls.arch.backbone.model_zoo.rexnet import ReXNet_1_0, ReXNet_1_3, ReXNet_1_5, ReXNet_2_0, ReXNet_3_0 +from ppcls.arch.backbone.variant_models.resnet_variant import ResNet50_last_stage_stride1 diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index 58965f83be1c7f9c0210f233baec2f88867c909b..fe0fac15e4457c941f3f8aa3b67858c0e2904e97 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -35,15 +35,18 @@ class TheseusLayer(nn.Layer): after_stop = True continue if isinstance(self._sub_layers[layer_i], TheseusLayer): - after_stop = self._sub_layers[layer_i].stop_after(stop_layer_name) + after_stop = self._sub_layers[layer_i].stop_after( + stop_layer_name) return after_stop def _update_res(self, return_layers): for layer_i in self._sub_layers: layer_name = self._sub_layers[layer_i].full_name() for return_pattern in return_layers: - if return_layers is not None and re.match(return_pattern, layer_name): - self._sub_layers[layer_i].register_forward_post_hook(self._save_sub_res_hook) + if return_layers is not None and re.match(return_pattern, + layer_name): + self._sub_layers[layer_i].register_forward_post_hook( + self._save_sub_res_hook) # def _save_sub_res_hook(self, layer, input, output): # self.res_dict[layer.full_name()] = output @@ -51,13 +54,24 @@ class TheseusLayer(nn.Layer): # def _disconnect_res_dict_hook(self, input, output): # self.res_dict = None - def replace_sub(self, layer_name_pattern, replace_function, recursive=True): - for layer_i in self._sub_layers: - layer_name = self._sub_layers[layer_i].full_name() + def replace_sub(self, layer_name_pattern, replace_function, + recursive=True): + for k in self._sub_layers.keys(): + layer_name = self._sub_layers[k].full_name() if re.match(layer_name_pattern, layer_name): - self._sub_layers[layer_i] = replace_function(self._sub_layers[layer_i]) - if recursive and isinstance(self._sub_layers[layer_i], TheseusLayer): - self._sub_layers[layer_i].replace_sub(layer_name_pattern, replace_function, recursive) + self._sub_layers[k] = replace_function(self._sub_layers[k]) + if recursive: + if isinstance(self._sub_layers[k], TheseusLayer): + self._sub_layers[k].replace_sub( + layer_name_pattern, replace_function, recursive) + elif isinstance(self._sub_layers[k], + nn.Sequential) or isinstance( + self._sub_layers[k], nn.LayerList): + for kk in self._sub_layers[k]._sub_layers.keys(): + self._sub_layers[k]._sub_layers[kk].replace_sub( + layer_name_pattern, replace_function, recursive) + else: + pass ''' example of replace function: @@ -70,4 +84,4 @@ class TheseusLayer(nn.Layer): ) return new_conv - ''' \ No newline at end of file + ''' diff --git a/ppcls/arch/backbone/variant_models/__init__.py b/ppcls/arch/backbone/variant_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e87a0e098cf19e4066c39c060d174512b64451a2 --- /dev/null +++ b/ppcls/arch/backbone/variant_models/__init__.py @@ -0,0 +1 @@ +from .resnet_variant import ResNet50_last_stage_stride1 diff --git a/ppcls/arch/backbone/variant_models/resnet_variant.py b/ppcls/arch/backbone/variant_models/resnet_variant.py new file mode 100644 index 0000000000000000000000000000000000000000..81eb71bb8d73017e1a555d496f23a1dbb02088f9 --- /dev/null +++ b/ppcls/arch/backbone/variant_models/resnet_variant.py @@ -0,0 +1,22 @@ +from paddle.nn import Conv2D +from ppcls.arch.backbone.legendary_models.resnet import ResNet50 + +__all__ = ["ResNet50_last_stage_stride1"] + + +def ResNet50_last_stage_stride1(pretrained=False, use_ssld=False, **kwargs): + def replace_function(conv): + new_conv = Conv2D( + in_channels=conv._in_channels, + out_channels=conv._out_channels, + kernel_size=conv._kernel_size, + stride=1, + padding=conv._padding, + groups=conv._groups, + bias_attr=conv._bias_attr) + return new_conv + + match_re = "conv2d_4[4|6]" + model = ResNet50(pretrained=pretrained, use_ssld=use_ssld, **kwargs) + model.replace_sub(match_re, replace_function, True) + return model