提交 71cb728b 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: rename to upgrade_sublayer()

上级 1e696ac2
......@@ -45,15 +45,14 @@ class TheseusLayer(nn.Layer):
self.res_dict[self.res_name] = output
def replace_sub(self, *args, **kwargs) -> None:
msg = "The function 'replace_sub()' is deprecated, please use 'layer_wrench()' instead."
msg = "The function 'replace_sub()' is deprecated, please use 'upgrade_sublayer()' instead."
logger.error(DeprecationWarning(msg))
raise DeprecationWarning(msg)
# TODO(gaotingquan): what is a good name?
def layer_wrench(self,
layer_name_pattern: Union[str, List[str]],
handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[
str, nn.Layer]:
def upgrade_sublayer(self,
layer_name_pattern: Union[str, List[str]],
handle_func: Callable[[nn.Layer, str], nn.Layer]
) -> Dict[str, nn.Layer]:
"""use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.
Args:
......@@ -88,7 +87,7 @@ class TheseusLayer(nn.Layer):
handle_res_dict = {}
for pattern in layer_name_pattern:
# parse pattern to find target layer and its parent
# parse pattern to find target layer and its parent
layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
if not layer_list:
continue
......@@ -166,7 +165,8 @@ class TheseusLayer(nn.Layer):
handle_func = Handler(self.res_dict)
res_dict = self.layer_wrench(return_patterns, handle_func=handle_func)
res_dict = self.upgrade_sublayer(
return_patterns, handle_func=handle_func)
if hasattr(self, "hook_remove_helper"):
self.hook_remove_helper.remove()
......@@ -221,10 +221,10 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[
parent_layer (nn.Layer): The root layer relative to the pattern.
Returns:
Union[None, List[Dict[str, Union[nn.Layer, str, None]]]]: None if failed. If successfully, the members are layers parsed in order:
[
{"layer": first layer, "name": first layer's name parsed, "index": first layer's index parsed if exist},
{"layer": second layer, "name": second layer's name parsed, "index": second layer's index parsed if exist},
Union[None, List[Dict[str, Union[nn.Layer, str, None]]]]: None if failed. If successfully, the members are layers parsed in order:
[
{"layer": first layer, "name": first layer's name parsed, "index": first layer's index parsed if exist},
{"layer": second layer, "name": second layer's name parsed, "index": second layer's index parsed if exist},
...
]
"""
......
......@@ -5,7 +5,7 @@ __all__ = ["ResNet50_last_stage_stride1"]
def ResNet50_last_stage_stride1(pretrained=False, use_ssld=False, **kwargs):
def replace_function(conv):
def replace_function(conv, pattern):
new_conv = Conv2D(
in_channels=conv._in_channels,
out_channels=conv._out_channels,
......@@ -16,8 +16,8 @@ def ResNet50_last_stage_stride1(pretrained=False, use_ssld=False, **kwargs):
bias_attr=conv._bias_attr)
return new_conv
match_re = "conv2d_4[4|6]"
pattern = ["blocks[13].conv1.conv", "blocks[13].short.conv"]
model = ResNet50(pretrained=False, use_ssld=use_ssld, **kwargs)
model.replace_sub(match_re, replace_function, True)
model.upgrade_sublayer(pattern, replace_function)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld)
return model
import paddle
from paddle.nn import Sigmoid
from ppcls.arch.backbone.legendary_models.vgg import VGG19
__all__ = ["VGG19Sigmoid"]
class SigmoidSuffix(paddle.nn.Layer):
def __init__(self, origin_layer):
super(SigmoidSuffix, self).__init__()
super().__init__()
self.origin_layer = origin_layer
self.sigmoid = Sigmoid()
def forward(self, input, res_dict=None, **kwargs):
x = self.origin_layer(input)
x = self.sigmoid(x)
return x
def VGG19Sigmoid(pretrained=False, use_ssld=False, **kwargs):
def replace_function(origin_layer):
def replace_function(origin_layer, pattern):
new_layer = SigmoidSuffix(origin_layer)
return new_layer
match_re = "linear_2"
pattern = "fc2"
model = VGG19(pretrained=pretrained, use_ssld=use_ssld, **kwargs)
model.replace_sub(match_re, replace_function, True)
model.upgrade_sublayer(pattern, replace_function)
return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册