提交 cf205e13 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: fix result returned by _find_layers_handle

上级 b0ae3a12
...@@ -30,11 +30,13 @@ class TheseusLayer(nn.Layer): ...@@ -30,11 +30,13 @@ class TheseusLayer(nn.Layer):
self.res_dict[self.res_name] = output self.res_dict[self.res_name] = output
def _find_layers_handle(self, patterns, handle_func): def _find_layers_handle(self, patterns, handle_func):
sub_layers_dict = {} handle_res_dict = {}
for pattern in patterns: for pattern in patterns:
pattern_list = pattern.split(".") pattern_list = pattern.split(".")
if not pattern_list: if not pattern_list:
continue continue
# find parent layer of sub-layer specified by pattern
sub_layer_parent = self sub_layer_parent = self
while len(pattern_list) > 1: while len(pattern_list) > 1:
if '[' in pattern_list[0]: if '[' in pattern_list[0]:
...@@ -52,7 +54,11 @@ class TheseusLayer(nn.Layer): ...@@ -52,7 +54,11 @@ class TheseusLayer(nn.Layer):
sub_layer_parent = sub_layer_parent.sub_layer sub_layer_parent = sub_layer_parent.sub_layer
pattern_list = pattern_list[1:] pattern_list = pattern_list[1:]
if sub_layer_parent is None: if sub_layer_parent is None:
msg = f"Not found layer by name({pattern_list[0]}) specifed in pattern({pattern})."
logger.warning(msg)
continue continue
# find sub-layer specified by pattern
if '[' in pattern_list[0]: if '[' in pattern_list[0]:
sub_layer_name = pattern_list[0].split('[')[0] sub_layer_name = pattern_list[0].split('[')[0]
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
...@@ -68,13 +74,33 @@ class TheseusLayer(nn.Layer): ...@@ -68,13 +74,33 @@ class TheseusLayer(nn.Layer):
sub_layer = wrap_theseus(sub_layer) sub_layer = wrap_theseus(sub_layer)
setattr(sub_layer_parent, pattern_list[0], sub_layer) setattr(sub_layer_parent, pattern_list[0], sub_layer)
sub_layers_dict[pattern] = sub_layer
handle_res = handle_func(sub_layer, pattern) handle_res = handle_func(sub_layer, pattern)
return sub_layers_dict, handle_res handle_res_dict[pattern] = handle_res
return handle_res_dict
def _set_identity(self, layer, layer_name, layer_index=None):
stop_after = False
for sub_layer_name in layer._sub_layers:
if stop_after:
layer._sub_layers[sub_layer_name] = Identity()
continue
if sub_layer_name == layer_name:
stop_after = True
if layer_index and stop_after:
stop_after = False
for sub_layer_index in layer._sub_layers[layer_name]._sub_layers:
if stop_after:
layer._sub_layers[layer_name][sub_layer_index] = Identity()
continue
if layer_index == sub_layer_index:
stop_after = True
return stop_after
def replace_sub(self, def replace_sub(self,
layer_name_pattern: Union[str, List[str]], layer_name_pattern: Union[str, List[str]],
replace_function: Callable[[nn.Layer, str], Any]) -> bool: replace_function: Callable[[nn.Layer, str], Any]) -> Any:
"""use 'replace_function' to modify the 'layer_name_pattern'. """use 'replace_function' to modify the 'layer_name_pattern'.
Args: Args:
...@@ -83,24 +109,26 @@ class TheseusLayer(nn.Layer): ...@@ -83,24 +109,26 @@ class TheseusLayer(nn.Layer):
Returns: Returns:
bool: 'True' if successful, 'False' otherwise. bool: 'True' if successful, 'False' otherwise.
Examples: Examples:
from paddle import nn
import paddleclas import paddleclas
def replace_conv(origin_conv: nn.Conv2D): def rep_func(warp_layer: nn.Layer, pattern: str):
new_conv = nn.Conv2D( sub_layer = warp_layer.sub_layer
in_channels=origin_conv._in_channels, new_layer = nn.Conv2D(
out_channels=origin_conv._out_channels, in_channels=sub_layer._in_channels,
kernel_size=origin_conv._kernel_size, out_channels=sub_layer._out_channels,
stride=2 kernel_size=5
) )
return new_conv warp_layer.sub_layer = new_layer
return True
net = paddleclas.MobileNetV1() net = paddleclas.MobileNetV1()
tag = net.replace_sub(layer_name_pattern="conv", replace_function=replace_conv) res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], replace_function=rep_func)
print(tag) print(res)
# True # {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
""" """
if not isinstance(layer_name_pattern, list): if not isinstance(layer_name_pattern, list):
...@@ -108,26 +136,6 @@ class TheseusLayer(nn.Layer): ...@@ -108,26 +136,6 @@ class TheseusLayer(nn.Layer):
return self._find_layers_handle( return self._find_layers_handle(
layer_name_pattern, handle_func=replace_function) layer_name_pattern, handle_func=replace_function)
def _set_identity(self, layer, layer_name, layer_index=None):
stop_after = False
for sub_layer_name in layer._sub_layers:
if stop_after:
layer._sub_layers[sub_layer_name] = Identity()
continue
if sub_layer_name == layer_name:
stop_after = True
if layer_index and stop_after:
stop_after = False
for sub_layer_index in layer._sub_layers[layer_name]._sub_layers:
if stop_after:
layer._sub_layers[layer_name][sub_layer_index] = Identity()
continue
if layer_index == sub_layer_index:
stop_after = True
return stop_after
# stop doesn't work when stop layer has a parallel branch. # stop doesn't work when stop layer has a parallel branch.
def stop_after(self, stop_layer_name: str) -> bool: def stop_after(self, stop_layer_name: str) -> bool:
"""stop forward and backward after 'stop_layer_name'. """stop forward and backward after 'stop_layer_name'.
...@@ -153,7 +161,7 @@ class TheseusLayer(nn.Layer): ...@@ -153,7 +161,7 @@ class TheseusLayer(nn.Layer):
sub_layer_index = None sub_layer_index = None
layer = getattr(layer, sub_layer_name, None) layer = getattr(layer, sub_layer_name, None)
if layer is None: if layer is None:
msg = f"Not found layer by name({pattern_list[0]}) in stop_layer_name({stop_layer_name})." msg = f"Not found layer by name({pattern_list[0]}) specifed in stop_layer_name({stop_layer_name})."
logger.warning(msg) logger.warning(msg)
return False return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册