提交 18dec074 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: fix problems commented in reviewing

上级 41296972
......@@ -21,6 +21,7 @@ class TheseusLayer(nn.Layer):
def _return_dict_hook(self, layer, input, output):
res_dict = {"output": output}
# 'list' is needed to avoid error raised by popping self.res_dict
for res_key in list(self.res_dict):
res_dict[res_key] = self.res_dict.pop(res_key)
return res_dict
......@@ -28,12 +29,44 @@ class TheseusLayer(nn.Layer):
def _save_sub_res_hook(self, layer, input, output):
self.res_dict[self.res_name] = output
def _find_layers_handle(self, patterns, handle_func):
def replace_sub(self,
layer_name_pattern: Union[str, List[str]],
handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[
str, nn.Layer]:
"""use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.
Args:
layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'.
handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'.
Returns:
Dict[str, nn.Layer]: The key is the patter and corresponding value is the result returned by 'handle_func'.
Examples:
from paddle import nn
import paddleclas
def rep_func(sub_layer: nn.Layer, pattern: str):
new_layer = nn.Conv2D(
in_channels=sub_layer._in_channels,
out_channels=sub_layer._out_channels,
kernel_size=5,
padding=2
)
return new_layer
net = paddleclas.MobileNetV1()
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
print(res)
# {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
"""
if not isinstance(layer_name_pattern, list):
layer_name_pattern = [layer_name_pattern]
handle_res_dict = {}
for pattern in patterns:
for pattern in layer_name_pattern:
pattern_list = pattern.split(".")
if not pattern_list:
continue
# find parent layer of sub-layer specified by pattern
sub_layer_parent = self
......@@ -65,32 +98,30 @@ class TheseusLayer(nn.Layer):
sub_layer_name = pattern_list[0]
sub_layer_index = None
sub_layer = getattr(sub_layer_parent, sub_layer_name, False)
sub_layer = getattr(sub_layer_parent, sub_layer_name, None)
if sub_layer is False:
if not sub_layer:
msg = f"Not found sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})."
logger.warning(msg)
continue
try:
sub_layer = sub_layer[
sub_layer_index] if sub_layer_index is not None else sub_layer
except KeyError as e:
msg = f"Not found sub-layer by index({sub_layer_index}) specifed in pattern({pattern})."
logger.warning(msg)
continue
if sub_layer_index is not None:
if int(sub_layer_index) < 0 or int(sub_layer_index) >= len(
sub_layer):
msg = f"Not found sub-layer by index({sub_layer_index}) specifed in pattern({pattern})."
logger.warning(msg)
continue
sub_layer = sub_layer[sub_layer_index]
if not isinstance(sub_layer, TheseusLayer):
sub_layer = wrap_theseus(sub_layer)
new_sub_layer = handle_func(sub_layer, pattern)
if sub_layer_index:
getattr(sub_layer_parent,
sub_layer_name)[sub_layer_index] = sub_layer
sub_layer_name)[sub_layer_index] = new_sub_layer
else:
setattr(sub_layer_parent, sub_layer_name, sub_layer)
setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
handle_res = handle_func(sub_layer, pattern)
handle_res_dict[pattern] = handle_res
handle_res_dict[pattern] = new_sub_layer
return handle_res_dict
def _set_identity(self, layer, layer_name, layer_index=None):
......@@ -113,45 +144,6 @@ class TheseusLayer(nn.Layer):
return stop_after
def replace_sub(self,
layer_name_pattern: Union[str, List[str]],
replace_function: Callable[[nn.Layer, str], Any]) -> Any:
"""use 'replace_function' to modify the 'layer_name_pattern'.
Args:
layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'replace_function'.
replace_function (FunctionType): The function to modify target layer specified by 'layer_name_pattern'.
Returns:
bool: 'True' if successful, 'False' otherwise.
Examples:
from paddle import nn
import paddleclas
def rep_func(warp_layer: nn.Layer, pattern: str):
sub_layer = warp_layer.sub_layer
new_layer = nn.Conv2D(
in_channels=sub_layer._in_channels,
out_channels=sub_layer._out_channels,
kernel_size=5
)
warp_layer.sub_layer = new_layer
return True
net = paddleclas.MobileNetV1()
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], replace_function=rep_func)
print(res)
# {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
"""
if not isinstance(layer_name_pattern, list):
layer_name_pattern = [layer_name_pattern]
return self._find_layers_handle(
layer_name_pattern, handle_func=replace_function)
# TODO(weishengyu): stop doesn't work when stop layer has a parallel branch.
def stop_after(self, stop_layer_name: str) -> bool:
"""stop forward and backward after 'stop_layer_name'.
......@@ -210,15 +202,11 @@ class TheseusLayer(nn.Layer):
layer.res_dict = self.res_dict
layer.res_name = pattern
layer.register_forward_post_hook(layer._save_sub_res_hook)
return True
return layer
handle_func = Handler(self.res_dict)
if not isinstance(return_patterns, list):
return_patterns = [return_patterns]
return self._find_layers_handle(
return_patterns, handle_func=handle_func)
return self.replace_sub(return_patterns, handle_func=handle_func)
class WrapLayer(TheseusLayer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册