From d49657ad083b4c0ceeb16a517c42a77129cdedbc Mon Sep 17 00:00:00 2001 From: dongshuilong Date: Thu, 2 Sep 2021 07:42:22 +0000 Subject: [PATCH] fix rec forward bug --- ppcls/engine/engine.py | 5 +++++ ppcls/engine/train/train.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index ad5c584f..bdcd4c4c 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -51,6 +51,11 @@ class Engine(object): self.config = config self.eval_mode = self.config["Global"].get("eval_mode", "classification") + if "Head" in self.config["Arch"]: + self.is_rec = True + else: + self.is_rec = False + # init logger self.output_dir = self.config['Global']['output_dir'] log_file = os.path.join(self.output_dir, self.config["Arch"]["name"], diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index 9e36a063..73f22508 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -79,7 +79,7 @@ def train_epoch(trainer, epoch_id, print_batch_step): def forward(trainer, batch): - if trainer.eval_mode == "classification": + if not trainer.is_rec: return trainer.model(batch[0]) else: return trainer.model(batch[0], batch[1]) -- GitLab