diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 8d27f2db554031a89c34403b123a269b9c4532e3..381f139cd81e11fc47b4794ff90cab251afb5a59 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -60,17 +60,17 @@ class Engine(object): # build model self.model = build_model(self.config, self.mode) + # load_pretrain + self._init_pretrained() + + self._init_amp() + # init train_func and eval_func self.eval = build_eval_func( self.config, mode=self.mode, model=self.model) self.train = build_train_func( self.config, mode=self.mode, model=self.model, eval_func=self.eval) - # load_pretrain - self._init_pretrained() - - self._init_amp() - # for distributed self._init_dist() @@ -197,11 +197,11 @@ class Engine(object): if self.config["Global"]["pretrained_model"] is not None: if self.config["Global"]["pretrained_model"].startswith("http"): load_dygraph_pretrain_from_url( - [self.model, getattr(self.train, "loss_func", None)], + [self.model, getattr(self, 'train_loss_func', None)], self.config["Global"]["pretrained_model"]) else: load_dygraph_pretrain( - [self.model, getattr(self.train, "loss_func", None)], + [self.model, getattr(self, 'train_loss_func', None)], self.config["Global"]["pretrained_model"]) def _init_amp(self): @@ -257,10 +257,10 @@ class Engine(object): if self.config["Global"]["distributed"]: dist.init_parallel_env() self.model = paddle.DataParallel(self.model) - if self.mode == 'train' and len(self.train.loss_func.parameters( + if self.mode == 'train' and len(self.train_loss_func.parameters( )) > 0: - self.train.loss_func = paddle.DataParallel( - self.train.loss_func) + self.train_loss_func = paddle.DataParallel( + self.train_loss_func) class ExportModel(TheseusLayer): diff --git a/ppcls/engine/evaluation/__init__.py b/ppcls/engine/evaluation/__init__.py index 17ed928775059dc2a3ba4e3523bcebe185e4711b..43cacda97722d2dfff6cfe053ca341af0a1d3625 100644 --- a/ppcls/engine/evaluation/__init__.py +++ b/ppcls/engine/evaluation/__init__.py @@ -20,7 +20,7 @@ from .adaface import adaface_eval def build_eval_func(config, mode, model): if mode not in ["eval", "train"]: return None - task = config["Global"].get("eval_mode", "classification") + task = config["Global"].get("task", "classification") if task == "classification": return ClassEval(config, mode, model) elif task == "retrieval":