# 'list' is needed to avoid error raised by popping self.res_dict
forres_keyinlist(self.res_dict):
# clear the res_dict because the forward process may change according to input
res_dict[res_key]=self.res_dict.pop(res_key)
returnres_dict
def_save_sub_res_hook(self,layer,input,output):
self.res_dict[self.res_name]=output
defreplace_sub(self,*args,**kwargs)->None:
msg="The function 'replace_sub()' is deprecated, please use 'upgrade_sublayer()' instead."
logger.error(DeprecationWarning(msg))
raiseDeprecationWarning(msg)
defupgrade_sublayer(self,
layer_name_pattern:Union[str,List[str]],
handle_func:Callable[[nn.Layer,str],nn.Layer]
)->Dict[str,nn.Layer]:
"""use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.
Args:
layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'.
handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'. The formal params are the layer(nn.Layer) and pattern(str) that is (a member of) layer_name_pattern (when layer_name_pattern is List type). And the return is the layer processed.
Returns:
Dict[str, nn.Layer]: The key is the pattern and corresponding value is the result returned by 'handle_func()'.
Examples:
from paddle import nn
import paddleclas
def rep_func(layer: nn.Layer, pattern: str):
new_layer = nn.Conv2D(
in_channels=layer._in_channels,
out_channels=layer._out_channels,
kernel_size=5,
padding=2
)
return new_layer
net = paddleclas.MobileNetV1()
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
print(res)
# {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
"""
ifnotisinstance(layer_name_pattern,list):
layer_name_pattern=[layer_name_pattern]
handle_res_dict={}
forpatterninlayer_name_pattern:
# parse pattern to find target layer and its parent