From 2cb3bf66dd2d6dd1d612a98daa6710501f35de26 Mon Sep 17 00:00:00 2001 From: weishengyu Date: Sun, 8 Aug 2021 15:55:34 +0800 Subject: [PATCH] update wrap theseus rule --- ppcls/arch/backbone/base/theseus_layer.py | 55 +++++++++++++---------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index 2efe72ed..ae0075bb 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -33,21 +33,24 @@ class TheseusLayer(nn.Layer): return after_stop def update_res(self, return_patterns): - if not return_patterns: + if not return_patterns or isinstance(self, WrapLayer): return for layer_i in self._sub_layers: - if isinstance(self._sub_layers[layer_i], (nn.Sequential, nn.LayerList)): - self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i], return_patterns) layer_name = self._sub_layers[layer_i].full_name() - for return_pattern in return_patterns: - if re.match(return_pattern, layer_name): - if not isinstance(self._sub_layers[layer_i], TheseusLayer) and not isinstance(self, WrapLayer): - self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i], return_patterns) - self._sub_layers[layer_i].register_forward_post_hook( - self._sub_layers[layer_i]._save_sub_res_hook) - self._sub_layers[layer_i].res_dict = self.res_dict - if isinstance(self._sub_layers[layer_i], TheseusLayer): + if isinstance(self._sub_layers[layer_i], (nn.Sequential, nn.LayerList)): + self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i]) + self._sub_layers[layer_i].res_dict = self.res_dict self._sub_layers[layer_i].update_res(return_patterns) + else: + for return_pattern in return_patterns: + if re.match(return_pattern, layer_name): + if not isinstance(self._sub_layers[layer_i], TheseusLayer): + self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i]) + self._sub_layers[layer_i].register_forward_post_hook( + self._sub_layers[layer_i]._save_sub_res_hook) + self._sub_layers[layer_i].res_dict = self.res_dict + if isinstance(self._sub_layers[layer_i], TheseusLayer): + self._sub_layers[layer_i].update_res(return_patterns) def _save_sub_res_hook(self, layer, input, output): if self.res_dict is not None: @@ -93,18 +96,24 @@ class WrapLayer(TheseusLayer): def forward(self, *inputs, **kwargs): self.sub_layer(*inputs, **kwargs) + def update_res(self, return_patterns): + if not return_patterns or not isinstance(self.sub_layer, (nn.Sequential, nn.LayerList)): + return + for layer_i in self.sub_layer._sub_layers: + if isinstance(self.sub_layer._sub_layers[layer_i], (nn.Sequential, nn.LayerList)): + self.sub_layer._sub_layers[layer_i] = wrap_theseus(self.sub_layer._sub_layers[layer_i]) + self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict + self.sub_layer._sub_layers[layer_i].update_res(return_patterns) + + layer_name = self.sub_layer._sub_layers[layer_i].full_name() + for return_pattern in return_patterns: + if re.match(return_pattern, layer_name): + self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict -def wrap_theseus(sub_layer, return_patterns): - if isinstance(sub_layer, (nn.Sequential, nn.LayerList)): - for layer_i in sub_layer._sub_layers: - if isinstance(sub_layer._sub_layers[layer_i], TheseusLayer): - continue - elif isinstance(sub_layer._sub_layers[layer_i], (nn.Sequential, nn.LayerList)): - wrap_theseus(sub_layer._sub_layers[layer_i], return_patterns) - elif isinstance(sub_layer._sub_layers[layer_i], nn.Layer): - layer_name = sub_layer._sub_layers[layer_i].full_name() - for return_pattern in return_patterns: - if re.match(return_pattern, layer_name): - wrap_theseus(sub_layer._sub_layers[layer_i], return_patterns) + if isinstance(self.sub_layer._sub_layers[layer_i], TheseusLayer): + self.sub_layer._sub_layers[layer_i].update_res(return_patterns) + + +def wrap_theseus(sub_layer): wrapped_layer = WrapLayer(sub_layer) return wrapped_layer -- GitLab