未验证 提交 681ef5d1 编写于 作者: W Walter 提交者: GitHub

Merge pull request #1200 from RainFrost1/develop

fix rec forward bug
...@@ -51,6 +51,11 @@ class Engine(object): ...@@ -51,6 +51,11 @@ class Engine(object):
self.config = config self.config = config
self.eval_mode = self.config["Global"].get("eval_mode", self.eval_mode = self.config["Global"].get("eval_mode",
"classification") "classification")
if "Head" in self.config["Arch"]:
self.is_rec = True
else:
self.is_rec = False
# init logger # init logger
self.output_dir = self.config['Global']['output_dir'] self.output_dir = self.config['Global']['output_dir']
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"], log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
......
...@@ -79,7 +79,7 @@ def train_epoch(trainer, epoch_id, print_batch_step): ...@@ -79,7 +79,7 @@ def train_epoch(trainer, epoch_id, print_batch_step):
def forward(trainer, batch): def forward(trainer, batch):
if trainer.eval_mode == "classification": if not trainer.is_rec:
return trainer.model(batch[0]) return trainer.model(batch[0])
else: else:
return trainer.model(batch[0], batch[1]) return trainer.model(batch[0], batch[1])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册