提交 98f38fa0 编写于 作者: W weishengyu

dbg

上级 f41b09ef
...@@ -50,11 +50,11 @@ class TheseusLayer(nn.Layer): ...@@ -50,11 +50,11 @@ class TheseusLayer(nn.Layer):
self._sub_layers[layer_i]._save_sub_res_hook) self._sub_layers[layer_i]._save_sub_res_hook)
self._sub_layers[layer_i].res_dict = self.res_dict self._sub_layers[layer_i].res_dict = self.res_dict
if isinstance(self._sub_layers[layer_i], TheseusLayer): if isinstance(self._sub_layers[layer_i], TheseusLayer):
self._sub_layers[layer_i].res_dict = self.res_dict
self._sub_layers[layer_i].update_res(return_patterns) self._sub_layers[layer_i].update_res(return_patterns)
def _save_sub_res_hook(self, layer, input, output): def _save_sub_res_hook(self, layer, input, output):
if self.res_dict is not None: self.res_dict[layer.full_name()] = output
self.res_dict[layer.full_name()] = output
def replace_sub(self, layer_name_pattern, replace_function, recursive=True): def replace_sub(self, layer_name_pattern, replace_function, recursive=True):
for layer_i in self._sub_layers: for layer_i in self._sub_layers:
...@@ -109,7 +109,7 @@ class WrapLayer(TheseusLayer): ...@@ -109,7 +109,7 @@ class WrapLayer(TheseusLayer):
for return_pattern in return_patterns: for return_pattern in return_patterns:
if re.match(return_pattern, layer_name): if re.match(return_pattern, layer_name):
self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict
self._sub_layers[layer_i].register_forward_post_hook( self.sub_layer._sub_layers[layer_i].register_forward_post_hook(
self._sub_layers[layer_i]._save_sub_res_hook) self._sub_layers[layer_i]._save_sub_res_hook)
if isinstance(self.sub_layer._sub_layers[layer_i], TheseusLayer): if isinstance(self.sub_layer._sub_layers[layer_i], TheseusLayer):
......
...@@ -152,7 +152,7 @@ class VGGNet(TheseusLayer): ...@@ -152,7 +152,7 @@ class VGGNet(TheseusLayer):
x = self.relu(x) x = self.relu(x)
x = self.drop(x) x = self.drop(x)
x = self.fc3(x) x = self.fc3(x)
if self.res_dict: if self.res_dict and res_dict is not None:
for res_key in list(self.res_dict): for res_key in list(self.res_dict):
res_dict[res_key] = self.res_dict.pop(res_key) res_dict[res_key] = self.res_dict.pop(res_key)
return x return x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册