diff --git a/PaddleNLP/language_representations_kit/BERT/train.py b/PaddleNLP/language_representations_kit/BERT/train.py index f176638c8295007d4efaccd1f3925221fd142484..c99b3db2de494909bae500b235b725e3c08bf36f 100644 --- a/PaddleNLP/language_representations_kit/BERT/train.py +++ b/PaddleNLP/language_representations_kit/BERT/train.py @@ -322,9 +322,10 @@ def train(args): exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope build_strategy = fluid.BuildStrategy() - if sys.platform == "win32" and nccl2_num_trainers > 1: + if not sys.platform == "win32": + build_strategy.num_trainers = nccl2_num_trainers + elif nccl2_num_trainers > 1: raise ValueError("Windows platform doesn't support distributed training!") - build_strategy.num_trainers = nccl2_num_trainers build_strategy.trainer_id = nccl2_trainer_id # use_ngraph is for CPU only, please refer to README_ngraph.md for details use_ngraph = os.getenv('FLAGS_use_ngraph')