未验证 提交 04c08043 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Merge pull request #62 from wangxiao1021/api

update mrc
...@@ -103,8 +103,8 @@ You can easily re-produce following competitive results with minor codes, which ...@@ -103,8 +103,8 @@ You can easily re-produce following competitive results with minor codes, which
<td>94.9</td> <td>94.9</td>
<td>94.5</td> <td>94.5</td>
<td>94.7</td> <td>94.7</td>
<td>96.3</td> <td>64.3</td>
<td>84.0</td> <td>85.2</td>
</tr> </tr>
</tbody> </tbody>
......
...@@ -94,5 +94,5 @@ The evaluation results are as follows: ...@@ -94,5 +94,5 @@ The evaluation results are as follows:
``` ```
data_num: 3219 data_num: 3219
em_sroce: 0.963031997515, f1: 83.9865402973 em_sroce: 64.3367505436, f1: 85.1781896843
``` ```
...@@ -9,7 +9,7 @@ if __name__ == '__main__': ...@@ -9,7 +9,7 @@ if __name__ == '__main__':
# configs # configs
max_seqlen = 512 max_seqlen = 512
batch_size = 8 batch_size = 8
num_epochs = 8 num_epochs = 2
lr = 3e-5 lr = 3e-5
doc_stride = 128 doc_stride = 128
max_query_len = 64 max_query_len = 64
...@@ -64,8 +64,7 @@ if __name__ == '__main__': ...@@ -64,8 +64,7 @@ if __name__ == '__main__':
# step 8-1*: load pretrained parameters # step 8-1*: load pretrained parameters
trainer.load_pretrain(pre_params) trainer.load_pretrain(pre_params)
# step 8-2*: set saver to save model # step 8-2*: set saver to save model
# save_steps = (n_steps-8) // 4 save_steps = 3040
save_steps = 1520
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type) trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# step 8-3: start training # step 8-3: start training
trainer.train(print_steps=print_steps) trainer.train(print_steps=print_steps)
...@@ -90,7 +89,7 @@ if __name__ == '__main__': ...@@ -90,7 +89,7 @@ if __name__ == '__main__':
trainer.build_predict_forward(pred_ernie, mrc_pred_head) trainer.build_predict_forward(pred_ernie, mrc_pred_head)
# step 6: load checkpoint # step 6: load checkpoint
pred_model_path = './outputs/ckpt.step'+str(12160) pred_model_path = './outputs/ckpt.step'+str(3040)
trainer.load_ckpt(pred_model_path) trainer.load_ckpt(pred_model_path)
# step 7: fit prepared reader and data # step 7: fit prepared reader and data
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册