diff --git a/BERT/run_squad.py b/BERT/run_squad.py index 9607a8be0a79aaf77b4f1f4df9fe91b3f1ff28ce..f93c7bfe69770f46b105d5c15f3fc7b9ca2afe52 100644 --- a/BERT/run_squad.py +++ b/BERT/run_squad.py @@ -241,7 +241,7 @@ def train(args): data_path=args.train_file, batch_size=args.batch_size, phase='train', - shuffle=False, + shuffle=True, dev_count=dev_count, version_2_with_negative=args.version_2_with_negative, epoch=args.epoch) @@ -396,7 +396,7 @@ def train(args): total_cost, total_num_seqs = [], [] time_begin = time.time() - if steps % args.save_steps == 0: + if steps % args.save_steps == 0 or steps == max_train_steps: save_path = os.path.join(args.checkpoints, "step_" + str(steps)) fluid.io.save_persistables(exe, save_path, train_program)