From 0055ca2ffe3f9b292f4da42bef90447ede94d992 Mon Sep 17 00:00:00 2001 From: Tingquan Gao <35441050@qq.com> Date: Tue, 14 Mar 2023 16:16:40 +0800 Subject: [PATCH] Revert "debug" This reverts commit 9e683d0d6934cae517bd73a0d3c4260cbef98e0c. --- ppcls/engine/engine.py | 20 ++++++++++---------- ppcls/engine/evaluation/__init__.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 8d27f2db..381f139c 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -60,17 +60,17 @@ class Engine(object): # build model self.model = build_model(self.config, self.mode) + # load_pretrain + self._init_pretrained() + + self._init_amp() + # init train_func and eval_func self.eval = build_eval_func( self.config, mode=self.mode, model=self.model) self.train = build_train_func( self.config, mode=self.mode, model=self.model, eval_func=self.eval) - # load_pretrain - self._init_pretrained() - - self._init_amp() - # for distributed self._init_dist() @@ -197,11 +197,11 @@ class Engine(object): if self.config["Global"]["pretrained_model"] is not None: if self.config["Global"]["pretrained_model"].startswith("http"): load_dygraph_pretrain_from_url( - [self.model, getattr(self.train, "loss_func", None)], + [self.model, getattr(self, 'train_loss_func', None)], self.config["Global"]["pretrained_model"]) else: load_dygraph_pretrain( - [self.model, getattr(self.train, "loss_func", None)], + [self.model, getattr(self, 'train_loss_func', None)], self.config["Global"]["pretrained_model"]) def _init_amp(self): @@ -257,10 +257,10 @@ class Engine(object): if self.config["Global"]["distributed"]: dist.init_parallel_env() self.model = paddle.DataParallel(self.model) - if self.mode == 'train' and len(self.train.loss_func.parameters( + if self.mode == 'train' and len(self.train_loss_func.parameters( )) > 0: - self.train.loss_func = paddle.DataParallel( - self.train.loss_func) + self.train_loss_func = paddle.DataParallel( + self.train_loss_func) class ExportModel(TheseusLayer): diff --git a/ppcls/engine/evaluation/__init__.py b/ppcls/engine/evaluation/__init__.py index 17ed9287..43cacda9 100644 --- a/ppcls/engine/evaluation/__init__.py +++ b/ppcls/engine/evaluation/__init__.py @@ -20,7 +20,7 @@ from .adaface import adaface_eval def build_eval_func(config, mode, model): if mode not in ["eval", "train"]: return None - task = config["Global"].get("eval_mode", "classification") + task = config["Global"].get("task", "classification") if task == "classification": return ClassEval(config, mode, model) elif task == "retrieval": -- GitLab