未验证 提交 b9dae026 编写于 作者: X xuezhong 提交者: GitHub

Merge pull request #108 from PaddlePaddle/fix_squad

Tiny fixes in squad
...@@ -241,7 +241,7 @@ def train(args): ...@@ -241,7 +241,7 @@ def train(args):
data_path=args.train_file, data_path=args.train_file,
batch_size=args.batch_size, batch_size=args.batch_size,
phase='train', phase='train',
shuffle=False, shuffle=True,
dev_count=dev_count, dev_count=dev_count,
version_2_with_negative=args.version_2_with_negative, version_2_with_negative=args.version_2_with_negative,
epoch=args.epoch) epoch=args.epoch)
...@@ -396,7 +396,7 @@ def train(args): ...@@ -396,7 +396,7 @@ def train(args):
total_cost, total_num_seqs = [], [] total_cost, total_num_seqs = [], []
time_begin = time.time() time_begin = time.time()
if steps % args.save_steps == 0: if steps % args.save_steps == 0 or steps == max_train_steps:
save_path = os.path.join(args.checkpoints, save_path = os.path.join(args.checkpoints,
"step_" + str(steps)) "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program) fluid.io.save_persistables(exe, save_path, train_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册