diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index ad5c584f056b422fb113306a7cbebf2e56347a7b..bdcd4c4cc7c40542e5658442d6dc8ea26d12e4ce 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -51,6 +51,11 @@ class Engine(object): self.config = config self.eval_mode = self.config["Global"].get("eval_mode", "classification") + if "Head" in self.config["Arch"]: + self.is_rec = True + else: + self.is_rec = False + # init logger self.output_dir = self.config['Global']['output_dir'] log_file = os.path.join(self.output_dir, self.config["Arch"]["name"], diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index 9e36a063e48009fdb27a36ea7f93a94e81e6de48..73f225087fe37b38d8274e43e7b901760101af6e 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -79,7 +79,7 @@ def train_epoch(trainer, epoch_id, print_batch_step): def forward(trainer, batch): - if trainer.eval_mode == "classification": + if not trainer.is_rec: return trainer.model(batch[0]) else: return trainer.model(batch[0], batch[1])