提交 a75dc8c9 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: fix the error that containers nesting cannot be handled.

the error would be raised when when the pattern string represents nested, e.g., containing "[3][1]".
上级 a1baf3f4
......@@ -103,7 +103,7 @@ class TheseusLayer(nn.Layer):
return new_layer
net = paddleclas.MobileNetV1()
res = net.replace_sub(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
res = net.upgrade_sublayer(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
print(res)
# {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
"""
......@@ -117,18 +117,26 @@ class TheseusLayer(nn.Layer):
layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
if not layer_list:
continue
sub_layer_parent = layer_list[-2]["layer"] if len(
layer_list) > 1 else self
sub_layer = layer_list[-1]["layer"]
sub_layer_name = layer_list[-1]["name"]
sub_layer_index = layer_list[-1]["index"]
sub_layer_index_list = layer_list[-1]["index_list"]
new_sub_layer = handle_func(sub_layer, pattern)
if sub_layer_index:
getattr(sub_layer_parent,
sub_layer_name)[sub_layer_index] = new_sub_layer
if sub_layer_index_list:
if len(sub_layer_index_list) > 1:
sub_layer_parent = getattr(
sub_layer_parent,
sub_layer_name)[sub_layer_index_list[0]]
for sub_layer_index in sub_layer_index_list[1:-1]:
sub_layer_parent = sub_layer_parent[sub_layer_index]
sub_layer_parent[sub_layer_index_list[-1]] = new_sub_layer
else:
getattr(sub_layer_parent, sub_layer_name)[
sub_layer_index_list[0]] = new_sub_layer
else:
setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
......@@ -151,8 +159,8 @@ class TheseusLayer(nn.Layer):
parent_layer = self
for layer_dict in layer_list:
name, index = layer_dict["name"], layer_dict["index"]
if not set_identity(parent_layer, name, index):
name, index_list = layer_dict["name"], layer_dict["index_list"]
if not set_identity(parent_layer, name, index_list):
msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'."
logger.warning(msg)
return False
......@@ -208,13 +216,13 @@ def save_sub_res_hook(layer, input, output):
def set_identity(parent_layer: nn.Layer,
layer_name: str,
layer_index: str=None) -> bool:
"""set the layer specified by layer_name and layer_index to Indentity.
layer_index_list: str=None) -> bool:
"""set the layer specified by layer_name and layer_index_list to Indentity.
Args:
parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and layer_index.
parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and layer_index_list.
layer_name (str): The name of target layer to be set to Indentity.
layer_index (str, optional): The index of target layer to be set to Indentity in parent_layer. Defaults to None.
layer_index_list (str, optional): The index of target layer to be set to Indentity in parent_layer. Defaults to None.
Returns:
bool: True if successfully, False otherwise.
......@@ -228,16 +236,19 @@ def set_identity(parent_layer: nn.Layer,
if sub_layer_name == layer_name:
stop_after = True
if layer_index and stop_after:
stop_after = False
for sub_layer_index in parent_layer._sub_layers[
layer_name]._sub_layers:
if stop_after:
parent_layer._sub_layers[layer_name][
sub_layer_index] = Identity()
continue
if layer_index == sub_layer_index:
stop_after = True
if layer_index_list and stop_after:
layer_container = parent_layer._sub_layers[layer_name]
for num, layer_index in enumerate(layer_index_list):
stop_after = False
for i in range(num):
layer_container = layer_container[layer_index_list[i]]
for sub_layer_index in layer_container._sub_layers:
if stop_after:
parent_layer._sub_layers[layer_name][
sub_layer_index] = Identity()
continue
if layer_index == sub_layer_index:
stop_after = True
return stop_after
......@@ -269,10 +280,12 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[
while len(pattern_list) > 0:
if '[' in pattern_list[0]:
target_layer_name = pattern_list[0].split('[')[0]
target_layer_index = pattern_list[0].split('[')[1].split(']')[0]
target_layer_index_list = list(
index.split(']')[0]
for index in pattern_list[0].split('[')[1:])
else:
target_layer_name = pattern_list[0]
target_layer_index = None
target_layer_index_list = None
target_layer = getattr(parent_layer, target_layer_name, None)
......@@ -281,21 +294,22 @@ def parse_pattern_str(pattern: str, parent_layer: nn.Layer) -> Union[
logger.warning(msg)
return None
if target_layer_index and target_layer:
if int(target_layer_index) < 0 or int(target_layer_index) >= len(
target_layer):
msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0."
logger.warning(msg)
return None
target_layer = target_layer[target_layer_index]
if target_layer_index_list:
for target_layer_index in target_layer_index_list:
if int(target_layer_index) < 0 or int(
target_layer_index) >= len(target_layer):
msg = f"Not found layer by index('{target_layer_index}') specifed in pattern('{pattern}'). The index should < {len(target_layer)} and > 0."
logger.warning(msg)
return None
target_layer = target_layer[target_layer_index]
layer_list.append({
"layer": target_layer,
"name": target_layer_name,
"index": target_layer_index
"index_list": target_layer_index_list
})
pattern_list = pattern_list[1:]
parent_layer = target_layer
return layer_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册