提交 6c164c9e 编写于 作者: W weishengyu

dbg

上级 6ebe7f09
......@@ -38,6 +38,7 @@ from ppcls.utils.config import print_config
from ppcls.data import build_dataloader
from ppcls.arch import build_model
from ppcls.arch import apply_to_static
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.loss import build_loss
from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer
......@@ -75,8 +76,10 @@ class Trainer(object):
self.is_rec = False
self.model = build_model(self.config["Arch"])
if "return_pattern" in self.config["Arch"]:
if "return_patterns" in self.config["Arch"] and isinstance(self.model, TheseusLayer):
self.return_inter = True
else:
self.return_inter = False
# set @to_static for benchmark, skip this by default.
apply_to_static(self.config, self.model)
......@@ -395,13 +398,13 @@ class Trainer(object):
def forward(self, batch):
if self.return_inter:
return_dict = {}
res_dict = {}
else:
return_dict = None
res_dict = None
if not self.is_rec:
out = self.model(batch[0], return_dict=return_dict)
out = self.model(batch[0], res_dict=res_dict)
else:
out = self.model(batch[0], batch[1], return_dict=return_dict)
out = self.model(batch[0], batch[1], return_dict=res_dict)
return out
@paddle.no_grad()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册