提交 b0ae3a12 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: unify the pattern of layer's name

上级 7da2a997
from abc import ABC from typing import List, Union, Callable, Any
from paddle import nn from paddle import nn
import re from ppcls.utils import logger
class Identity(nn.Layer): class Identity(nn.Layer):
...@@ -19,36 +19,33 @@ class TheseusLayer(nn.Layer): ...@@ -19,36 +19,33 @@ class TheseusLayer(nn.Layer):
self.pruner = None self.pruner = None
self.quanter = None self.quanter = None
# stop doesn't work when stop layer has a parallel branch. # TODO(gaotingquan): weishengyu
def stop_after(self, stop_layer_name: str): def _return_dict_hook(self, layer, input, output):
after_stop = False res_dict = {"output": output}
for layer_i in self._sub_layers: for res_key in list(self.res_dict):
if after_stop: res_dict[res_key] = self.res_dict.pop(res_key)
self._sub_layers[layer_i] = Identity() return res_dict
continue
layer_name = self._sub_layers[layer_i].full_name() def _save_sub_res_hook(self, layer, input, output):
if layer_name == stop_layer_name: self.res_dict[self.res_name] = output
after_stop = True
continue def _find_layers_handle(self, patterns, handle_func):
if isinstance(self._sub_layers[layer_i], TheseusLayer): sub_layers_dict = {}
after_stop = self._sub_layers[layer_i].stop_after( for pattern in patterns:
stop_layer_name) pattern_list = pattern.split(".")
return after_stop
def update_res(self, return_patterns):
for return_pattern in return_patterns:
pattern_list = return_pattern.split(".")
if not pattern_list: if not pattern_list:
continue continue
sub_layer_parent = self sub_layer_parent = self
while len(pattern_list) > 1: while len(pattern_list) > 1:
if '[' in pattern_list[0]: if '[' in pattern_list[0]:
sub_layer_name = pattern_list[0].split('[')[0] sub_layer_name = pattern_list[0].split('[')[0]
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] sub_layer_index = pattern_list[0].split('[')[1].split(']')[
sub_layer_parent = getattr(sub_layer_parent, sub_layer_name)[sub_layer_index] 0]
sub_layer_parent = getattr(sub_layer_parent,
sub_layer_name)[sub_layer_index]
else: else:
sub_layer_parent = getattr(sub_layer_parent, pattern_list[0], sub_layer_parent = getattr(sub_layer_parent,
None) pattern_list[0], None)
if sub_layer_parent is None: if sub_layer_parent is None:
break break
if isinstance(sub_layer_parent, WrapLayer): if isinstance(sub_layer_parent, WrapLayer):
...@@ -59,49 +56,38 @@ class TheseusLayer(nn.Layer): ...@@ -59,49 +56,38 @@ class TheseusLayer(nn.Layer):
if '[' in pattern_list[0]: if '[' in pattern_list[0]:
sub_layer_name = pattern_list[0].split('[')[0] sub_layer_name = pattern_list[0].split('[')[0]
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
sub_layer = getattr(sub_layer_parent, sub_layer_name)[sub_layer_index] sub_layer = getattr(sub_layer_parent,
sub_layer_name)[sub_layer_index]
if not isinstance(sub_layer, TheseusLayer): if not isinstance(sub_layer, TheseusLayer):
sub_layer = wrap_theseus(sub_layer) sub_layer = wrap_theseus(sub_layer)
getattr(sub_layer_parent, sub_layer_name)[sub_layer_index] = sub_layer getattr(sub_layer_parent,
sub_layer_name)[sub_layer_index] = sub_layer
else: else:
sub_layer = getattr(sub_layer_parent, pattern_list[0]) sub_layer = getattr(sub_layer_parent, pattern_list[0])
if not isinstance(sub_layer, TheseusLayer): if not isinstance(sub_layer, TheseusLayer):
sub_layer = wrap_theseus(sub_layer) sub_layer = wrap_theseus(sub_layer)
setattr(sub_layer_parent, pattern_list[0], sub_layer) setattr(sub_layer_parent, pattern_list[0], sub_layer)
sub_layer.res_dict = self.res_dict sub_layers_dict[pattern] = sub_layer
sub_layer.res_name = return_pattern handle_res = handle_func(sub_layer, pattern)
sub_layer.register_forward_post_hook(sub_layer._save_sub_res_hook) return sub_layers_dict, handle_res
def _save_sub_res_hook(self, layer, input, output): def replace_sub(self,
self.res_dict[self.res_name] = output layer_name_pattern: Union[str, List[str]],
replace_function: Callable[[nn.Layer, str], Any]) -> bool:
"""use 'replace_function' to modify the 'layer_name_pattern'.
def _return_dict_hook(self, layer, input, output): Args:
res_dict = {"output": output} layer_name_pattern (str): The name of target layer variable.
for res_key in list(self.res_dict): replace_function (FunctionType): The function to modify target layer,
res_dict[res_key] = self.res_dict.pop(res_key)
return res_dict Returns:
bool: 'True' if successful, 'False' otherwise.
Examples:
import paddleclas
def replace_sub(self, layer_name_pattern, replace_function,
recursive=True):
for layer_i in self._sub_layers:
layer_name = self._sub_layers[layer_i].full_name()
if re.match(layer_name_pattern, layer_name):
self._sub_layers[layer_i] = replace_function(self._sub_layers[
layer_i])
if recursive:
if isinstance(self._sub_layers[layer_i], TheseusLayer):
self._sub_layers[layer_i].replace_sub(
layer_name_pattern, replace_function, recursive)
elif isinstance(self._sub_layers[layer_i],
(nn.Sequential, nn.LayerList)):
for layer_j in self._sub_layers[layer_i]._sub_layers:
self._sub_layers[layer_i]._sub_layers[
layer_j].replace_sub(layer_name_pattern,
replace_function, recursive)
'''
example of replace function:
def replace_conv(origin_conv: nn.Conv2D): def replace_conv(origin_conv: nn.Conv2D):
new_conv = nn.Conv2D( new_conv = nn.Conv2D(
in_channels=origin_conv._in_channels, in_channels=origin_conv._in_channels,
...@@ -111,7 +97,103 @@ class TheseusLayer(nn.Layer): ...@@ -111,7 +97,103 @@ class TheseusLayer(nn.Layer):
) )
return new_conv return new_conv
''' net = paddleclas.MobileNetV1()
tag = net.replace_sub(layer_name_pattern="conv", replace_function=replace_conv)
print(tag)
# True
"""
if not isinstance(layer_name_pattern, list):
layer_name_pattern = [layer_name_pattern]
return self._find_layers_handle(
layer_name_pattern, handle_func=replace_function)
def _set_identity(self, layer, layer_name, layer_index=None):
stop_after = False
for sub_layer_name in layer._sub_layers:
if stop_after:
layer._sub_layers[sub_layer_name] = Identity()
continue
if sub_layer_name == layer_name:
stop_after = True
if layer_index and stop_after:
stop_after = False
for sub_layer_index in layer._sub_layers[layer_name]._sub_layers:
if stop_after:
layer._sub_layers[layer_name][sub_layer_index] = Identity()
continue
if layer_index == sub_layer_index:
stop_after = True
return stop_after
# stop doesn't work when stop layer has a parallel branch.
def stop_after(self, stop_layer_name: str) -> bool:
"""stop forward and backward after 'stop_layer_name'.
Args:
stop_layer_name (str): The name of target layer variable.
Returns:
bool: 'True' if successful, 'False' otherwise.
"""
pattern_list = stop_layer_name.split(".")
to_identity_list = []
layer = self
while len(pattern_list) > 0:
layer_parent = layer
if '[' in pattern_list[0]:
sub_layer_name = pattern_list[0].split('[')[0]
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
layer = getattr(layer, sub_layer_name)[sub_layer_index]
else:
sub_layer_name = pattern_list[0]
sub_layer_index = None
layer = getattr(layer, sub_layer_name, None)
if layer is None:
msg = f"Not found layer by name({pattern_list[0]}) in stop_layer_name({stop_layer_name})."
logger.warning(msg)
return False
to_identity_list.append(
(layer_parent, sub_layer_name, sub_layer_index))
pattern_list = pattern_list[1:]
for to_identity_layer in to_identity_list:
if not self._set_identity(*to_identity_layer):
msg = "Failed to set the layers that after stop_layer_name to IdentityLayer."
logger.warning(msg)
return False
return True
def update_res(self, return_patterns: Union[str, List[str]]) -> bool:
"""update the results needed returned.
Args:
return_patterns (Union[str, List[str]]): The layer(s)' name to be retruened.
Returns:
bool: 'True' if successful, 'False' otherwise.
"""
class Handler(object):
def __init__(self, res_dict):
self.res_dict = res_dict
def __call__(self, layer, pattern):
layer.res_dict = self.res_dict
layer.res_name = pattern
layer.register_forward_post_hook(layer._save_sub_res_hook)
handle_func = Handler(self.res_dict)
if not isinstance(return_patterns, list):
return_patterns = [return_patterns]
return self._find_layers_handle(
return_patterns, handle_func=handle_func)
class WrapLayer(TheseusLayer): class WrapLayer(TheseusLayer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册