diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index 0117def4cea0fabf559305dbc21fb6928d0d4230..a0dcd0999acf0d950e4edf1e772969a48f0bc57a 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -23,6 +23,7 @@ from . import backbone, gears from .backbone import * from .gears import build_gear from .utils import * +from ppcls.arch.backbone.base.theseus_layer import TheseusLayer from ppcls.utils import logger from ppcls.utils.save_load import load_dygraph_pretrain @@ -32,8 +33,11 @@ __all__ = ["build_model", "RecModel", "DistillationModel"] def build_model(config): config = copy.deepcopy(config) model_type = config.pop("name") + return_patterns = config.pop("return_patterns", None) mod = importlib.import_module(__name__) arch = getattr(mod, model_type)(**config) + if return_patterns is not None and isinstance(arch, TheseusLayer): + arch.update_res(return_patterns=return_patterns, return_dict=True) return arch @@ -55,7 +59,10 @@ class RecModel(nn.Layer): super().__init__() backbone_config = config["Backbone"] backbone_name = backbone_config.pop("name") + return_patterns = config.pop("return_patterns", None) self.backbone = eval(backbone_name)(**backbone_config) + if return_patterns is not None and isinstance(self.backbone, TheseusLayer): + self.backbone.update_res(return_patterns=return_patterns, return_dict=True) if "BackboneStopLayer" in config: backbone_stop_layer = config["BackboneStopLayer"]["name"] self.backbone.stop_after(backbone_stop_layer) diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index 29697e6509a953030a3a3086e51c4c9c9a712df0..35eac5f083bae1a371119ccf35c390441f0d1f8e 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -57,6 +57,12 @@ class TheseusLayer(nn.Layer): def _save_sub_res_hook(self, layer, input, output): self.res_dict[layer.full_name()] = output + def _return_dict_hook(self, layer, input, output): + res_dict = {"output": output} + for res_key in list(self.res_dict): + res_dict[res_key] = self.res_dict.pop(res_key) + return res_dict + def replace_sub(self, layer_name_pattern, replace_function, recursive=True): for layer_i in self._sub_layers: layer_name = self._sub_layers[layer_i].full_name() diff --git a/ppcls/arch/backbone/legendary_models/hrnet.py b/ppcls/arch/backbone/legendary_models/hrnet.py index 51ad4e4f51b7104de30962d72849c9e032229e67..7c4898a1387ec5346ba06ac0995ac1ac17c9624e 100644 --- a/ppcls/arch/backbone/legendary_models/hrnet.py +++ b/ppcls/arch/backbone/legendary_models/hrnet.py @@ -367,7 +367,7 @@ class HRNet(TheseusLayer): model: nn.Layer. Specific HRNet model depends on args. """ - def __init__(self, width=18, has_se=False, class_num=1000): + def __init__(self, width=18, has_se=False, class_num=1000, return_patterns=None): super().__init__() self.width = width @@ -456,8 +456,11 @@ class HRNet(TheseusLayer): 2048, class_num, weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) + if return_patterns is not None: + self.update_res(return_patterns) + self.register_forward_post_hook(self._return_dict_hook) - def forward(self, x, res_dict=None): + def forward(self, x): x = self.conv_layer1_1(x) x = self.conv_layer1_2(x) diff --git a/ppcls/arch/backbone/legendary_models/inception_v3.py b/ppcls/arch/backbone/legendary_models/inception_v3.py index b6403bbe6af0ffa15eab6a6548e30398a8054a2d..50fbcb4cbaef9bbc19ed65aeee181bfea6a599b5 100644 --- a/ppcls/arch/backbone/legendary_models/inception_v3.py +++ b/ppcls/arch/backbone/legendary_models/inception_v3.py @@ -454,7 +454,7 @@ class Inception_V3(TheseusLayer): model: nn.Layer. Specific Inception_V3 model depends on args. """ - def __init__(self, config, class_num=1000): + def __init__(self, config, class_num=1000, return_patterns=None): super().__init__() self.inception_a_list = config["inception_a"] @@ -496,6 +496,9 @@ class Inception_V3(TheseusLayer): class_num, weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)), bias_attr=ParamAttr()) + if return_patterns is not None: + self.update_res(return_patterns) + self.register_forward_post_hook(self._return_dict_hook) def forward(self, x): x = self.inception_stem(x) diff --git a/ppcls/arch/backbone/legendary_models/mobilenet_v1.py b/ppcls/arch/backbone/legendary_models/mobilenet_v1.py index bacac5b530a51d89e6a2062a5f20bff89900c381..944bdb14610ee90444f02c2bfd2dd11fb7d48b48 100644 --- a/ppcls/arch/backbone/legendary_models/mobilenet_v1.py +++ b/ppcls/arch/backbone/legendary_models/mobilenet_v1.py @@ -102,7 +102,7 @@ class MobileNet(TheseusLayer): model: nn.Layer. Specific MobileNet model depends on args. """ - def __init__(self, scale=1.0, class_num=1000): + def __init__(self, scale=1.0, class_num=1000, return_patterns=None): super().__init__() self.scale = scale @@ -145,16 +145,16 @@ class MobileNet(TheseusLayer): int(1024 * scale), class_num, weight_attr=ParamAttr(initializer=KaimingNormal())) + if return_patterns is not None: + self.update_res(return_patterns) + self.register_forward_post_hook(self._return_dict_hook) - def forward(self, x, res_dict=None): + def forward(self, x): x = self.conv(x) x = self.blocks(x) x = self.avg_pool(x) x = self.flatten(x) x = self.fc(x) - if self.res_dict and res_dict is not None: - for res_key in list(self.res_dict): - res_dict[res_key] = self.res_dict.pop(res_key) return x diff --git a/ppcls/arch/backbone/legendary_models/mobilenet_v3.py b/ppcls/arch/backbone/legendary_models/mobilenet_v3.py index aff69bcae1d5a67d5c14bb95c39af3ecca6e48a3..f39b81567b4a68de2ada9815aebbf29c07da0b06 100644 --- a/ppcls/arch/backbone/legendary_models/mobilenet_v3.py +++ b/ppcls/arch/backbone/legendary_models/mobilenet_v3.py @@ -142,7 +142,8 @@ class MobileNetV3(TheseusLayer): inplanes=STEM_CONV_NUMBER, class_squeeze=LAST_SECOND_CONV_LARGE, class_expand=LAST_CONV, - dropout_prob=0.2): + dropout_prob=0.2, + return_patterns=None): super().__init__() self.cfg = config @@ -199,6 +200,9 @@ class MobileNetV3(TheseusLayer): self.flatten = nn.Flatten(start_axis=1, stop_axis=-1) self.fc = Linear(self.class_expand, class_num) + if return_patterns is not None: + self.update_res(return_patterns) + self.register_forward_post_hook(self._return_dict_hook) def forward(self, x): x = self.conv(x) diff --git a/ppcls/arch/backbone/legendary_models/resnet.py b/ppcls/arch/backbone/legendary_models/resnet.py index 5417e2d970e74a690ada1c5a4f3d7fdedf238d11..4f79c0d75f06e6b7e188f9311902fe3397823c8f 100644 --- a/ppcls/arch/backbone/legendary_models/resnet.py +++ b/ppcls/arch/backbone/legendary_models/resnet.py @@ -269,7 +269,8 @@ class ResNet(TheseusLayer): class_num=1000, lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0], data_format="NCHW", - input_image_channel=3): + input_image_channel=3, + return_patterns=None): super().__init__() self.cfg = config @@ -337,6 +338,9 @@ class ResNet(TheseusLayer): weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) self.data_format = data_format + if return_patterns is not None: + self.update_res(return_patterns) + self.register_forward_post_hook(self._return_dict_hook) def forward(self, x): with paddle.static.amp.fp16_guard(): diff --git a/ppcls/arch/backbone/legendary_models/vgg.py b/ppcls/arch/backbone/legendary_models/vgg.py index a898092981ffd31e45e1317a880c3917246d6bba..9b1750d54c2973cc08a3eb54575b53f61ff9891e 100644 --- a/ppcls/arch/backbone/legendary_models/vgg.py +++ b/ppcls/arch/backbone/legendary_models/vgg.py @@ -111,7 +111,7 @@ class VGGNet(TheseusLayer): model: nn.Layer. Specific VGG model depends on args. """ - def __init__(self, config, stop_grad_layers=0, class_num=1000): + def __init__(self, config, stop_grad_layers=0, class_num=1000, return_patterns=None): super().__init__() self.stop_grad_layers = stop_grad_layers @@ -137,8 +137,11 @@ class VGGNet(TheseusLayer): self.fc1 = Linear(7 * 7 * 512, 4096) self.fc2 = Linear(4096, 4096) self.fc3 = Linear(4096, class_num) + if return_patterns is not None: + self.update_res(return_patterns) + self.register_forward_post_hook(self._return_dict_hook) - def forward(self, inputs, res_dict=None): + def forward(self, inputs): x = self.conv_block_1(inputs) x = self.conv_block_2(x) x = self.conv_block_3(x) @@ -152,9 +155,6 @@ class VGGNet(TheseusLayer): x = self.relu(x) x = self.drop(x) x = self.fc3(x) - if self.res_dict and res_dict is not None: - for res_key in list(self.res_dict): - res_dict[res_key] = self.res_dict.pop(res_key) return x