diff --git a/BERT/train.py b/BERT/train.py index afc2f91ddda027867d79314bd1528d60f51167ef..55d362b612f56bd4ff6596e4df875cb539805e69 100644 --- a/BERT/train.py +++ b/BERT/train.py @@ -335,7 +335,7 @@ def train(args): lm_cost = [] acc = [] time_begin = time.time() - while True: + while steps < args.num_train_steps: try: steps += nccl2_num_trainers skip_steps = args.skip_steps * nccl2_num_trainers