diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index 40f5d317140036f21c265681b0f64a4e687b0bfb..fb06c183b82deef21157aaab52ea52386d6d8e8b 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -29,10 +29,16 @@ class TheseusLayer(nn.Layer): def _save_sub_res_hook(self, layer, input, output): self.res_dict[self.res_name] = output - def replace_sub(self, - layer_name_pattern: Union[str, List[str]], - handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[ - str, nn.Layer]: + def replace_sub(self, *args, **kwargs) -> None: + msg = "\"replace_sub\" is deprecated, please use \"layer_wrench\" instead." + logger.error(DeprecationWarning(msg)) + raise DeprecationWarning(msg) + + # TODO(gaotingquan): what is a good name? + def layer_wrench(self, + layer_name_pattern: Union[str, List[str]], + handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[ + str, nn.Layer]: """use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'. Args: @@ -59,7 +65,7 @@ class TheseusLayer(nn.Layer): net = paddleclas.MobileNetV1() res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func) print(res) - # {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True} + # {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer} """ if not isinstance(layer_name_pattern, list):