layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'.
replace_function,recursive)
handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'. The formal params are the layer(nn.Layer) and pattern(str) that is (a member of) layer_name_pattern (when layer_name_pattern is List type). And the return is the layer processed.
'''
Returns:
example of replace function:
Dict[str, nn.Layer]: The key is the pattern and corresponding value is the result returned by 'handle_func()'.
def replace_conv(origin_conv: nn.Conv2D):
new_conv = nn.Conv2D(
Examples:
in_channels=origin_conv._in_channels,
out_channels=origin_conv._out_channels,
from paddle import nn
kernel_size=origin_conv._kernel_size,
import paddleclas
stride=2
def rep_func(layer: nn.Layer, pattern: str):
new_layer = nn.Conv2D(
in_channels=layer._in_channels,
out_channels=layer._out_channels,
kernel_size=5,
padding=2
)
)
return new_conv
return new_layer
net = paddleclas.MobileNetV1()
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
print(res)
# {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
"""
ifnotisinstance(layer_name_pattern,list):
layer_name_pattern=[layer_name_pattern]
handle_res_dict={}
forpatterninlayer_name_pattern:
# parse pattern to find target layer and its parent