From 56911b573b1b580f009d7232d78b58cf699cd6c0 Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Tue, 21 Dec 2021 14:12:43 +0000 Subject: [PATCH] refactor: rename replace_sub() func to ? --- ppcls/arch/backbone/base/theseus_layer.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index 40f5d317..fb06c183 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -29,10 +29,16 @@ class TheseusLayer(nn.Layer): def _save_sub_res_hook(self, layer, input, output): self.res_dict[self.res_name] = output - def replace_sub(self, - layer_name_pattern: Union[str, List[str]], - handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[ - str, nn.Layer]: + def replace_sub(self, *args, **kwargs) -> None: + msg = "\"replace_sub\" is deprecated, please use \"layer_wrench\" instead." + logger.error(DeprecationWarning(msg)) + raise DeprecationWarning(msg) + + # TODO(gaotingquan): what is a good name? + def layer_wrench(self, + layer_name_pattern: Union[str, List[str]], + handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[ + str, nn.Layer]: """use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'. Args: @@ -59,7 +65,7 @@ class TheseusLayer(nn.Layer): net = paddleclas.MobileNetV1() res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func) print(res) - # {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True} + # {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer} """ if not isinstance(layer_name_pattern, list): -- GitLab