提交 b4678fd3 编写于 作者: W weishengyu

revert trainer

上级 98f38fa0
...@@ -38,7 +38,6 @@ from ppcls.utils.config import print_config ...@@ -38,7 +38,6 @@ from ppcls.utils.config import print_config
from ppcls.data import build_dataloader from ppcls.data import build_dataloader
from ppcls.arch import build_model from ppcls.arch import build_model
from ppcls.arch import apply_to_static from ppcls.arch import apply_to_static
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.loss import build_loss from ppcls.loss import build_loss
from ppcls.metric import build_metrics from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer from ppcls.optimizer import build_optimizer
...@@ -76,11 +75,6 @@ class Trainer(object): ...@@ -76,11 +75,6 @@ class Trainer(object):
self.is_rec = False self.is_rec = False
self.model = build_model(self.config["Arch"]) self.model = build_model(self.config["Arch"])
if "return_patterns" in self.config["Arch"] and isinstance(self.model, TheseusLayer):
self.model.update_res(self.config["Arch"]["return_patterns"])
self.return_inter = True
else:
self.return_inter = False
# set @to_static for benchmark, skip this by default. # set @to_static for benchmark, skip this by default.
apply_to_static(self.config, self.model) apply_to_static(self.config, self.model)
...@@ -397,11 +391,11 @@ class Trainer(object): ...@@ -397,11 +391,11 @@ class Trainer(object):
self.model.train() self.model.train()
return eval_result return eval_result
def forward(self, batch, res_dict=None): def forward(self, batch):
if not self.is_rec: if not self.is_rec:
out = self.model(batch[0], res_dict=res_dict) out = self.model(batch[0])
else: else:
out = self.model(batch[0], batch[1], return_dict=res_dict) out = self.model(batch[0], batch[1])
return out return out
@paddle.no_grad() @paddle.no_grad()
...@@ -659,11 +653,7 @@ class Trainer(object): ...@@ -659,11 +653,7 @@ class Trainer(object):
image_file_list.append(image_file) image_file_list.append(image_file)
if len(batch_data) >= batch_size or idx == len(image_list) - 1: if len(batch_data) >= batch_size or idx == len(image_list) - 1:
batch_tensor = paddle.to_tensor(batch_data) batch_tensor = paddle.to_tensor(batch_data)
if self.return_inter: out = self.forward([batch_tensor])
res_dict = {}
else:
res_dict = None
out = self.forward([batch_tensor], res_dict)
if isinstance(out, list): if isinstance(out, list):
out = out[0] out = out[0]
result = postprocess_func(out, image_file_list) result = postprocess_func(out, image_file_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册