diff --git a/tools/train.py b/tools/train.py index ffe56dfc13863c20913cb8799aebe73c11883d52..85c98eaddfe69c08e0e29921edcb1d26539b871f 100755 --- a/tools/train.py +++ b/tools/train.py @@ -41,7 +41,7 @@ import tools.program as program dist.get_world_size() -def main(config, device, logger, vdl_writer): +def main(config, device, logger, vdl_writer, seed): # init dist environment if config['Global']['distributed']: dist.init_parallel_env() @@ -50,7 +50,7 @@ def main(config, device, logger, vdl_writer): # build dataloader set_signal_handlers() - train_dataloader = build_dataloader(config, 'Train', device, logger) + train_dataloader = build_dataloader(config, 'Train', device, logger, seed) if len(train_dataloader) == 0: logger.error( "No Images in train dataset, please ensure\n" + @@ -61,7 +61,7 @@ def main(config, device, logger, vdl_writer): return if config['Eval']: - valid_dataloader = build_dataloader(config, 'Eval', device, logger) + valid_dataloader = build_dataloader(config, 'Eval', device, logger, seed) else: valid_dataloader = None @@ -224,5 +224,5 @@ if __name__ == '__main__': config, device, logger, vdl_writer = program.preprocess(is_train=True) seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024 set_seed(seed) - main(config, device, logger, vdl_writer) + main(config, device, logger, vdl_writer, seed) # test_reader(config, device, logger)