提交 b3ab418e 编写于 作者: G gaotingquan 提交者: Tingquan Gao

refactor: mv the init_res of net to super().__init__()

Put the calling of update_res(), stop_after(), freeze_befor() to the construction method of the parent class. And the init_net() support to call stop_after() by config of Arch.
上级 7513c0b5
...@@ -28,12 +28,14 @@ class Identity(nn.Layer): ...@@ -28,12 +28,14 @@ class Identity(nn.Layer):
class TheseusLayer(nn.Layer): class TheseusLayer(nn.Layer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TheseusLayer, self).__init__() super().__init__()
self.res_dict = {} self.res_dict = {}
self.res_name = self.full_name() self.res_name = self.full_name()
self.pruner = None self.pruner = None
self.quanter = None self.quanter = None
self.init_net(*args, **kwargs)
def _return_dict_hook(self, layer, input, output): def _return_dict_hook(self, layer, input, output):
res_dict = {"logits": output} res_dict = {"logits": output}
# 'list' is needed to avoid error raised by popping self.res_dict # 'list' is needed to avoid error raised by popping self.res_dict
...@@ -43,36 +45,45 @@ class TheseusLayer(nn.Layer): ...@@ -43,36 +45,45 @@ class TheseusLayer(nn.Layer):
return res_dict return res_dict
def init_net(self, def init_net(self,
stages_pattern, stages_pattern=None,
return_patterns=None, return_patterns=None,
return_stages=None, return_stages=None,
stop_grad_pattern=None): freeze_befor=None,
if return_patterns and return_stages: stop_after=None):
msg = f"The 'return_patterns' would be ignored when 'return_stages' is set." # init the output of net
logger.warning(msg) if return_patterns or return_stages:
return_stages = None if return_patterns and return_stages:
msg = f"The 'return_patterns' would be ignored when 'return_stages' is set."
if return_stages is True:
return_patterns = stages_pattern
# return_stages is int or bool
if type(return_stages) is int:
return_stages = [return_stages]
if isinstance(return_stages, list):
if max(return_stages) > len(stages_pattern) or min(
return_stages) < 0:
msg = f"The 'return_stages' set error. Illegal value(s) have been ignored. The stages' pattern list is {stages_pattern}."
logger.warning(msg) logger.warning(msg)
return_stages = [ return_stages = None
val for val in return_stages
if val >= 0 and val < len(stages_pattern) if return_stages is True:
] return_patterns = stages_pattern
return_patterns = [stages_pattern[i] for i in return_stages]
# return_stages is int or bool
if return_patterns: if type(return_stages) is int:
self.update_res(return_patterns) return_stages = [return_stages]
if isinstance(return_stages, list):
if stop_grad_pattern is not None: if max(return_stages) > len(stages_pattern) or min(
self.freeze_befor(stop_grad_pattern) return_stages) < 0:
msg = f"The 'return_stages' set error. Illegal value(s) have been ignored. The stages' pattern list is {stages_pattern}."
logger.warning(msg)
return_stages = [
val for val in return_stages
if val >= 0 and val < len(stages_pattern)
]
return_patterns = [stages_pattern[i] for i in return_stages]
if return_patterns:
self.update_res(return_patterns)
# freeze subnet
if freeze_befor is not None:
self.freeze_befor(freeze_befor)
# set subnet to Identity
if stop_after is not None:
self.stop_after(stop_after)
def init_res(self, def init_res(self,
stages_pattern, stages_pattern,
......
...@@ -184,7 +184,6 @@ class SEModule(TheseusLayer): ...@@ -184,7 +184,6 @@ class SEModule(TheseusLayer):
class PPLCNet(TheseusLayer): class PPLCNet(TheseusLayer):
def __init__(self, def __init__(self,
stages_pattern,
scale=1.0, scale=1.0,
class_num=1000, class_num=1000,
dropout_prob=0.2, dropout_prob=0.2,
...@@ -192,10 +191,8 @@ class PPLCNet(TheseusLayer): ...@@ -192,10 +191,8 @@ class PPLCNet(TheseusLayer):
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
stride_list=[2, 2, 2, 2, 2], stride_list=[2, 2, 2, 2, 2],
use_last_conv=True, use_last_conv=True,
return_patterns=None,
return_stages=None,
**kwargs): **kwargs):
super().__init__() super().__init__(**kwargs)
self.scale = scale self.scale = scale
self.class_expand = class_expand self.class_expand = class_expand
self.lr_mult_list = lr_mult_list self.lr_mult_list = lr_mult_list
...@@ -305,11 +302,6 @@ class PPLCNet(TheseusLayer): ...@@ -305,11 +302,6 @@ class PPLCNet(TheseusLayer):
make_divisible(self.net_config["blocks6"][-1][2] * scale), make_divisible(self.net_config["blocks6"][-1][2] * scale),
class_num) class_num)
super().init_res(
stages_pattern,
return_patterns=return_patterns,
return_stages=return_stages)
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
......
...@@ -255,8 +255,9 @@ class PPLCNetV2(TheseusLayer): ...@@ -255,8 +255,9 @@ class PPLCNetV2(TheseusLayer):
class_num=1000, class_num=1000,
dropout_prob=0, dropout_prob=0,
use_last_conv=True, use_last_conv=True,
class_expand=1280): class_expand=1280,
super().__init__() **kwargs):
super().__init__(**kwargs)
self.scale = scale self.scale = scale
self.use_last_conv = use_last_conv self.use_last_conv = use_last_conv
self.class_expand = class_expand self.class_expand = class_expand
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册