提交 b3ab418e 编写于 作者: G gaotingquan 提交者: Tingquan Gao

refactor: mv the init_res of net to super().__init__()

Put the calling of update_res(), stop_after(), freeze_befor() to the construction method of the parent class. And the init_net() support to call stop_after() by config of Arch.
上级 7513c0b5
......@@ -28,12 +28,14 @@ class Identity(nn.Layer):
class TheseusLayer(nn.Layer):
def __init__(self, *args, **kwargs):
super(TheseusLayer, self).__init__()
super().__init__()
self.res_dict = {}
self.res_name = self.full_name()
self.pruner = None
self.quanter = None
self.init_net(*args, **kwargs)
def _return_dict_hook(self, layer, input, output):
res_dict = {"logits": output}
# 'list' is needed to avoid error raised by popping self.res_dict
......@@ -43,36 +45,45 @@ class TheseusLayer(nn.Layer):
return res_dict
def init_net(self,
stages_pattern,
stages_pattern=None,
return_patterns=None,
return_stages=None,
stop_grad_pattern=None):
if return_patterns and return_stages:
msg = f"The 'return_patterns' would be ignored when 'return_stages' is set."
logger.warning(msg)
return_stages = None
if return_stages is True:
return_patterns = stages_pattern
# return_stages is int or bool
if type(return_stages) is int:
return_stages = [return_stages]
if isinstance(return_stages, list):
if max(return_stages) > len(stages_pattern) or min(
return_stages) < 0:
msg = f"The 'return_stages' set error. Illegal value(s) have been ignored. The stages' pattern list is {stages_pattern}."
freeze_befor=None,
stop_after=None):
# init the output of net
if return_patterns or return_stages:
if return_patterns and return_stages:
msg = f"The 'return_patterns' would be ignored when 'return_stages' is set."
logger.warning(msg)
return_stages = [
val for val in return_stages
if val >= 0 and val < len(stages_pattern)
]
return_patterns = [stages_pattern[i] for i in return_stages]
if return_patterns:
self.update_res(return_patterns)
if stop_grad_pattern is not None:
self.freeze_befor(stop_grad_pattern)
return_stages = None
if return_stages is True:
return_patterns = stages_pattern
# return_stages is int or bool
if type(return_stages) is int:
return_stages = [return_stages]
if isinstance(return_stages, list):
if max(return_stages) > len(stages_pattern) or min(
return_stages) < 0:
msg = f"The 'return_stages' set error. Illegal value(s) have been ignored. The stages' pattern list is {stages_pattern}."
logger.warning(msg)
return_stages = [
val for val in return_stages
if val >= 0 and val < len(stages_pattern)
]
return_patterns = [stages_pattern[i] for i in return_stages]
if return_patterns:
self.update_res(return_patterns)
# freeze subnet
if freeze_befor is not None:
self.freeze_befor(freeze_befor)
# set subnet to Identity
if stop_after is not None:
self.stop_after(stop_after)
def init_res(self,
stages_pattern,
......
......@@ -184,7 +184,6 @@ class SEModule(TheseusLayer):
class PPLCNet(TheseusLayer):
def __init__(self,
stages_pattern,
scale=1.0,
class_num=1000,
dropout_prob=0.2,
......@@ -192,10 +191,8 @@ class PPLCNet(TheseusLayer):
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
stride_list=[2, 2, 2, 2, 2],
use_last_conv=True,
return_patterns=None,
return_stages=None,
**kwargs):
super().__init__()
super().__init__(**kwargs)
self.scale = scale
self.class_expand = class_expand
self.lr_mult_list = lr_mult_list
......@@ -305,11 +302,6 @@ class PPLCNet(TheseusLayer):
make_divisible(self.net_config["blocks6"][-1][2] * scale),
class_num)
super().init_res(
stages_pattern,
return_patterns=return_patterns,
return_stages=return_stages)
def forward(self, x):
x = self.conv1(x)
......
......@@ -255,8 +255,9 @@ class PPLCNetV2(TheseusLayer):
class_num=1000,
dropout_prob=0,
use_last_conv=True,
class_expand=1280):
super().__init__()
class_expand=1280,
**kwargs):
super().__init__(**kwargs)
self.scale = scale
self.use_last_conv = use_last_conv
self.class_expand = class_expand
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册