From cf205e1379c0ffe52c20a6934b76976447ace1ca Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Mon, 20 Dec 2021 04:33:20 +0000 Subject: [PATCH] fix: fix result returned by _find_layers_handle --- ppcls/arch/backbone/base/theseus_layer.py | 80 +++++++++++++---------- 1 file changed, 44 insertions(+), 36 deletions(-) diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index 2b45745c..5d06385a 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -30,11 +30,13 @@ class TheseusLayer(nn.Layer): self.res_dict[self.res_name] = output def _find_layers_handle(self, patterns, handle_func): - sub_layers_dict = {} + handle_res_dict = {} for pattern in patterns: pattern_list = pattern.split(".") if not pattern_list: continue + + # find parent layer of sub-layer specified by pattern sub_layer_parent = self while len(pattern_list) > 1: if '[' in pattern_list[0]: @@ -52,7 +54,11 @@ class TheseusLayer(nn.Layer): sub_layer_parent = sub_layer_parent.sub_layer pattern_list = pattern_list[1:] if sub_layer_parent is None: + msg = f"Not found layer by name({pattern_list[0]}) specifed in pattern({pattern})." + logger.warning(msg) continue + + # find sub-layer specified by pattern if '[' in pattern_list[0]: sub_layer_name = pattern_list[0].split('[')[0] sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] @@ -68,13 +74,33 @@ class TheseusLayer(nn.Layer): sub_layer = wrap_theseus(sub_layer) setattr(sub_layer_parent, pattern_list[0], sub_layer) - sub_layers_dict[pattern] = sub_layer handle_res = handle_func(sub_layer, pattern) - return sub_layers_dict, handle_res + handle_res_dict[pattern] = handle_res + return handle_res_dict + + def _set_identity(self, layer, layer_name, layer_index=None): + stop_after = False + for sub_layer_name in layer._sub_layers: + if stop_after: + layer._sub_layers[sub_layer_name] = Identity() + continue + if sub_layer_name == layer_name: + stop_after = True + + if layer_index and stop_after: + stop_after = False + for sub_layer_index in layer._sub_layers[layer_name]._sub_layers: + if stop_after: + layer._sub_layers[layer_name][sub_layer_index] = Identity() + continue + if layer_index == sub_layer_index: + stop_after = True + + return stop_after def replace_sub(self, layer_name_pattern: Union[str, List[str]], - replace_function: Callable[[nn.Layer, str], Any]) -> bool: + replace_function: Callable[[nn.Layer, str], Any]) -> Any: """use 'replace_function' to modify the 'layer_name_pattern'. Args: @@ -83,24 +109,26 @@ class TheseusLayer(nn.Layer): Returns: bool: 'True' if successful, 'False' otherwise. - + Examples: + from paddle import nn import paddleclas - def replace_conv(origin_conv: nn.Conv2D): - new_conv = nn.Conv2D( - in_channels=origin_conv._in_channels, - out_channels=origin_conv._out_channels, - kernel_size=origin_conv._kernel_size, - stride=2 + def rep_func(warp_layer: nn.Layer, pattern: str): + sub_layer = warp_layer.sub_layer + new_layer = nn.Conv2D( + in_channels=sub_layer._in_channels, + out_channels=sub_layer._out_channels, + kernel_size=5 ) - return new_conv + warp_layer.sub_layer = new_layer + return True net = paddleclas.MobileNetV1() - tag = net.replace_sub(layer_name_pattern="conv", replace_function=replace_conv) - print(tag) - # True + res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], replace_function=rep_func) + print(res) + # {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True} """ if not isinstance(layer_name_pattern, list): @@ -108,26 +136,6 @@ class TheseusLayer(nn.Layer): return self._find_layers_handle( layer_name_pattern, handle_func=replace_function) - def _set_identity(self, layer, layer_name, layer_index=None): - stop_after = False - for sub_layer_name in layer._sub_layers: - if stop_after: - layer._sub_layers[sub_layer_name] = Identity() - continue - if sub_layer_name == layer_name: - stop_after = True - - if layer_index and stop_after: - stop_after = False - for sub_layer_index in layer._sub_layers[layer_name]._sub_layers: - if stop_after: - layer._sub_layers[layer_name][sub_layer_index] = Identity() - continue - if layer_index == sub_layer_index: - stop_after = True - - return stop_after - # stop doesn't work when stop layer has a parallel branch. def stop_after(self, stop_layer_name: str) -> bool: """stop forward and backward after 'stop_layer_name'. @@ -153,7 +161,7 @@ class TheseusLayer(nn.Layer): sub_layer_index = None layer = getattr(layer, sub_layer_name, None) if layer is None: - msg = f"Not found layer by name({pattern_list[0]}) in stop_layer_name({stop_layer_name})." + msg = f"Not found layer by name({pattern_list[0]}) specifed in stop_layer_name({stop_layer_name})." logger.warning(msg) return False -- GitLab