提交 cf205e13 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: fix result returned by _find_layers_handle

上级 b0ae3a12
......@@ -30,11 +30,13 @@ class TheseusLayer(nn.Layer):
self.res_dict[self.res_name] = output
def _find_layers_handle(self, patterns, handle_func):
sub_layers_dict = {}
handle_res_dict = {}
for pattern in patterns:
pattern_list = pattern.split(".")
if not pattern_list:
continue
# find parent layer of sub-layer specified by pattern
sub_layer_parent = self
while len(pattern_list) > 1:
if '[' in pattern_list[0]:
......@@ -52,7 +54,11 @@ class TheseusLayer(nn.Layer):
sub_layer_parent = sub_layer_parent.sub_layer
pattern_list = pattern_list[1:]
if sub_layer_parent is None:
msg = f"Not found layer by name({pattern_list[0]}) specifed in pattern({pattern})."
logger.warning(msg)
continue
# find sub-layer specified by pattern
if '[' in pattern_list[0]:
sub_layer_name = pattern_list[0].split('[')[0]
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
......@@ -68,13 +74,33 @@ class TheseusLayer(nn.Layer):
sub_layer = wrap_theseus(sub_layer)
setattr(sub_layer_parent, pattern_list[0], sub_layer)
sub_layers_dict[pattern] = sub_layer
handle_res = handle_func(sub_layer, pattern)
return sub_layers_dict, handle_res
handle_res_dict[pattern] = handle_res
return handle_res_dict
def _set_identity(self, layer, layer_name, layer_index=None):
stop_after = False
for sub_layer_name in layer._sub_layers:
if stop_after:
layer._sub_layers[sub_layer_name] = Identity()
continue
if sub_layer_name == layer_name:
stop_after = True
if layer_index and stop_after:
stop_after = False
for sub_layer_index in layer._sub_layers[layer_name]._sub_layers:
if stop_after:
layer._sub_layers[layer_name][sub_layer_index] = Identity()
continue
if layer_index == sub_layer_index:
stop_after = True
return stop_after
def replace_sub(self,
layer_name_pattern: Union[str, List[str]],
replace_function: Callable[[nn.Layer, str], Any]) -> bool:
replace_function: Callable[[nn.Layer, str], Any]) -> Any:
"""use 'replace_function' to modify the 'layer_name_pattern'.
Args:
......@@ -86,21 +112,23 @@ class TheseusLayer(nn.Layer):
Examples:
from paddle import nn
import paddleclas
def replace_conv(origin_conv: nn.Conv2D):
new_conv = nn.Conv2D(
in_channels=origin_conv._in_channels,
out_channels=origin_conv._out_channels,
kernel_size=origin_conv._kernel_size,
stride=2
def rep_func(warp_layer: nn.Layer, pattern: str):
sub_layer = warp_layer.sub_layer
new_layer = nn.Conv2D(
in_channels=sub_layer._in_channels,
out_channels=sub_layer._out_channels,
kernel_size=5
)
return new_conv
warp_layer.sub_layer = new_layer
return True
net = paddleclas.MobileNetV1()
tag = net.replace_sub(layer_name_pattern="conv", replace_function=replace_conv)
print(tag)
# True
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], replace_function=rep_func)
print(res)
# {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
"""
if not isinstance(layer_name_pattern, list):
......@@ -108,26 +136,6 @@ class TheseusLayer(nn.Layer):
return self._find_layers_handle(
layer_name_pattern, handle_func=replace_function)
def _set_identity(self, layer, layer_name, layer_index=None):
stop_after = False
for sub_layer_name in layer._sub_layers:
if stop_after:
layer._sub_layers[sub_layer_name] = Identity()
continue
if sub_layer_name == layer_name:
stop_after = True
if layer_index and stop_after:
stop_after = False
for sub_layer_index in layer._sub_layers[layer_name]._sub_layers:
if stop_after:
layer._sub_layers[layer_name][sub_layer_index] = Identity()
continue
if layer_index == sub_layer_index:
stop_after = True
return stop_after
# stop doesn't work when stop layer has a parallel branch.
def stop_after(self, stop_layer_name: str) -> bool:
"""stop forward and backward after 'stop_layer_name'.
......@@ -153,7 +161,7 @@ class TheseusLayer(nn.Layer):
sub_layer_index = None
layer = getattr(layer, sub_layer_name, None)
if layer is None:
msg = f"Not found layer by name({pattern_list[0]}) in stop_layer_name({stop_layer_name})."
msg = f"Not found layer by name({pattern_list[0]}) specifed in stop_layer_name({stop_layer_name})."
logger.warning(msg)
return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册