diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index bf53467c11a9f6ba9fe87fea4da15fbfeaf80cf8..30c5a9f8b5894ae0273351ee175236d6eb1a661b 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -77,7 +77,11 @@ class TheseusLayer(nn.Layer): return_patterns = [stages_pattern[i] for i in return_stages] if return_patterns: - self.update_res(return_patterns) + # call update_res function after the __init__ of the object has completed execution, that is, the contructing of layer or model has been completed. + def update_res_hook(layer, input): + self.update_res(return_patterns) + + self.register_forward_pre_hook(update_res_hook) # freeze subnet if freeze_befor is not None: