提交 0055ca2f 编写于 作者: T Tingquan Gao

Revert "debug"

This reverts commit 9e683d0d.
上级 753270ab
......@@ -60,17 +60,17 @@ class Engine(object):
# build model
self.model = build_model(self.config, self.mode)
# load_pretrain
self._init_pretrained()
self._init_amp()
# init train_func and eval_func
self.eval = build_eval_func(
self.config, mode=self.mode, model=self.model)
self.train = build_train_func(
self.config, mode=self.mode, model=self.model, eval_func=self.eval)
# load_pretrain
self._init_pretrained()
self._init_amp()
# for distributed
self._init_dist()
......@@ -197,11 +197,11 @@ class Engine(object):
if self.config["Global"]["pretrained_model"] is not None:
if self.config["Global"]["pretrained_model"].startswith("http"):
load_dygraph_pretrain_from_url(
[self.model, getattr(self.train, "loss_func", None)],
[self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"])
else:
load_dygraph_pretrain(
[self.model, getattr(self.train, "loss_func", None)],
[self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"])
def _init_amp(self):
......@@ -257,10 +257,10 @@ class Engine(object):
if self.config["Global"]["distributed"]:
dist.init_parallel_env()
self.model = paddle.DataParallel(self.model)
if self.mode == 'train' and len(self.train.loss_func.parameters(
if self.mode == 'train' and len(self.train_loss_func.parameters(
)) > 0:
self.train.loss_func = paddle.DataParallel(
self.train.loss_func)
self.train_loss_func = paddle.DataParallel(
self.train_loss_func)
class ExportModel(TheseusLayer):
......
......@@ -20,7 +20,7 @@ from .adaface import adaface_eval
def build_eval_func(config, mode, model):
if mode not in ["eval", "train"]:
return None
task = config["Global"].get("eval_mode", "classification")
task = config["Global"].get("task", "classification")
if task == "classification":
return ClassEval(config, mode, model)
elif task == "retrieval":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册