未验证 提交 ff9dd192 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #741 from weisy11/develop_reg

Develop reg
...@@ -15,5 +15,5 @@ ...@@ -15,5 +15,5 @@
from . import backbone from . import backbone
from .backbone import * from .backbone import *
from .loss import * from ppcls.arch.loss_metrics.loss import *
from .utils import * from .utils import *
...@@ -11,11 +11,11 @@ class Identity(nn.Layer): ...@@ -11,11 +11,11 @@ class Identity(nn.Layer):
return inputs return inputs
class TheseusLayer(nn.Layer, ABC): class TheseusLayer(nn.Layer):
def __init__(self, *args, return_patterns=None, **kwargs): def __init__(self, *args, return_patterns=None, **kwargs):
super(TheseusLayer, self).__init__() super(TheseusLayer, self).__init__()
self.res_dict = None self.res_dict = None
self.register_forward_post_hook(self._disconnect_res_dict_hook) # self.register_forward_post_hook(self._disconnect_res_dict_hook)
if return_patterns is not None: if return_patterns is not None:
self._update_res(return_patterns) self._update_res(return_patterns)
...@@ -45,11 +45,11 @@ class TheseusLayer(nn.Layer, ABC): ...@@ -45,11 +45,11 @@ class TheseusLayer(nn.Layer, ABC):
if return_layers is not None and re.match(return_pattern, layer_name): if return_layers is not None and re.match(return_pattern, layer_name):
self._sub_layers[layer_i].register_forward_post_hook(self._save_sub_res_hook) self._sub_layers[layer_i].register_forward_post_hook(self._save_sub_res_hook)
def _save_sub_res_hook(self, layer, input, output): # def _save_sub_res_hook(self, layer, input, output):
self.res_dict[layer.full_name()] = output # self.res_dict[layer.full_name()] = output
#
def _disconnect_res_dict_hook(self, input, output): # def _disconnect_res_dict_hook(self, input, output):
self.res_dict = None # self.res_dict = None
def replace_sub(self, layer_name_pattern, replace_function, recursive=True): def replace_sub(self, layer_name_pattern, replace_function, recursive=True):
for layer_i in self._sub_layers: for layer_i in self._sub_layers:
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册