提交 f41b09ef 编写于 作者: W weishengyu

dbg

上级 88a72b67
......@@ -109,6 +109,8 @@ class WrapLayer(TheseusLayer):
for return_pattern in return_patterns:
if re.match(return_pattern, layer_name):
self.sub_layer._sub_layers[layer_i].res_dict = self.res_dict
self._sub_layers[layer_i].register_forward_post_hook(
self._sub_layers[layer_i]._save_sub_res_hook)
if isinstance(self.sub_layer._sub_layers[layer_i], TheseusLayer):
self.sub_layer._sub_layers[layer_i].update_res(return_patterns)
......
......@@ -137,7 +137,6 @@ class VGGNet(TheseusLayer):
self.fc1 = Linear(7 * 7 * 512, 4096)
self.fc2 = Linear(4096, 4096)
self.fc3 = Linear(4096, class_num)
self.update_res(return_patterns)
def forward(self, inputs, res_dict=None):
x = self.conv_block_1(inputs)
......
......@@ -77,6 +77,7 @@ class Trainer(object):
self.model = build_model(self.config["Arch"])
if "return_patterns" in self.config["Arch"] and isinstance(self.model, TheseusLayer):
self.model.update_res(self.config["Arch"]["return_patterns"])
self.return_inter = True
else:
self.return_inter = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册