提交 0055ca2f 编写于 作者: T Tingquan Gao

Revert "debug"

This reverts commit 9e683d0d.
上级 753270ab
...@@ -60,17 +60,17 @@ class Engine(object): ...@@ -60,17 +60,17 @@ class Engine(object):
# build model # build model
self.model = build_model(self.config, self.mode) self.model = build_model(self.config, self.mode)
# load_pretrain
self._init_pretrained()
self._init_amp()
# init train_func and eval_func # init train_func and eval_func
self.eval = build_eval_func( self.eval = build_eval_func(
self.config, mode=self.mode, model=self.model) self.config, mode=self.mode, model=self.model)
self.train = build_train_func( self.train = build_train_func(
self.config, mode=self.mode, model=self.model, eval_func=self.eval) self.config, mode=self.mode, model=self.model, eval_func=self.eval)
# load_pretrain
self._init_pretrained()
self._init_amp()
# for distributed # for distributed
self._init_dist() self._init_dist()
...@@ -197,11 +197,11 @@ class Engine(object): ...@@ -197,11 +197,11 @@ class Engine(object):
if self.config["Global"]["pretrained_model"] is not None: if self.config["Global"]["pretrained_model"] is not None:
if self.config["Global"]["pretrained_model"].startswith("http"): if self.config["Global"]["pretrained_model"].startswith("http"):
load_dygraph_pretrain_from_url( load_dygraph_pretrain_from_url(
[self.model, getattr(self.train, "loss_func", None)], [self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"]) self.config["Global"]["pretrained_model"])
else: else:
load_dygraph_pretrain( load_dygraph_pretrain(
[self.model, getattr(self.train, "loss_func", None)], [self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"]) self.config["Global"]["pretrained_model"])
def _init_amp(self): def _init_amp(self):
...@@ -257,10 +257,10 @@ class Engine(object): ...@@ -257,10 +257,10 @@ class Engine(object):
if self.config["Global"]["distributed"]: if self.config["Global"]["distributed"]:
dist.init_parallel_env() dist.init_parallel_env()
self.model = paddle.DataParallel(self.model) self.model = paddle.DataParallel(self.model)
if self.mode == 'train' and len(self.train.loss_func.parameters( if self.mode == 'train' and len(self.train_loss_func.parameters(
)) > 0: )) > 0:
self.train.loss_func = paddle.DataParallel( self.train_loss_func = paddle.DataParallel(
self.train.loss_func) self.train_loss_func)
class ExportModel(TheseusLayer): class ExportModel(TheseusLayer):
......
...@@ -20,7 +20,7 @@ from .adaface import adaface_eval ...@@ -20,7 +20,7 @@ from .adaface import adaface_eval
def build_eval_func(config, mode, model): def build_eval_func(config, mode, model):
if mode not in ["eval", "train"]: if mode not in ["eval", "train"]:
return None return None
task = config["Global"].get("eval_mode", "classification") task = config["Global"].get("task", "classification")
if task == "classification": if task == "classification":
return ClassEval(config, mode, model) return ClassEval(config, mode, model)
elif task == "retrieval": elif task == "retrieval":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册