未验证 提交 b9dae026 编写于 作者: X xuezhong 提交者: GitHub

Merge pull request #108 from PaddlePaddle/fix_squad

Tiny fixes in squad
......@@ -241,7 +241,7 @@ def train(args):
data_path=args.train_file,
batch_size=args.batch_size,
phase='train',
shuffle=False,
shuffle=True,
dev_count=dev_count,
version_2_with_negative=args.version_2_with_negative,
epoch=args.epoch)
......@@ -396,7 +396,7 @@ def train(args):
total_cost, total_num_seqs = [], []
time_begin = time.time()
if steps % args.save_steps == 0:
if steps % args.save_steps == 0 or steps == max_train_steps:
save_path = os.path.join(args.checkpoints,
"step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册