提交 71cb728b 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: rename to upgrade_sublayer()

上级 1e696ac2
...@@ -45,15 +45,14 @@ class TheseusLayer(nn.Layer): ...@@ -45,15 +45,14 @@ class TheseusLayer(nn.Layer):
self.res_dict[self.res_name] = output self.res_dict[self.res_name] = output
def replace_sub(self, *args, **kwargs) -> None: def replace_sub(self, *args, **kwargs) -> None:
msg = "The function 'replace_sub()' is deprecated, please use 'layer_wrench()' instead." msg = "The function 'replace_sub()' is deprecated, please use 'upgrade_sublayer()' instead."
logger.error(DeprecationWarning(msg)) logger.error(DeprecationWarning(msg))
raise DeprecationWarning(msg) raise DeprecationWarning(msg)
# TODO(gaotingquan): what is a good name? def upgrade_sublayer(self,
def layer_wrench(self,
layer_name_pattern: Union[str, List[str]], layer_name_pattern: Union[str, List[str]],
handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[ handle_func: Callable[[nn.Layer, str], nn.Layer]
str, nn.Layer]: ) -> Dict[str, nn.Layer]:
"""use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'. """use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.
Args: Args:
...@@ -166,7 +165,8 @@ class TheseusLayer(nn.Layer): ...@@ -166,7 +165,8 @@ class TheseusLayer(nn.Layer):
handle_func = Handler(self.res_dict) handle_func = Handler(self.res_dict)
res_dict = self.layer_wrench(return_patterns, handle_func=handle_func) res_dict = self.upgrade_sublayer(
return_patterns, handle_func=handle_func)
if hasattr(self, "hook_remove_helper"): if hasattr(self, "hook_remove_helper"):
self.hook_remove_helper.remove() self.hook_remove_helper.remove()
......
...@@ -5,7 +5,7 @@ __all__ = ["ResNet50_last_stage_stride1"] ...@@ -5,7 +5,7 @@ __all__ = ["ResNet50_last_stage_stride1"]
def ResNet50_last_stage_stride1(pretrained=False, use_ssld=False, **kwargs): def ResNet50_last_stage_stride1(pretrained=False, use_ssld=False, **kwargs):
def replace_function(conv): def replace_function(conv, pattern):
new_conv = Conv2D( new_conv = Conv2D(
in_channels=conv._in_channels, in_channels=conv._in_channels,
out_channels=conv._out_channels, out_channels=conv._out_channels,
...@@ -16,8 +16,8 @@ def ResNet50_last_stage_stride1(pretrained=False, use_ssld=False, **kwargs): ...@@ -16,8 +16,8 @@ def ResNet50_last_stage_stride1(pretrained=False, use_ssld=False, **kwargs):
bias_attr=conv._bias_attr) bias_attr=conv._bias_attr)
return new_conv return new_conv
match_re = "conv2d_4[4|6]" pattern = ["blocks[13].conv1.conv", "blocks[13].short.conv"]
model = ResNet50(pretrained=False, use_ssld=use_ssld, **kwargs) model = ResNet50(pretrained=False, use_ssld=use_ssld, **kwargs)
model.replace_sub(match_re, replace_function, True) model.upgrade_sublayer(pattern, replace_function)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld) _load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld)
return model return model
...@@ -7,7 +7,7 @@ __all__ = ["VGG19Sigmoid"] ...@@ -7,7 +7,7 @@ __all__ = ["VGG19Sigmoid"]
class SigmoidSuffix(paddle.nn.Layer): class SigmoidSuffix(paddle.nn.Layer):
def __init__(self, origin_layer): def __init__(self, origin_layer):
super(SigmoidSuffix, self).__init__() super().__init__()
self.origin_layer = origin_layer self.origin_layer = origin_layer
self.sigmoid = Sigmoid() self.sigmoid = Sigmoid()
...@@ -18,11 +18,11 @@ class SigmoidSuffix(paddle.nn.Layer): ...@@ -18,11 +18,11 @@ class SigmoidSuffix(paddle.nn.Layer):
def VGG19Sigmoid(pretrained=False, use_ssld=False, **kwargs): def VGG19Sigmoid(pretrained=False, use_ssld=False, **kwargs):
def replace_function(origin_layer): def replace_function(origin_layer, pattern):
new_layer = SigmoidSuffix(origin_layer) new_layer = SigmoidSuffix(origin_layer)
return new_layer return new_layer
match_re = "linear_2" pattern = "fc2"
model = VGG19(pretrained=pretrained, use_ssld=use_ssld, **kwargs) model = VGG19(pretrained=pretrained, use_ssld=use_ssld, **kwargs)
model.replace_sub(match_re, replace_function, True) model.upgrade_sublayer(pattern, replace_function)
return model return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册