提交 7513c0b5 编写于 作者: G gaotingquan 提交者: Tingquan Gao

feat: support freeeze subnet

freeze subnet specified with the specified layer name by setting OutputTensor.stop_gradient=True
上级 d670f2e4
...@@ -42,10 +42,45 @@ class TheseusLayer(nn.Layer): ...@@ -42,10 +42,45 @@ class TheseusLayer(nn.Layer):
res_dict[res_key] = self.res_dict.pop(res_key) res_dict[res_key] = self.res_dict.pop(res_key)
return res_dict return res_dict
def init_net(self,
stages_pattern,
return_patterns=None,
return_stages=None,
stop_grad_pattern=None):
if return_patterns and return_stages:
msg = f"The 'return_patterns' would be ignored when 'return_stages' is set."
logger.warning(msg)
return_stages = None
if return_stages is True:
return_patterns = stages_pattern
# return_stages is int or bool
if type(return_stages) is int:
return_stages = [return_stages]
if isinstance(return_stages, list):
if max(return_stages) > len(stages_pattern) or min(
return_stages) < 0:
msg = f"The 'return_stages' set error. Illegal value(s) have been ignored. The stages' pattern list is {stages_pattern}."
logger.warning(msg)
return_stages = [
val for val in return_stages
if val >= 0 and val < len(stages_pattern)
]
return_patterns = [stages_pattern[i] for i in return_stages]
if return_patterns:
self.update_res(return_patterns)
if stop_grad_pattern is not None:
self.freeze_befor(stop_grad_pattern)
def init_res(self, def init_res(self,
stages_pattern, stages_pattern,
return_patterns=None, return_patterns=None,
return_stages=None): return_stages=None):
msg = "\"init_res\" will be deprecated, please use \"init_net\" instead."
logger.warning(DeprecationWarning(msg))
if return_patterns and return_stages: if return_patterns and return_stages:
msg = f"The 'return_patterns' would be ignored when 'return_stages' is set." msg = f"The 'return_patterns' would be ignored when 'return_stages' is set."
logger.warning(msg) logger.warning(msg)
...@@ -168,6 +203,37 @@ class TheseusLayer(nn.Layer): ...@@ -168,6 +203,37 @@ class TheseusLayer(nn.Layer):
return True return True
def freeze_befor(self, layer_name: str) -> bool:
"""freeze the layer named layer_name and its previous layer.
Args:
layer_name (str): The name of layer that would be freezed.
Returns:
bool: 'True' if successful, 'False' otherwise.
"""
def stop_grad(layer, pattern):
class StopGradLayer(nn.Layer):
def __init__(self):
super().__init__()
self.layer = layer
def forward(self, x):
x = self.layer(x)
x.stop_gradient = True
return x
new_layer = StopGradLayer()
return new_layer
res = self.upgrade_sublayer(layer_name, stop_grad)
if len(res) == 0:
msg = "Failed to stop the gradient befor the layer named '{layer_name}'"
logger.warning(msg)
return False
return True
def update_res( def update_res(
self, self,
return_patterns: Union[str, List[str]]) -> Dict[str, nn.Layer]: return_patterns: Union[str, List[str]]) -> Dict[str, nn.Layer]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册