diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index 4b1914816f331f797aeb563ab7488abe536145b2..12a7a2729dfc682a5a58d96207b7e8a0f3ab4002 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -66,53 +66,22 @@ class TheseusLayer(nn.Layer): handle_res_dict = {} for pattern in layer_name_pattern: - pattern_list = pattern.split(".") + # pattern_list = pattern.split(".") # find parent layer of sub-layer specified by pattern - sub_layer_parent = self - while len(pattern_list) > 1: - if '[' in pattern_list[0]: - sub_layer_name = pattern_list[0].split('[')[0] - sub_layer_index = pattern_list[0].split('[')[1].split(']')[ - 0] - sub_layer_parent = getattr(sub_layer_parent, - sub_layer_name)[sub_layer_index] - else: - sub_layer_parent = getattr(sub_layer_parent, - pattern_list[0], None) - if sub_layer_parent is None: - break - if isinstance(sub_layer_parent, WrapLayer): - sub_layer_parent = sub_layer_parent.sub_layer - pattern_list = pattern_list[1:] - if sub_layer_parent is None: - msg = f"Not found parent layer of sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})." - logger.warning(msg) + sub_layer_parent, _, _ = parse_pattern_str( + pattern=pattern, idx=(0, -1), sub_layer_parent=self) + + if not sub_layer_parent: continue # find sub-layer specified by pattern - if '[' in pattern_list[0]: - sub_layer_name = pattern_list[0].split('[')[0] - sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] - else: - sub_layer_name = pattern_list[0] - sub_layer_index = None - - sub_layer = getattr(sub_layer_parent, sub_layer_name, None) + sub_layer, sub_layer_name, sub_layer_index = parse_pattern_str( + pattern=pattern, idx=-1, sub_layer_parent=sub_layer_parent) if not sub_layer: - msg = f"Not found sub-layer by name({pattern_list[0]}) specifed in pattern({pattern})." - logger.warning(msg) continue - if sub_layer_index is not None: - if int(sub_layer_index) < 0 or int(sub_layer_index) >= len( - sub_layer): - msg = f"Not found sub-layer by index({sub_layer_index}) specifed in pattern({pattern})." - logger.warning(msg) - continue - sub_layer = sub_layer[sub_layer_index] - new_sub_layer = handle_func(sub_layer, pattern) if sub_layer_index: @@ -156,6 +125,7 @@ class TheseusLayer(nn.Layer): pattern_list = stop_layer_name.split(".") to_identity_list = [] + # TODO(gaotingquan): replace code by self._parse_pattern_str() layer = self while len(pattern_list) > 0: layer_parent = layer @@ -219,5 +189,67 @@ class WrapLayer(TheseusLayer): def wrap_theseus(sub_layer): - wrapped_layer = WrapLayer(sub_layer) - return wrapped_layer + return WrapLayer(sub_layer) + + +def unwrap_theseus(sub_layer): + if isinstance(sub_layer, WrapLayer): + sub_layer = sub_layer.sub_layer + return sub_layer + + +def slice_pattern(pattern, idx): + pattern_list = pattern.split(".") + if idx: + if isinstance(idx, tuple): + if len(idx) == 1: + return pattern_list[idx[0]] + elif len(idx) == 2: + return pattern_list[idx[0]:idx[1]] + else: + msg = f"Only support length of 'idx' is 1 or 2 when 'idx' is a tuple." + logger.warning(msg) + return None + elif isinstance(idx, int): + return [pattern_list[idx]] + else: + msg = f"Only support type of 'idx' is int or tuple." + logger.warning(msg) + return None + + return pattern_list + + +def parse_pattern_str(pattern, sub_layer_parent, idx=None): + pattern_list = slice_pattern(pattern, idx) + if not pattern_list: + return None, None, None + + while len(pattern_list) > 0: + if '[' in pattern_list[0]: + sub_layer_name = pattern_list[0].split('[')[0] + sub_layer_index = pattern_list[0].split('[')[1].split(']')[0] + else: + sub_layer_name = pattern_list[0] + sub_layer_index = None + + sub_layer_parent = getattr(sub_layer_parent, sub_layer_name, None) + sub_layer_parent = unwrap_theseus(sub_layer_parent) + + if sub_layer_parent is None: + msg = f"Not found layer named({sub_layer_name}) specifed in pattern({pattern})." + logger.warning(msg) + return None, sub_layer_name, sub_layer_index + + if sub_layer_index and sub_layer_parent: + if int(sub_layer_index) < 0 or int(sub_layer_index) >= len( + sub_layer_parent): + msg = f"Not found layer by index({sub_layer_index}) specifed in pattern({pattern}). The lenght of sub_layer's parent layer is < '{len(sub_layer_parent)}' and > '0'." + logger.warning(msg) + return None, sub_layer_name, sub_layer_index + sub_layer_parent = sub_layer_parent[sub_layer_index] + sub_layer_parent = unwrap_theseus(sub_layer_parent) + + pattern_list = pattern_list[1:] + + return sub_layer_parent, sub_layer_name, sub_layer_index