diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index 8ef7913f89a55586698ffca985dc3eb61ed4ff24..f328aed6f0a5109b9bb5e7e5ccc37438ff734d51 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -12,15 +12,9 @@ class Identity(nn.Layer): class TheseusLayer(nn.Layer): - def __init__(self, *args, return_patterns=None, **kwargs): + def __init__(self, *args, **kwargs): super(TheseusLayer, self).__init__() - self.res_dict = None - if return_patterns is not None: - self._update_res(return_patterns) - - def forward(self, *input, res_dict=None, **kwargs): - if res_dict is not None: - self.res_dict = res_dict + self.res_dict = {} # stop doesn't work when stop layer has a parallel branch. def stop_after(self, stop_layer_name: str): @@ -38,33 +32,43 @@ class TheseusLayer(nn.Layer): stop_layer_name) return after_stop - def _update_res(self, return_layers): + def update_res(self, return_patterns): + if not return_patterns or isinstance(self, WrapLayer): + return + for layer_i in self._sub_layers: + layer_name = self._sub_layers[layer_i].full_name() + if isinstance(self._sub_layers[layer_i], (nn.Sequential, nn.LayerList)): + self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i]) + self._sub_layers[layer_i].res_dict = self.res_dict + self._sub_layers[layer_i].update_res(return_patterns) + else: + for return_pattern in return_patterns: + if re.match(return_pattern, layer_name): + if not isinstance(self._sub_layers[layer_i], TheseusLayer): + self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i]) + self._sub_layers[layer_i].register_forward_post_hook( + self._sub_layers[layer_i]._save_sub_res_hook) + self._sub_layers[layer_i].res_dict = self.res_dict + if isinstance(self._sub_layers[layer_i], TheseusLayer): + self._sub_layers[layer_i].res_dict = self.res_dict + self._sub_layers[layer_i].update_res(return_patterns) + + def _save_sub_res_hook(self, layer, input, output): + self.res_dict[layer.full_name()] = output + + def replace_sub(self, layer_name_pattern, replace_function, recursive=True): for layer_i in self._sub_layers: layer_name = self._sub_layers[layer_i].full_name() - for return_pattern in return_layers: - if return_layers is not None and re.match(return_pattern, - layer_name): - self._sub_layers[layer_i].register_forward_post_hook( - self._save_sub_res_hook) - - def replace_sub(self, layer_name_pattern, replace_function, - recursive=True): - for k in self._sub_layers.keys(): - layer_name = self._sub_layers[k].full_name() if re.match(layer_name_pattern, layer_name): - self._sub_layers[k] = replace_function(self._sub_layers[k]) + self._sub_layers[layer_i] = replace_function(self._sub_layers[layer_i]) if recursive: - if isinstance(self._sub_layers[k], TheseusLayer): - self._sub_layers[k].replace_sub( + if isinstance(self._sub_layers[layer_i], TheseusLayer): + self._sub_layers[layer_i].replace_sub( layer_name_pattern, replace_function, recursive) - elif isinstance(self._sub_layers[k], - nn.Sequential) or isinstance( - self._sub_layers[k], nn.LayerList): - for kk in self._sub_layers[k]._sub_layers.keys(): - self._sub_layers[k]._sub_layers[kk].replace_sub( + elif isinstance(self._sub_layers[layer_i], (nn.Sequential, nn.LayerList)): + for layer_j in self._sub_layers[layer_i]._sub_layers: + self._sub_layers[layer_i]._sub_layers[layer_j].replace_sub( layer_name_pattern, replace_function, recursive) - else: - pass ''' example of replace function: @@ -78,3 +82,40 @@ class TheseusLayer(nn.Layer): return new_conv ''' + + +class WrapLayer(TheseusLayer): + def __init__(self, sub_layer): + super(WrapLayer, self).__init__() + self.sub_layer = sub_layer + self.name = sub_layer.full_name() + + def full_name(self): + return self.name + + def forward(self, *inputs, **kwargs): + return self.sub_layer(*inputs, **kwargs) + + def update_res(self, return_patterns): + if not return_patterns or not isinstance(self.sub_layer, (nn.Sequential, nn.LayerList)): + return + for layer_i in self.sub_layer._sub_layers: + if isinstance(self.sub_layer._sub_layers[layer_i], (nn.Sequential, nn.LayerList)): + self.sub_layer._sub_layers[layer_i] = wrap_theseus(self.sub_layer._sub_layers[layer_i]) + self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict + self.sub_layer._sub_layers[layer_i].update_res(return_patterns) + + layer_name = self.sub_layer._sub_layers[layer_i].full_name() + for return_pattern in return_patterns: + if re.match(return_pattern, layer_name): + self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict + self.sub_layer._sub_layers[layer_i].register_forward_post_hook( + self._sub_layers[layer_i]._save_sub_res_hook) + + if isinstance(self.sub_layer._sub_layers[layer_i], TheseusLayer): + self.sub_layer._sub_layers[layer_i].update_res(return_patterns) + + +def wrap_theseus(sub_layer): + wrapped_layer = WrapLayer(sub_layer) + return wrapped_layer diff --git a/ppcls/arch/backbone/legendary_models/vgg.py b/ppcls/arch/backbone/legendary_models/vgg.py index 7868b51eafce4f0bd383ad66199e50f2a05c1832..a45637d2a6a89bad081ca4452497e540872dfafe 100644 --- a/ppcls/arch/backbone/legendary_models/vgg.py +++ b/ppcls/arch/backbone/legendary_models/vgg.py @@ -111,7 +111,7 @@ class VGGNet(TheseusLayer): model: nn.Layer. Specific VGG model depends on args. """ - def __init__(self, config, stop_grad_layers=0, class_num=1000): + def __init__(self, config, stop_grad_layers=0, class_num=1000, return_patterns=None): super().__init__() self.stop_grad_layers = stop_grad_layers @@ -138,7 +138,7 @@ class VGGNet(TheseusLayer): self.fc2 = Linear(4096, 4096) self.fc3 = Linear(4096, class_num) - def forward(self, inputs): + def forward(self, inputs, res_dict=None): x = self.conv_block_1(inputs) x = self.conv_block_2(x) x = self.conv_block_3(x) @@ -152,6 +152,9 @@ class VGGNet(TheseusLayer): x = self.relu(x) x = self.drop(x) x = self.fc3(x) + if self.res_dict and res_dict is not None: + for res_key in list(self.res_dict): + res_dict[res_key] = self.res_dict.pop(res_key) return x diff --git a/ppcls/engine/trainer.py b/ppcls/engine/trainer.py index 451531c1d1e6ca59e2addc1add752649e05f1e67..14b7547ddca0922aa915f4dbdb7fff9e2ad1c4a6 100644 --- a/ppcls/engine/trainer.py +++ b/ppcls/engine/trainer.py @@ -588,7 +588,7 @@ class Trainer(object): if len(batch) == 3: has_unique_id = True batch[2] = batch[2].reshape([-1, 1]).astype("int64") - out = self.model(batch[0], batch[1]) + out = self.forward(batch) batch_feas = out["features"] # do norm @@ -653,7 +653,7 @@ class Trainer(object): image_file_list.append(image_file) if len(batch_data) >= batch_size or idx == len(image_list) - 1: batch_tensor = paddle.to_tensor(batch_data) - out = self.model(batch_tensor) + out = self.forward([batch_tensor]) if isinstance(out, list): out = out[0] result = postprocess_func(out, image_file_list)