提交 8d0b0d4b 编写于 作者: G gaotingquan 提交者: Tingquan Gao

refactor: extract _parse_pattern_str() func

上级 18dec074
......@@ -66,53 +66,22 @@ class TheseusLayer(nn.Layer):
handle_res_dict = {}
for pattern in layer_name_pattern:
pattern_list = pattern.split(".")
# pattern_list = pattern.split(".")
# find parent layer of sub-layer specified by pattern
sub_layer_parent = self
while len(pattern_list) > 1:
if '[' in pattern_list[0]:
sub_layer_name = pattern_list[0].split('[')[0]
sub_layer_index = pattern_list[0].split('[')[1].split(']')[
0]
sub_layer_parent = getattr(sub_layer_parent,
sub_layer_name)[sub_layer_index]
else:
sub_layer_parent = getattr(sub_layer_parent,
pattern_list[0], None)
if sub_layer_parent is None:
break
if isinstance(sub_layer_parent, WrapLayer):
sub_layer_parent = sub_layer_parent.sub_layer
pattern_list = pattern_list[1:]
if sub_layer_parent is None:
msg = f"Not found parent layer of sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})."
logger.warning(msg)
sub_layer_parent, _, _ = parse_pattern_str(
pattern=pattern, idx=(0, -1), sub_layer_parent=self)
if not sub_layer_parent:
continue
# find sub-layer specified by pattern
if '[' in pattern_list[0]:
sub_layer_name = pattern_list[0].split('[')[0]
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
else:
sub_layer_name = pattern_list[0]
sub_layer_index = None
sub_layer = getattr(sub_layer_parent, sub_layer_name, None)
sub_layer, sub_layer_name, sub_layer_index = parse_pattern_str(
pattern=pattern, idx=-1, sub_layer_parent=sub_layer_parent)
if not sub_layer:
msg = f"Not found sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})."
logger.warning(msg)
continue
if sub_layer_index is not None:
if int(sub_layer_index) < 0 or int(sub_layer_index) >= len(
sub_layer):
msg = f"Not found sub-layer by index({sub_layer_index}) specifed in pattern({pattern})."
logger.warning(msg)
continue
sub_layer = sub_layer[sub_layer_index]
new_sub_layer = handle_func(sub_layer, pattern)
if sub_layer_index:
......@@ -156,6 +125,7 @@ class TheseusLayer(nn.Layer):
pattern_list = stop_layer_name.split(".")
to_identity_list = []
# TODO(gaotingquan): replace code by self._parse_pattern_str()
layer = self
while len(pattern_list) > 0:
layer_parent = layer
......@@ -219,5 +189,67 @@ class WrapLayer(TheseusLayer):
def wrap_theseus(sub_layer):
wrapped_layer = WrapLayer(sub_layer)
return wrapped_layer
return WrapLayer(sub_layer)
def unwrap_theseus(sub_layer):
if isinstance(sub_layer, WrapLayer):
sub_layer = sub_layer.sub_layer
return sub_layer
def slice_pattern(pattern, idx):
pattern_list = pattern.split(".")
if idx:
if isinstance(idx, tuple):
if len(idx) == 1:
return pattern_list[idx[0]]
elif len(idx) == 2:
return pattern_list[idx[0]:idx[1]]
else:
msg = f"Only support length of 'idx' is 1 or 2 when 'idx' is a tuple."
logger.warning(msg)
return None
elif isinstance(idx, int):
return [pattern_list[idx]]
else:
msg = f"Only support type of 'idx' is int or tuple."
logger.warning(msg)
return None
return pattern_list
def parse_pattern_str(pattern, sub_layer_parent, idx=None):
pattern_list = slice_pattern(pattern, idx)
if not pattern_list:
return None, None, None
while len(pattern_list) > 0:
if '[' in pattern_list[0]:
sub_layer_name = pattern_list[0].split('[')[0]
sub_layer_index = pattern_list[0].split('[')[1].split(']')[0]
else:
sub_layer_name = pattern_list[0]
sub_layer_index = None
sub_layer_parent = getattr(sub_layer_parent, sub_layer_name, None)
sub_layer_parent = unwrap_theseus(sub_layer_parent)
if sub_layer_parent is None:
msg = f"Not found layer named({sub_layer_name}) specifed in pattern({pattern})."
logger.warning(msg)
return None, sub_layer_name, sub_layer_index
if sub_layer_index and sub_layer_parent:
if int(sub_layer_index) < 0 or int(sub_layer_index) >= len(
sub_layer_parent):
msg = f"Not found layer by index({sub_layer_index}) specifed in pattern({pattern}). The lenght of sub_layer's parent layer is < '{len(sub_layer_parent)}' and > '0'."
logger.warning(msg)
return None, sub_layer_name, sub_layer_index
sub_layer_parent = sub_layer_parent[sub_layer_index]
sub_layer_parent = unwrap_theseus(sub_layer_parent)
pattern_list = pattern_list[1:]
return sub_layer_parent, sub_layer_name, sub_layer_index
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册