提交 ef13f8c6 编写于 作者: W weishengyu

modify code

上级 6dbbf8cc
...@@ -32,7 +32,7 @@ class TheseusLayer(nn.Layer): ...@@ -32,7 +32,7 @@ class TheseusLayer(nn.Layer):
stop_layer_name) stop_layer_name)
return after_stop return after_stop
def _update_res(self, return_patterns): def update_res(self, return_patterns):
if not return_patterns: if not return_patterns:
return return
for layer_i in self._sub_layers: for layer_i in self._sub_layers:
...@@ -47,7 +47,7 @@ class TheseusLayer(nn.Layer): ...@@ -47,7 +47,7 @@ class TheseusLayer(nn.Layer):
self._sub_layers[layer_i]._save_sub_res_hook) self._sub_layers[layer_i]._save_sub_res_hook)
self._sub_layers[layer_i].res_dict = self.res_dict self._sub_layers[layer_i].res_dict = self.res_dict
if isinstance(self._sub_layers[layer_i], TheseusLayer): if isinstance(self._sub_layers[layer_i], TheseusLayer):
self._sub_layers[layer_i]._update_res(return_patterns) self._sub_layers[layer_i].update_res(return_patterns)
def _save_sub_res_hook(self, layer, input, output): def _save_sub_res_hook(self, layer, input, output):
if self.res_dict is not None: if self.res_dict is not None:
......
...@@ -137,7 +137,7 @@ class VGGNet(TheseusLayer): ...@@ -137,7 +137,7 @@ class VGGNet(TheseusLayer):
self.fc1 = Linear(7 * 7 * 512, 4096) self.fc1 = Linear(7 * 7 * 512, 4096)
self.fc2 = Linear(4096, 4096) self.fc2 = Linear(4096, 4096)
self.fc3 = Linear(4096, class_num) self.fc3 = Linear(4096, class_num)
self._update_res(return_patterns) self.update_res(return_patterns)
def forward(self, inputs, res_dict=None): def forward(self, inputs, res_dict=None):
x = self.conv_block_1(inputs) x = self.conv_block_1(inputs)
......
...@@ -396,11 +396,7 @@ class Trainer(object): ...@@ -396,11 +396,7 @@ class Trainer(object):
self.model.train() self.model.train()
return eval_result return eval_result
def forward(self, batch): def forward(self, batch, res_dict=None):
if self.return_inter:
res_dict = {}
else:
res_dict = None
if not self.is_rec: if not self.is_rec:
out = self.model(batch[0], res_dict=res_dict) out = self.model(batch[0], res_dict=res_dict)
else: else:
...@@ -662,7 +658,11 @@ class Trainer(object): ...@@ -662,7 +658,11 @@ class Trainer(object):
image_file_list.append(image_file) image_file_list.append(image_file)
if len(batch_data) >= batch_size or idx == len(image_list) - 1: if len(batch_data) >= batch_size or idx == len(image_list) - 1:
batch_tensor = paddle.to_tensor(batch_data) batch_tensor = paddle.to_tensor(batch_data)
out = self.forward([batch_tensor]) if self.return_inter:
res_dict = {}
else:
res_dict = None
out = self.forward([batch_tensor], res_dict)
if isinstance(out, list): if isinstance(out, list):
out = out[0] out = out[0]
result = postprocess_func(out, image_file_list) result = postprocess_func(out, image_file_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册