提交 c07758b3 编写于 作者: H HydrogenSulfate

fix engine.py

上级 fa52acd6
......@@ -144,22 +144,21 @@ class Engine(object):
self.config["Global"]["eval_during_train"]):
if self.eval_mode in ["classification", "adaface"]:
self.eval_dataloader = build_dataloader(
self.config["DataLoader"], "Eval", self.device,
self.use_dali)
self.config["DataLoader"], "Eval", self.device, False)
elif self.eval_mode == "retrieval":
self.gallery_query_dataloader = None
if len(self.config["DataLoader"]["Eval"].keys()) == 1:
key = list(self.config["DataLoader"]["Eval"].keys())[0]
self.gallery_query_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], key, self.device,
self.use_dali)
False)
else:
self.gallery_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Gallery",
self.device, self.use_dali)
self.device, False)
self.query_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Query",
self.device, self.use_dali)
self.device, False)
# build loss
if self.mode == "train":
......@@ -339,11 +338,11 @@ class Engine(object):
)
self.config["Global"]["seed"] = seed = 42
logger.info(
f"Set random seed to ({seed} + $PADDLE_TRAINER_ID) for different trainer"
f"Set random seed to ({int(seed)} + $PADDLE_TRAINER_ID) for different trainer"
)
paddle.seed(seed + dist.get_rank())
np.random.seed(seed + dist.get_rank())
random.seed(seed + dist.get_rank())
paddle.seed(int(seed) + dist.get_rank())
np.random.seed(int(seed) + dist.get_rank())
random.seed(int(seed) + dist.get_rank())
# build postprocess for infer
if self.mode == 'infer':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册