提交 5131956d 编写于 作者: W weishengyu

dbg theseus

上级 ce39aea9
...@@ -38,17 +38,18 @@ class TheseusLayer(nn.Layer): ...@@ -38,17 +38,18 @@ class TheseusLayer(nn.Layer):
for layer_i in self._sub_layers: for layer_i in self._sub_layers:
layer_name = self._sub_layers[layer_i].full_name() layer_name = self._sub_layers[layer_i].full_name()
if isinstance(self._sub_layers[layer_i], (nn.Sequential, nn.LayerList)): if isinstance(self._sub_layers[layer_i], (nn.Sequential, nn.LayerList)):
self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i]) self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i], self.res_dict)
self._sub_layers[layer_i].res_dict = self.res_dict
self._sub_layers[layer_i].update_res(return_patterns) self._sub_layers[layer_i].update_res(return_patterns)
else: else:
for return_pattern in return_patterns: for return_pattern in return_patterns:
if re.match(return_pattern, layer_name): if re.match(return_pattern, layer_name):
if not isinstance(self._sub_layers[layer_i], TheseusLayer): if not isinstance(self._sub_layers[layer_i], TheseusLayer):
self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i]) self._sub_layers[layer_i] = wrap_theseus(self._sub_layers[layer_i], self.res_dict)
else:
self._sub_layers[layer_i].res_dict = self.res_dict
self._sub_layers[layer_i].register_forward_post_hook( self._sub_layers[layer_i].register_forward_post_hook(
self._sub_layers[layer_i]._save_sub_res_hook) self._sub_layers[layer_i]._save_sub_res_hook)
self._sub_layers[layer_i].res_dict = self.res_dict
if isinstance(self._sub_layers[layer_i], TheseusLayer): if isinstance(self._sub_layers[layer_i], TheseusLayer):
self._sub_layers[layer_i].res_dict = self.res_dict self._sub_layers[layer_i].res_dict = self.res_dict
self._sub_layers[layer_i].update_res(return_patterns) self._sub_layers[layer_i].update_res(return_patterns)
...@@ -85,10 +86,12 @@ class TheseusLayer(nn.Layer): ...@@ -85,10 +86,12 @@ class TheseusLayer(nn.Layer):
class WrapLayer(TheseusLayer): class WrapLayer(TheseusLayer):
def __init__(self, sub_layer): def __init__(self, sub_layer, res_dict=None):
super(WrapLayer, self).__init__() super(WrapLayer, self).__init__()
self.sub_layer = sub_layer self.sub_layer = sub_layer
self.name = sub_layer.full_name() self.name = sub_layer.full_name()
if res_dict is not None:
self.res_dict = res_dict
def full_name(self): def full_name(self):
return self.name return self.name
...@@ -101,14 +104,14 @@ class WrapLayer(TheseusLayer): ...@@ -101,14 +104,14 @@ class WrapLayer(TheseusLayer):
return return
for layer_i in self.sub_layer._sub_layers: for layer_i in self.sub_layer._sub_layers:
if isinstance(self.sub_layer._sub_layers[layer_i], (nn.Sequential, nn.LayerList)): if isinstance(self.sub_layer._sub_layers[layer_i], (nn.Sequential, nn.LayerList)):
self.sub_layer._sub_layers[layer_i] = wrap_theseus(self.sub_layer._sub_layers[layer_i]) self.sub_layer._sub_layers[layer_i] = wrap_theseus(self.sub_layer._sub_layers[layer_i], self.res_dict)
self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict
self.sub_layer._sub_layers[layer_i].update_res(return_patterns) self.sub_layer._sub_layers[layer_i].update_res(return_patterns)
elif isinstance(self.sub_layer._sub_layers[layer_i], TheseusLayer):
self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict
layer_name = self.sub_layer._sub_layers[layer_i].full_name() layer_name = self.sub_layer._sub_layers[layer_i].full_name()
for return_pattern in return_patterns: for return_pattern in return_patterns:
if re.match(return_pattern, layer_name): if re.match(return_pattern, layer_name):
self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict
self.sub_layer._sub_layers[layer_i].register_forward_post_hook( self.sub_layer._sub_layers[layer_i].register_forward_post_hook(
self._sub_layers[layer_i]._save_sub_res_hook) self._sub_layers[layer_i]._save_sub_res_hook)
...@@ -116,6 +119,6 @@ class WrapLayer(TheseusLayer): ...@@ -116,6 +119,6 @@ class WrapLayer(TheseusLayer):
self.sub_layer._sub_layers[layer_i].update_res(return_patterns) self.sub_layer._sub_layers[layer_i].update_res(return_patterns)
def wrap_theseus(sub_layer): def wrap_theseus(sub_layer, res_dict=None):
wrapped_layer = WrapLayer(sub_layer) wrapped_layer = WrapLayer(sub_layer, res_dict)
return wrapped_layer return wrapped_layer
...@@ -146,12 +146,15 @@ class MobileNet(TheseusLayer): ...@@ -146,12 +146,15 @@ class MobileNet(TheseusLayer):
class_num, class_num,
weight_attr=ParamAttr(initializer=KaimingNormal())) weight_attr=ParamAttr(initializer=KaimingNormal()))
def forward(self, x): def forward(self, x, res_dict=None):
x = self.conv(x) x = self.conv(x)
x = self.blocks(x) x = self.blocks(x)
x = self.avg_pool(x) x = self.avg_pool(x)
x = self.flatten(x) x = self.flatten(x)
x = self.fc(x) x = self.fc(x)
if self.res_dict and res_dict is not None:
for res_key in list(self.res_dict):
res_dict[res_key] = self.res_dict.pop(res_key)
return x return x
......
...@@ -111,7 +111,7 @@ class VGGNet(TheseusLayer): ...@@ -111,7 +111,7 @@ class VGGNet(TheseusLayer):
model: nn.Layer. Specific VGG model depends on args. model: nn.Layer. Specific VGG model depends on args.
""" """
def __init__(self, config, stop_grad_layers=0, class_num=1000, return_patterns=None): def __init__(self, config, stop_grad_layers=0, class_num=1000):
super().__init__() super().__init__()
self.stop_grad_layers = stop_grad_layers self.stop_grad_layers = stop_grad_layers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册