未验证 提交 2a4d5515 编写于 作者: R RuohengMa 提交者: GitHub

Fix seed passing issue of build_dataloader (#10614)

上级 681467d4
......@@ -41,7 +41,7 @@ import tools.program as program
dist.get_world_size()
def main(config, device, logger, vdl_writer):
def main(config, device, logger, vdl_writer, seed):
# init dist environment
if config['Global']['distributed']:
dist.init_parallel_env()
......@@ -50,7 +50,7 @@ def main(config, device, logger, vdl_writer):
# build dataloader
set_signal_handlers()
train_dataloader = build_dataloader(config, 'Train', device, logger)
train_dataloader = build_dataloader(config, 'Train', device, logger, seed)
if len(train_dataloader) == 0:
logger.error(
"No Images in train dataset, please ensure\n" +
......@@ -61,7 +61,7 @@ def main(config, device, logger, vdl_writer):
return
if config['Eval']:
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
valid_dataloader = build_dataloader(config, 'Eval', device, logger, seed)
else:
valid_dataloader = None
......@@ -224,5 +224,5 @@ if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess(is_train=True)
seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024
set_seed(seed)
main(config, device, logger, vdl_writer)
main(config, device, logger, vdl_writer, seed)
# test_reader(config, device, logger)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册