提交 56911b57 编写于 作者: G gaotingquan 提交者: Tingquan Gao

refactor: rename replace_sub() func to ?

上级 0f126b75
......@@ -29,10 +29,16 @@ class TheseusLayer(nn.Layer):
def _save_sub_res_hook(self, layer, input, output):
self.res_dict[self.res_name] = output
def replace_sub(self,
layer_name_pattern: Union[str, List[str]],
handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[
str, nn.Layer]:
def replace_sub(self, *args, **kwargs) -> None:
msg = "\"replace_sub\" is deprecated, please use \"layer_wrench\" instead."
logger.error(DeprecationWarning(msg))
raise DeprecationWarning(msg)
# TODO(gaotingquan): what is a good name?
def layer_wrench(self,
layer_name_pattern: Union[str, List[str]],
handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[
str, nn.Layer]:
"""use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.
Args:
......@@ -59,7 +65,7 @@ class TheseusLayer(nn.Layer):
net = paddleclas.MobileNetV1()
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
print(res)
# {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
# {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
"""
if not isinstance(layer_name_pattern, list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册