提交 c07758b3 编写于 作者: H HydrogenSulfate

fix engine.py

上级 fa52acd6
...@@ -144,22 +144,21 @@ class Engine(object): ...@@ -144,22 +144,21 @@ class Engine(object):
self.config["Global"]["eval_during_train"]): self.config["Global"]["eval_during_train"]):
if self.eval_mode in ["classification", "adaface"]: if self.eval_mode in ["classification", "adaface"]:
self.eval_dataloader = build_dataloader( self.eval_dataloader = build_dataloader(
self.config["DataLoader"], "Eval", self.device, self.config["DataLoader"], "Eval", self.device, False)
self.use_dali)
elif self.eval_mode == "retrieval": elif self.eval_mode == "retrieval":
self.gallery_query_dataloader = None self.gallery_query_dataloader = None
if len(self.config["DataLoader"]["Eval"].keys()) == 1: if len(self.config["DataLoader"]["Eval"].keys()) == 1:
key = list(self.config["DataLoader"]["Eval"].keys())[0] key = list(self.config["DataLoader"]["Eval"].keys())[0]
self.gallery_query_dataloader = build_dataloader( self.gallery_query_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], key, self.device, self.config["DataLoader"]["Eval"], key, self.device,
self.use_dali) False)
else: else:
self.gallery_dataloader = build_dataloader( self.gallery_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Gallery", self.config["DataLoader"]["Eval"], "Gallery",
self.device, self.use_dali) self.device, False)
self.query_dataloader = build_dataloader( self.query_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Query", self.config["DataLoader"]["Eval"], "Query",
self.device, self.use_dali) self.device, False)
# build loss # build loss
if self.mode == "train": if self.mode == "train":
...@@ -339,11 +338,11 @@ class Engine(object): ...@@ -339,11 +338,11 @@ class Engine(object):
) )
self.config["Global"]["seed"] = seed = 42 self.config["Global"]["seed"] = seed = 42
logger.info( logger.info(
f"Set random seed to ({seed} + $PADDLE_TRAINER_ID) for different trainer" f"Set random seed to ({int(seed)} + $PADDLE_TRAINER_ID) for different trainer"
) )
paddle.seed(seed + dist.get_rank()) paddle.seed(int(seed) + dist.get_rank())
np.random.seed(seed + dist.get_rank()) np.random.seed(int(seed) + dist.get_rank())
random.seed(seed + dist.get_rank()) random.seed(int(seed) + dist.get_rank())
# build postprocess for infer # build postprocess for infer
if self.mode == 'infer': if self.mode == 'infer':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册