提交 b0ae3a12 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: unify the pattern of layer's name

上级 7da2a997
from abc import ABC from typing import List, Union, Callable, Any
from paddle import nn from paddle import nn
import re from ppcls.utils import logger
class Identity(nn.Layer): class Identity(nn.Layer):
...@@ -19,36 +19,33 @@ class TheseusLayer(nn.Layer): ...@@ -19,36 +19,33 @@ class TheseusLayer(nn.Layer):
self.pruner = None self.pruner = None
self.quanter = None self.quanter = None
# stop doesn't work when stop layer has a parallel branch. # TODO(gaotingquan): weishengyu
def stop_after(self, stop_layer_name: str): def _return_dict_hook(self, layer, input, output):
after_stop = False res_dict = {"output": output}
for layer_i in self._sub_layers: for res_key in list(self.res_dict):
if after_stop: res_dict[res_key] = self.res_dict.pop(res_key)
self._sub_layers[layer_i] = Identity() return res_dict
continue
layer_name = self._sub_layers[layer_i].full_name() def _save_sub_res_hook(self, layer, input, output):
if layer_name == stop_layer_name: self.res_dict[self.res_name] = output
after_stop = True
continue def _find_layers_handle(self, patterns, handle_func):
if isinstance(self._sub_layers[layer_i], TheseusLayer): sub_layers_dict = {}
after_stop = self._sub_layers[layer_i].stop_after( for pattern in patterns:
stop_layer_name) pattern_list = pattern.split(".")
return after_stop
def update_res(self, return_patterns):
for return_pattern in return_patterns:
pattern_list = return_pattern.split(".")
if not pattern_list: if not pattern_list:
continue continue
sub_layer_parent = self sub_layer_parent = self
while len(pattern_list) > 1: while len(pattern_list) > 1:
if '[' in pattern_list[0]: if '[' in pattern_list[0]:
sub_layer_name = pattern_list[0].split('[')[0] sub_layer_name = pattern_list[0].split('[')[0]
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] sub_layer_index = pattern_list[0].split('[')[1].split(']')[
sub_layer_parent = getattr(sub_layer_parent, sub_layer_name)[sub_layer_index] 0]
sub_layer_parent = getattr(sub_layer_parent,
sub_layer_name)[sub_layer_index]
else: else:
sub_layer_parent = getattr(sub_layer_parent, pattern_list[0], sub_layer_parent = getattr(sub_layer_parent,
None) pattern_list[0], None)
if sub_layer_parent is None: if sub_layer_parent is None:
break break
if isinstance(sub_layer_parent, WrapLayer): if isinstance(sub_layer_parent, WrapLayer):
...@@ -59,59 +56,144 @@ class TheseusLayer(nn.Layer): ...@@ -59,59 +56,144 @@ class TheseusLayer(nn.Layer):
if '[' in pattern_list[0]: if '[' in pattern_list[0]:
sub_layer_name = pattern_list[0].split('[')[0] sub_layer_name = pattern_list[0].split('[')[0]
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
sub_layer = getattr(sub_layer_parent, sub_layer_name)[sub_layer_index] sub_layer = getattr(sub_layer_parent,
sub_layer_name)[sub_layer_index]
if not isinstance(sub_layer, TheseusLayer): if not isinstance(sub_layer, TheseusLayer):
sub_layer = wrap_theseus(sub_layer) sub_layer = wrap_theseus(sub_layer)
getattr(sub_layer_parent, sub_layer_name)[sub_layer_index] = sub_layer getattr(sub_layer_parent,
sub_layer_name)[sub_layer_index] = sub_layer
else: else:
sub_layer = getattr(sub_layer_parent, pattern_list[0]) sub_layer = getattr(sub_layer_parent, pattern_list[0])
if not isinstance(sub_layer, TheseusLayer): if not isinstance(sub_layer, TheseusLayer):
sub_layer = wrap_theseus(sub_layer) sub_layer = wrap_theseus(sub_layer)
setattr(sub_layer_parent, pattern_list[0], sub_layer) setattr(sub_layer_parent, pattern_list[0], sub_layer)
sub_layer.res_dict = self.res_dict sub_layers_dict[pattern] = sub_layer
sub_layer.res_name = return_pattern handle_res = handle_func(sub_layer, pattern)
sub_layer.register_forward_post_hook(sub_layer._save_sub_res_hook) return sub_layers_dict, handle_res
def replace_sub(self,
layer_name_pattern: Union[str, List[str]],
replace_function: Callable[[nn.Layer, str], Any]) -> bool:
"""use 'replace_function' to modify the 'layer_name_pattern'.
Args:
layer_name_pattern (str): The name of target layer variable.
replace_function (FunctionType): The function to modify target layer,
Returns:
bool: 'True' if successful, 'False' otherwise.
Examples:
import paddleclas
def replace_conv(origin_conv: nn.Conv2D):
new_conv = nn.Conv2D(
in_channels=origin_conv._in_channels,
out_channels=origin_conv._out_channels,
kernel_size=origin_conv._kernel_size,
stride=2
)
return new_conv
net = paddleclas.MobileNetV1()
tag = net.replace_sub(layer_name_pattern="conv", replace_function=replace_conv)
print(tag)
# True
"""
if not isinstance(layer_name_pattern, list):
layer_name_pattern = [layer_name_pattern]
return self._find_layers_handle(
layer_name_pattern, handle_func=replace_function)
def _set_identity(self, layer, layer_name, layer_index=None):
stop_after = False
for sub_layer_name in layer._sub_layers:
if stop_after:
layer._sub_layers[sub_layer_name] = Identity()
continue
if sub_layer_name == layer_name:
stop_after = True
def _save_sub_res_hook(self, layer, input, output): if layer_index and stop_after:
self.res_dict[self.res_name] = output stop_after = False
for sub_layer_index in layer._sub_layers[layer_name]._sub_layers:
if stop_after:
layer._sub_layers[layer_name][sub_layer_index] = Identity()
continue
if layer_index == sub_layer_index:
stop_after = True
def _return_dict_hook(self, layer, input, output): return stop_after
res_dict = {"output": output}
for res_key in list(self.res_dict): # stop doesn't work when stop layer has a parallel branch.
res_dict[res_key] = self.res_dict.pop(res_key) def stop_after(self, stop_layer_name: str) -> bool:
return res_dict """stop forward and backward after 'stop_layer_name'.
Args:
stop_layer_name (str): The name of target layer variable.
def replace_sub(self, layer_name_pattern, replace_function, Returns:
recursive=True): bool: 'True' if successful, 'False' otherwise.
for layer_i in self._sub_layers: """
layer_name = self._sub_layers[layer_i].full_name() pattern_list = stop_layer_name.split(".")
if re.match(layer_name_pattern, layer_name): to_identity_list = []
self._sub_layers[layer_i] = replace_function(self._sub_layers[
layer_i]) layer = self
if recursive: while len(pattern_list) > 0:
if isinstance(self._sub_layers[layer_i], TheseusLayer): layer_parent = layer
self._sub_layers[layer_i].replace_sub( if '[' in pattern_list[0]:
layer_name_pattern, replace_function, recursive) sub_layer_name = pattern_list[0].split('[')[0]
elif isinstance(self._sub_layers[layer_i], sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
(nn.Sequential, nn.LayerList)): layer = getattr(layer, sub_layer_name)[sub_layer_index]
for layer_j in self._sub_layers[layer_i]._sub_layers: else:
self._sub_layers[layer_i]._sub_layers[ sub_layer_name = pattern_list[0]
layer_j].replace_sub(layer_name_pattern, sub_layer_index = None
replace_function, recursive) layer = getattr(layer, sub_layer_name, None)
if layer is None:
''' msg = f"Not found layer by name({pattern_list[0]}) in stop_layer_name({stop_layer_name})."
example of replace function: logger.warning(msg)
def replace_conv(origin_conv: nn.Conv2D): return False
new_conv = nn.Conv2D(
in_channels=origin_conv._in_channels, to_identity_list.append(
out_channels=origin_conv._out_channels, (layer_parent, sub_layer_name, sub_layer_index))
kernel_size=origin_conv._kernel_size, pattern_list = pattern_list[1:]
stride=2
) for to_identity_layer in to_identity_list:
return new_conv if not self._set_identity(*to_identity_layer):
msg = "Failed to set the layers that after stop_layer_name to IdentityLayer."
''' logger.warning(msg)
return False
return True
def update_res(self, return_patterns: Union[str, List[str]]) -> bool:
"""update the results needed returned.
Args:
return_patterns (Union[str, List[str]]): The layer(s)' name to be retruened.
Returns:
bool: 'True' if successful, 'False' otherwise.
"""
class Handler(object):
def __init__(self, res_dict):
self.res_dict = res_dict
def __call__(self, layer, pattern):
layer.res_dict = self.res_dict
layer.res_name = pattern
layer.register_forward_post_hook(layer._save_sub_res_hook)
handle_func = Handler(self.res_dict)
if not isinstance(return_patterns, list):
return_patterns = [return_patterns]
return self._find_layers_handle(
return_patterns, handle_func=handle_func)
class WrapLayer(TheseusLayer): class WrapLayer(TheseusLayer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册