diff --git a/ppcls/static/train.py b/ppcls/static/train.py index eed68d38a0e42fecd761abf156ae64d573b3ce39..53566267e143e5ff35f7f705bfee69cb72147ae7 100755 --- a/ppcls/static/train.py +++ b/ppcls/static/train.py @@ -171,7 +171,10 @@ def main(args): compiled_train_prog = train_prog if eval_dataloader is not None: - compiled_eval_prog = program.compile(config, eval_prog) + if not global_config.get("is_distributed", True): + compiled_eval_prog = program.compile(config, eval_prog) + else: + compiled_eval_prog = eval_prog for epoch_id in range(global_config["epochs"]): # 1. train with train dataset