提交 36bbc8e7 编写于 作者: W wangxiao1021

fix bugs

上级 cae26c4c
......@@ -90,10 +90,10 @@ if __name__ == '__main__':
# step 5: build forward graph with backbone and task head
trainer.build_predict_forward(pred_ernie, cls_pred_head)
# step 6: load pretrained model
# step 6: load checkpoint
# model_path = './outputs/ckpt.step'+str(save_steps)
model_path = './outputs/ckpt.step'+str(11980)
pred_ckpt = trainer.load_ckpt(model_path)
trainer.load_ckpt(model_path)
# step 7: fit prepared reader and data
trainer.fit_reader(predict_cls_reader, phase='predict')
......
......@@ -5,7 +5,7 @@ This task is a sentence pair matching task. The following sections detail model
#### Download Pre-trained Model
The pre-training model of this mission is: [ernie-en-base](https://github.com/PaddlePaddle/PALM/tree/r0.3-api).
The pre-training model of this mission is: [ERNIE-v2-en-base](https://github.com/PaddlePaddle/PALM/tree/r0.3-api).
Make sure you have downloaded the required pre-training model in the current folder.
......
......@@ -93,8 +93,8 @@ if __name__ == '__main__':
# step 5: build forward graph with backbone and task head
trainer.build_predict_forward(pred_ernie, match_pred_head)
# step 6: load pretrained model
pred_ckpt = trainer.load_ckpt(pred_model_path)
# step 6: load checkpoint
trainer.load_ckpt(pred_model_path)
# step 7: fit prepared reader and data
trainer.fit_reader(predict_match_reader, phase='predict')
......
......@@ -89,7 +89,7 @@ if __name__ == '__main__':
# step 5: build forward graph with backbone and task head
trainer.build_predict_forward(pred_ernie, mrc_pred_head)
# step 6: load pretrained model
# step 6: load checkpoint
pred_model_path = './outputs/ckpt.step'+str(12160)
trainer.load_ckpt(pred_model_path)
......
......@@ -5,7 +5,7 @@ This task is a slot filling task. During training, the task uses intent determin
#### Pre-trianed Model
The pre-training model of this mission is: [ernie-en-base](https://github.com/PaddlePaddle/PALM/tree/r0.3-api).
The pre-training model of this mission is: [ERNIE-v2-en-base](https://github.com/PaddlePaddle/PALM/tree/r0.3-api).
Make sure you have downloaded the required pre-training model in the current folder.
......
......@@ -45,9 +45,9 @@ if __name__ == '__main__':
# step 5-2: build forward graph with backbone and task head
trainer.build_predict_forward(pred_ernie, cls_pred_head)
# step 6: load pretrained model
# step 6: load checkpoint
pred_model_path = './outputs/ckpt.step4641'
pred_ckpt = trainer.load_ckpt(pred_model_path)
trainer.load_ckpt(pred_model_path)
# step 7: fit prepared reader and data
trainer.fit_reader(predict_cls_reader, phase='predict')
......
......@@ -46,9 +46,9 @@ if __name__ == '__main__':
# step 5-2: build forward graph with backbone and task head
trainer_seq_label.build_predict_forward(pred_ernie, seq_label_pred_head)
# step 6: load pretrained model
# step 6: load checkpoint
pred_model_path = './outputs/ckpt.step4641'
pred_ckpt = trainer_seq_label.load_ckpt(pred_model_path)
trainer_seq_label.load_ckpt(pred_model_path)
# step 7: fit prepared reader and data
trainer_seq_label.fit_reader(predict_seq_label_reader, phase='predict')
......
......@@ -42,8 +42,8 @@ if __name__ == '__main__':
# step 5-2: build forward graph with backbone and task head
trainer.build_predict_forward(pred_ernie, cls_pred_head)
# step 6: load pretrained model
pred_model = trainer.load_predict_model(pre_params)
# step 6: load checkpoint
trainer.load_predict_model(pre_params)
# step 7: fit prepared reader and data
trainer.fit_reader(predict_cls_reader, phase='predict')
......
......@@ -93,7 +93,7 @@ if __name__ == '__main__':
# step 5: build forward graph with backbone and task head
trainer.build_predict_forward(pred_ernie, seq_label_pred_head)
# step 6: load pretrained model
# step 6: load checkpoint
pred_model_path = './outputs/ckpt.step' + str(save_steps)
trainer.load_ckpt(pred_model_path)
......
......@@ -22,14 +22,7 @@ import math
import six
import paddlepalm.tokenizer.ernie_tokenizer as tokenization
import json
import sys
import io
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding('utf-8')
else:
import importlib
importlib.reload(sys)
RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册