提交 71cb728b 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: rename to upgrade_sublayer()

上级 1e696ac2
...@@ -45,15 +45,14 @@ class TheseusLayer(nn.Layer): ...@@ -45,15 +45,14 @@ class TheseusLayer(nn.Layer):
self.res_dict[self.res_name] = output self.res_dict[self.res_name] = output
def replace_sub(self, *args, **kwargs) -> None: def replace_sub(self, *args, **kwargs) -> None:
msg = "The function 'replace_sub()' is deprecated, please use 'layer_wrench()' instead." msg = "The function 'replace_sub()' is deprecated, please use 'upgrade_sublayer()' instead."
logger.error(DeprecationWarning(msg)) logger.error(DeprecationWarning(msg))
raise DeprecationWarning(msg) raise DeprecationWarning(msg)
# TODO(gaotingquan): what is a good name? def upgrade_sublayer(self,
def layer_wrench(self, layer_name_pattern: Union[str, List[str]],
layer_name_pattern: Union[str, List[str]], handle_func: Callable[[nn.Layer, str], nn.Layer]
handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[ ) -> Dict[str, nn.Layer]:
str, nn.Layer]:
"""use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'. """use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.
Args: Args:
...@@ -88,7 +87,7 @@ class TheseusLayer(nn.Layer): ...@@ -88,7 +87,7 @@ class TheseusLayer(nn.Layer):
handle_res_dict = {} handle_res_dict = {}
for pattern in layer_name_pattern: for pattern in layer_name_pattern:
# parse pattern to find target layer and its parent # parse pattern to find target layer and its parent
layer_list = parse_pattern_str(pattern=pattern, parent_layer=self) layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
if not layer_list: if not layer_list:
continue continue
...@@ -166,7 +165,8 @@ class TheseusLayer(nn.Layer): ...@@ -166,7 +165,8 @@ class TheseusLayer(nn.Layer):
handle_func = Handler(self.res_dict) handle_func = Handler(self.res_dict)
res_dict = self.layer_wrench(return_patterns, handle_func=handle_func) res_dict = self.upgrade_sublayer(
return_patterns, handle_func=handle_func)
if hasattr(self, "hook_remove_helper"): if hasattr(self, "hook_remove_helper"):
self.hook_remove_helper.remove() self.hook_remove_helper.remove()
...@@ -221,10 +221,10 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[ ...@@ -221,10 +221,10 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[
parent_layer (nn.Layer): The root layer relative to the pattern. parent_layer (nn.Layer): The root layer relative to the pattern.
Returns: Returns:
Union[None, List[Dict[str, Union[nn.Layer, str, None]]]]: None if failed. If successfully, the members are layers parsed in order: Union[None, List[Dict[str, Union[nn.Layer, str, None]]]]: None if failed. If successfully, the members are layers parsed in order:
[ [
{"layer": first layer, "name": first layer's name parsed, "index": first layer's index parsed if exist}, {"layer": first layer, "name": first layer's name parsed, "index": first layer's index parsed if exist},
{"layer": second layer, "name": second layer's name parsed, "index": second layer's index parsed if exist}, {"layer": second layer, "name": second layer's name parsed, "index": second layer's index parsed if exist},
... ...
] ]
""" """
......
...@@ -5,7 +5,7 @@ __all__ = ["ResNet50_last_stage_stride1"] ...@@ -5,7 +5,7 @@ __all__ = ["ResNet50_last_stage_stride1"]
def ResNet50_last_stage_stride1(pretrained=False, use_ssld=False, **kwargs): def ResNet50_last_stage_stride1(pretrained=False, use_ssld=False, **kwargs):
def replace_function(conv): def replace_function(conv, pattern):
new_conv = Conv2D( new_conv = Conv2D(
in_channels=conv._in_channels, in_channels=conv._in_channels,
out_channels=conv._out_channels, out_channels=conv._out_channels,
...@@ -16,8 +16,8 @@ def ResNet50_last_stage_stride1(pretrained=False, use_ssld=False, **kwargs): ...@@ -16,8 +16,8 @@ def ResNet50_last_stage_stride1(pretrained=False, use_ssld=False, **kwargs):
bias_attr=conv._bias_attr) bias_attr=conv._bias_attr)
return new_conv return new_conv
match_re = "conv2d_4[4|6]" pattern = ["blocks[13].conv1.conv", "blocks[13].short.conv"]
model = ResNet50(pretrained=False, use_ssld=use_ssld, **kwargs) model = ResNet50(pretrained=False, use_ssld=use_ssld, **kwargs)
model.replace_sub(match_re, replace_function, True) model.upgrade_sublayer(pattern, replace_function)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld) _load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld)
return model return model
import paddle import paddle
from paddle.nn import Sigmoid from paddle.nn import Sigmoid
from ppcls.arch.backbone.legendary_models.vgg import VGG19 from ppcls.arch.backbone.legendary_models.vgg import VGG19
__all__ = ["VGG19Sigmoid"] __all__ = ["VGG19Sigmoid"]
class SigmoidSuffix(paddle.nn.Layer): class SigmoidSuffix(paddle.nn.Layer):
def __init__(self, origin_layer): def __init__(self, origin_layer):
super(SigmoidSuffix, self).__init__() super().__init__()
self.origin_layer = origin_layer self.origin_layer = origin_layer
self.sigmoid = Sigmoid() self.sigmoid = Sigmoid()
def forward(self, input, res_dict=None, **kwargs): def forward(self, input, res_dict=None, **kwargs):
x = self.origin_layer(input) x = self.origin_layer(input)
x = self.sigmoid(x) x = self.sigmoid(x)
return x return x
def VGG19Sigmoid(pretrained=False, use_ssld=False, **kwargs): def VGG19Sigmoid(pretrained=False, use_ssld=False, **kwargs):
def replace_function(origin_layer): def replace_function(origin_layer, pattern):
new_layer = SigmoidSuffix(origin_layer) new_layer = SigmoidSuffix(origin_layer)
return new_layer return new_layer
match_re = "linear_2" pattern = "fc2"
model = VGG19(pretrained=pretrained, use_ssld=use_ssld, **kwargs) model = VGG19(pretrained=pretrained, use_ssld=use_ssld, **kwargs)
model.replace_sub(match_re, replace_function, True) model.upgrade_sublayer(pattern, replace_function)
return model return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册