From c5b3422b7ca362b95f7340af1ac580799d8d1ecc Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Thu, 19 Nov 2020 14:42:00 +0800 Subject: [PATCH] add dali v3 (#412) add dali --- tools/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/train.py b/tools/train.py index 3238df79..c441c698 100644 --- a/tools/train.py +++ b/tools/train.py @@ -112,14 +112,14 @@ def main(args): train_reader = Reader(config, 'train')() train_dataloader.set_sample_list_generator(train_reader, place) if config.validate: - valid_reader = Reader(config, 'valid')() - valid_dataloader.set_sample_list_generator(valid_reader, place) + if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0: + valid_reader = Reader(config, 'valid')() + valid_dataloader.set_sample_list_generator(valid_reader, place) compiled_valid_prog = program.compile(config, valid_prog) - else: import dali train_dataloader = dali.train(config) - if config.validate and int(os.getenv("PADDLE_TRAINER_ID", 0)): + if config.validate: if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0: valid_dataloader = dali.val(config) compiled_valid_prog = program.compile(config, valid_prog) -- GitLab