提交 18dec074 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: fix problems commented in reviewing

上级 41296972
...@@ -21,6 +21,7 @@ class TheseusLayer(nn.Layer): ...@@ -21,6 +21,7 @@ class TheseusLayer(nn.Layer):
def _return_dict_hook(self, layer, input, output): def _return_dict_hook(self, layer, input, output):
res_dict = {"output": output} res_dict = {"output": output}
# 'list' is needed to avoid error raised by popping self.res_dict
for res_key in list(self.res_dict): for res_key in list(self.res_dict):
res_dict[res_key] = self.res_dict.pop(res_key) res_dict[res_key] = self.res_dict.pop(res_key)
return res_dict return res_dict
...@@ -28,12 +29,44 @@ class TheseusLayer(nn.Layer): ...@@ -28,12 +29,44 @@ class TheseusLayer(nn.Layer):
def _save_sub_res_hook(self, layer, input, output): def _save_sub_res_hook(self, layer, input, output):
self.res_dict[self.res_name] = output self.res_dict[self.res_name] = output
def _find_layers_handle(self, patterns, handle_func): def replace_sub(self,
layer_name_pattern: Union[str, List[str]],
handle_func: Callable[[nn.Layer, str], nn.Layer]) -> Dict[
str, nn.Layer]:
"""use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.
Args:
layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'.
handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'.
Returns:
Dict[str, nn.Layer]: The key is the patter and corresponding value is the result returned by 'handle_func'.
Examples:
from paddle import nn
import paddleclas
def rep_func(sub_layer: nn.Layer, pattern: str):
new_layer = nn.Conv2D(
in_channels=sub_layer._in_channels,
out_channels=sub_layer._out_channels,
kernel_size=5,
padding=2
)
return new_layer
net = paddleclas.MobileNetV1()
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
print(res)
# {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
"""
if not isinstance(layer_name_pattern, list):
layer_name_pattern = [layer_name_pattern]
handle_res_dict = {} handle_res_dict = {}
for pattern in patterns: for pattern in layer_name_pattern:
pattern_list = pattern.split(".") pattern_list = pattern.split(".")
if not pattern_list:
continue
# find parent layer of sub-layer specified by pattern # find parent layer of sub-layer specified by pattern
sub_layer_parent = self sub_layer_parent = self
...@@ -65,32 +98,30 @@ class TheseusLayer(nn.Layer): ...@@ -65,32 +98,30 @@ class TheseusLayer(nn.Layer):
sub_layer_name = pattern_list[0] sub_layer_name = pattern_list[0]
sub_layer_index = None sub_layer_index = None
sub_layer = getattr(sub_layer_parent, sub_layer_name, False) sub_layer = getattr(sub_layer_parent, sub_layer_name, None)
if sub_layer is False: if not sub_layer:
msg = f"Not found sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})." msg = f"Not found sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})."
logger.warning(msg) logger.warning(msg)
continue continue
try: if sub_layer_index is not None:
sub_layer = sub_layer[ if int(sub_layer_index) < 0 or int(sub_layer_index) >= len(
sub_layer_index] if sub_layer_index is not None else sub_layer sub_layer):
except KeyError as e: msg = f"Not found sub-layer by index({sub_layer_index}) specifed in pattern({pattern})."
msg = f"Not found sub-layer by index({sub_layer_index}) specifed in pattern({pattern})." logger.warning(msg)
logger.warning(msg) continue
continue sub_layer = sub_layer[sub_layer_index]
if not isinstance(sub_layer, TheseusLayer): new_sub_layer = handle_func(sub_layer, pattern)
sub_layer = wrap_theseus(sub_layer)
if sub_layer_index: if sub_layer_index:
getattr(sub_layer_parent, getattr(sub_layer_parent,
sub_layer_name)[sub_layer_index] = sub_layer sub_layer_name)[sub_layer_index] = new_sub_layer
else: else:
setattr(sub_layer_parent, sub_layer_name, sub_layer) setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
handle_res = handle_func(sub_layer, pattern) handle_res_dict[pattern] = new_sub_layer
handle_res_dict[pattern] = handle_res
return handle_res_dict return handle_res_dict
def _set_identity(self, layer, layer_name, layer_index=None): def _set_identity(self, layer, layer_name, layer_index=None):
...@@ -113,45 +144,6 @@ class TheseusLayer(nn.Layer): ...@@ -113,45 +144,6 @@ class TheseusLayer(nn.Layer):
return stop_after return stop_after
def replace_sub(self,
layer_name_pattern: Union[str, List[str]],
replace_function: Callable[[nn.Layer, str], Any]) -> Any:
"""use 'replace_function' to modify the 'layer_name_pattern'.
Args:
layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'replace_function'.
replace_function (FunctionType): The function to modify target layer specified by 'layer_name_pattern'.
Returns:
bool: 'True' if successful, 'False' otherwise.
Examples:
from paddle import nn
import paddleclas
def rep_func(warp_layer: nn.Layer, pattern: str):
sub_layer = warp_layer.sub_layer
new_layer = nn.Conv2D(
in_channels=sub_layer._in_channels,
out_channels=sub_layer._out_channels,
kernel_size=5
)
warp_layer.sub_layer = new_layer
return True
net = paddleclas.MobileNetV1()
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], replace_function=rep_func)
print(res)
# {'blocks[11].depthwise_conv.conv': True, 'blocks[12].depthwise_conv.conv': True}
"""
if not isinstance(layer_name_pattern, list):
layer_name_pattern = [layer_name_pattern]
return self._find_layers_handle(
layer_name_pattern, handle_func=replace_function)
# TODO(weishengyu): stop doesn't work when stop layer has a parallel branch.
def stop_after(self, stop_layer_name: str) -> bool: def stop_after(self, stop_layer_name: str) -> bool:
"""stop forward and backward after 'stop_layer_name'. """stop forward and backward after 'stop_layer_name'.
...@@ -210,15 +202,11 @@ class TheseusLayer(nn.Layer): ...@@ -210,15 +202,11 @@ class TheseusLayer(nn.Layer):
layer.res_dict = self.res_dict layer.res_dict = self.res_dict
layer.res_name = pattern layer.res_name = pattern
layer.register_forward_post_hook(layer._save_sub_res_hook) layer.register_forward_post_hook(layer._save_sub_res_hook)
return True return layer
handle_func = Handler(self.res_dict) handle_func = Handler(self.res_dict)
if not isinstance(return_patterns, list): return self.replace_sub(return_patterns, handle_func=handle_func)
return_patterns = [return_patterns]
return self._find_layers_handle(
return_patterns, handle_func=handle_func)
class WrapLayer(TheseusLayer): class WrapLayer(TheseusLayer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册