提交 ef13f8c6 编写于 作者: W weishengyu

modify code

上级 6dbbf8cc
......@@ -32,7 +32,7 @@ class TheseusLayer(nn.Layer):
stop_layer_name)
return after_stop
def _update_res(self, return_patterns):
def update_res(self, return_patterns):
if not return_patterns:
return
for layer_i in self._sub_layers:
......@@ -47,7 +47,7 @@ class TheseusLayer(nn.Layer):
self._sub_layers[layer_i]._save_sub_res_hook)
self._sub_layers[layer_i].res_dict = self.res_dict
if isinstance(self._sub_layers[layer_i], TheseusLayer):
self._sub_layers[layer_i]._update_res(return_patterns)
self._sub_layers[layer_i].update_res(return_patterns)
def _save_sub_res_hook(self, layer, input, output):
if self.res_dict is not None:
......
......@@ -137,7 +137,7 @@ class VGGNet(TheseusLayer):
self.fc1 = Linear(7 * 7 * 512, 4096)
self.fc2 = Linear(4096, 4096)
self.fc3 = Linear(4096, class_num)
self._update_res(return_patterns)
self.update_res(return_patterns)
def forward(self, inputs, res_dict=None):
x = self.conv_block_1(inputs)
......
......@@ -396,11 +396,7 @@ class Trainer(object):
self.model.train()
return eval_result
def forward(self, batch):
if self.return_inter:
res_dict = {}
else:
res_dict = None
def forward(self, batch, res_dict=None):
if not self.is_rec:
out = self.model(batch[0], res_dict=res_dict)
else:
......@@ -662,7 +658,11 @@ class Trainer(object):
image_file_list.append(image_file)
if len(batch_data) >= batch_size or idx == len(image_list) - 1:
batch_tensor = paddle.to_tensor(batch_data)
out = self.forward([batch_tensor])
if self.return_inter:
res_dict = {}
else:
res_dict = None
out = self.forward([batch_tensor], res_dict)
if isinstance(out, list):
out = out[0]
result = postprocess_func(out, image_file_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册