提交 c1e8179f 编写于 作者: W wangxiao

fix #68 #71 and other bugs

上级 20b8873c
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
PaddlePALM (PArallel Learning from Multi-tasks) 是一个灵活,通用且易于使用的NLP大规模预训练和多任务学习框架。 PALM是一个旨在**快速开发高性能NLP模型**的上层框架。 PaddlePALM (PArallel Learning from Multi-tasks) 是一个灵活,通用且易于使用的NLP大规模预训练和多任务学习框架。 PALM是一个旨在**快速开发高性能NLP模型**的上层框架。
使用PaddlePALM,可以非常轻松灵活的探索具有多种任务辅助训练的“高鲁棒性”阅读理解模型,基于PALM训练的模型[D-Net](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/Research/MRQA2019-D-NET)[EMNLP2019国际阅读理解评测](mrqa .github.io)中夺得冠军。 使用PaddlePALM,可以非常轻松灵活的探索具有多种任务辅助训练的“高鲁棒性”阅读理解模型,基于PALM训练的模型[D-Net](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/Research/MRQA2019-D-NET)[EMNLP2019国际阅读理解评测](https://mrqa.github.io/)中夺得冠军。
<p align="center"> <p align="center">
<img src="https://tva1.sinaimg.cn/large/006tNbRwly1gbjkuuwrmlj30hs0hzdh2.jpg" alt="Sample" width="300" height="333"> <img src="https://tva1.sinaimg.cn/large/006tNbRwly1gbjkuuwrmlj30hs0hzdh2.jpg" alt="Sample" width="300" height="333">
...@@ -197,9 +197,9 @@ Available pretrain items: ...@@ -197,9 +197,9 @@ Available pretrain items:
更多实现细节请见示例: 更多实现细节请见示例:
- [Sentiment Classification](https://github.com/PaddlePaddle/PALM/tree/master/examples/classification) - [Sentiment Classification](https://github.com/PaddlePaddle/PALM/tree/master/examples/classification)
- [Quora Question Pairs matching](https://github.com/PaddlePaddle/PALM/tree/master/examples/matching) - [Question Pairs matching](https://github.com/PaddlePaddle/PALM/tree/master/examples/matching)
- [Tagging](https://github.com/PaddlePaddle/PALM/tree/master/examples/tagging) - [Named Entity Recognition](https://github.com/PaddlePaddle/PALM/tree/master/examples/tagging)
- [SQuAD machine Reading Comprehension](https://github.com/PaddlePaddle/PALM/tree/master/examples/mrc). - [SQuAD-like Machine Reading Comprehension](https://github.com/PaddlePaddle/PALM/tree/master/examples/mrc).
#### 多任务学习 #### 多任务学习
......
# coding=utf-8 # coding=utf-8
import paddlepalm as palm import paddlepalm as palm
import json import json
from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__': if __name__ == '__main__':
......
# coding=utf-8 # coding=utf-8
import paddlepalm as palm import paddlepalm as palm
import json import json
from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__': if __name__ == '__main__':
......
# coding=utf-8 # coding=utf-8
import paddlepalm as palm import paddlepalm as palm
import json import json
from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__': if __name__ == '__main__':
......
# coding=utf-8 # coding=utf-8
import paddlepalm as palm import paddlepalm as palm
import json import json
from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__': if __name__ == '__main__':
...@@ -80,4 +79,4 @@ if __name__ == '__main__': ...@@ -80,4 +79,4 @@ if __name__ == '__main__':
# save_steps = 10 # save_steps = 10
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type) trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# step 8-3: start training # step 8-3: start training
trainer.train(print_steps=print_steps) trainer.train(print_steps=print_steps)
\ No newline at end of file
# coding=utf-8 # coding=utf-8
import paddlepalm as palm import paddlepalm as palm
import json import json
from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__': if __name__ == '__main__':
......
# coding=utf-8 # coding=utf-8
import paddlepalm as palm import paddlepalm as palm
import json import json
from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__': if __name__ == '__main__':
...@@ -64,9 +63,9 @@ if __name__ == '__main__': ...@@ -64,9 +63,9 @@ if __name__ == '__main__':
# step 7: fit prepared reader and data # step 7: fit prepared reader and data
trainer.fit_reader(seq_label_reader) trainer.fit_reader(seq_label_reader)
# # step 8-1*: load pretrained parameters # step 8-1*: load pretrained parameters
trainer.load_pretrain(pre_params) trainer.load_pretrain(pre_params)
# # step 8-2*: set saver to save model # step 8-2*: set saver to save model
save_steps = 1951 save_steps = 1951
# print('save_steps: {}'.format(save_steps)) # print('save_steps: {}'.format(save_steps))
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type) trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
......
...@@ -98,7 +98,7 @@ class Classify(Head): ...@@ -98,7 +98,7 @@ class Classify(Head):
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.') raise ValueError('argument output_dir not found in config. Please add it into config dict/file.')
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for i in range(len(self._preds)): for i in range(len(self._preds)):
label = np.argmax(np.array(self._preds[i])) label = int(np.argmax(np.array(self._preds[i])))
result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]} result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]}
result = json.dumps(result) result = json.dumps(result)
writer.write(result+'\n') writer.write(result+'\n')
......
...@@ -179,7 +179,7 @@ class Match(Head): ...@@ -179,7 +179,7 @@ class Match(Head):
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for i in range(len(self._preds)): for i in range(len(self._preds)):
if self._learning_strategy == 'pointwise': if self._learning_strategy == 'pointwise':
label = np.argmax(np.array(self._preds[i])) label = int(np.argmax(np.array(self._preds[i])))
result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]} result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]}
elif self._learning_strategy == 'pairwise': elif self._learning_strategy == 'pairwise':
result = {'index': i, 'probs': self._preds[i][0]} result = {'index': i, 'probs': self._preds[i][0]}
......
...@@ -37,7 +37,7 @@ from paddlepalm.reader.utils.mlm_batching import prepare_batch_data ...@@ -37,7 +37,7 @@ from paddlepalm.reader.utils.mlm_batching import prepare_batch_data
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
if six.PY3: if six.PY3 and hasattr(sys.stdout, 'buffer'):
import io import io
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册