From c1e8179feee92919044f7a25d2bb7458dc7f2319 Mon Sep 17 00:00:00 2001 From: wangxiao Date: Wed, 18 Mar 2020 00:31:23 +0800 Subject: [PATCH] fix #68 #71 and other bugs --- README_zh.md | 8 ++++---- examples/classification/run.py | 1 - examples/matching/run.py | 1 - examples/mrc/run.py | 1 - examples/multi-task/run.py | 3 +-- examples/predict/run.py | 1 - examples/tagging/run.py | 5 ++--- paddlepalm/head/cls.py | 2 +- paddlepalm/head/match.py | 2 +- paddlepalm/reader/utils/reader4ernie.py | 2 +- 10 files changed, 10 insertions(+), 16 deletions(-) diff --git a/README_zh.md b/README_zh.md index 06da939..30ecab9 100644 --- a/README_zh.md +++ b/README_zh.md @@ -4,7 +4,7 @@ PaddlePALM (PArallel Learning from Multi-tasks) 是一个灵活,通用且易于使用的NLP大规模预训练和多任务学习框架。 PALM是一个旨在**快速开发高性能NLP模型**的上层框架。 -使用PaddlePALM,可以非常轻松灵活的探索具有多种任务辅助训练的“高鲁棒性”阅读理解模型,基于PALM训练的模型[D-Net](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/Research/MRQA2019-D-NET)在[EMNLP2019国际阅读理解评测](mrqa .github.io)中夺得冠军。 +使用PaddlePALM,可以非常轻松灵活的探索具有多种任务辅助训练的“高鲁棒性”阅读理解模型,基于PALM训练的模型[D-Net](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/Research/MRQA2019-D-NET)在[EMNLP2019国际阅读理解评测](https://mrqa.github.io/)中夺得冠军。

Sample @@ -197,9 +197,9 @@ Available pretrain items: 更多实现细节请见示例: - [Sentiment Classification](https://github.com/PaddlePaddle/PALM/tree/master/examples/classification) -- [Quora Question Pairs matching](https://github.com/PaddlePaddle/PALM/tree/master/examples/matching) -- [Tagging](https://github.com/PaddlePaddle/PALM/tree/master/examples/tagging) -- [SQuAD machine Reading Comprehension](https://github.com/PaddlePaddle/PALM/tree/master/examples/mrc). +- [Question Pairs matching](https://github.com/PaddlePaddle/PALM/tree/master/examples/matching) +- [Named Entity Recognition](https://github.com/PaddlePaddle/PALM/tree/master/examples/tagging) +- [SQuAD-like Machine Reading Comprehension](https://github.com/PaddlePaddle/PALM/tree/master/examples/mrc). #### 多任务学习 diff --git a/examples/classification/run.py b/examples/classification/run.py index 35692e1..cd6ad86 100644 --- a/examples/classification/run.py +++ b/examples/classification/run.py @@ -1,7 +1,6 @@ # coding=utf-8 import paddlepalm as palm import json -from paddlepalm.distribute import gpu_dev_count if __name__ == '__main__': diff --git a/examples/matching/run.py b/examples/matching/run.py index ff551df..cfb6994 100644 --- a/examples/matching/run.py +++ b/examples/matching/run.py @@ -1,7 +1,6 @@ # coding=utf-8 import paddlepalm as palm import json -from paddlepalm.distribute import gpu_dev_count if __name__ == '__main__': diff --git a/examples/mrc/run.py b/examples/mrc/run.py index fc3ee79..4b57bb8 100644 --- a/examples/mrc/run.py +++ b/examples/mrc/run.py @@ -1,7 +1,6 @@ # coding=utf-8 import paddlepalm as palm import json -from paddlepalm.distribute import gpu_dev_count if __name__ == '__main__': diff --git a/examples/multi-task/run.py b/examples/multi-task/run.py index 1876058..fb76f3d 100644 --- a/examples/multi-task/run.py +++ b/examples/multi-task/run.py @@ -1,7 +1,6 @@ # coding=utf-8 import paddlepalm as palm import json -from paddlepalm.distribute import gpu_dev_count if __name__ == '__main__': @@ -80,4 +79,4 @@ if __name__ == '__main__': # save_steps = 10 trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type) # step 8-3: start training - trainer.train(print_steps=print_steps) \ No newline at end of file + trainer.train(print_steps=print_steps) diff --git a/examples/predict/run.py b/examples/predict/run.py index 1b0bc84..dec9749 100644 --- a/examples/predict/run.py +++ b/examples/predict/run.py @@ -1,7 +1,6 @@ # coding=utf-8 import paddlepalm as palm import json -from paddlepalm.distribute import gpu_dev_count if __name__ == '__main__': diff --git a/examples/tagging/run.py b/examples/tagging/run.py index 6228304..e4887b6 100644 --- a/examples/tagging/run.py +++ b/examples/tagging/run.py @@ -1,7 +1,6 @@ # coding=utf-8 import paddlepalm as palm import json -from paddlepalm.distribute import gpu_dev_count if __name__ == '__main__': @@ -64,9 +63,9 @@ if __name__ == '__main__': # step 7: fit prepared reader and data trainer.fit_reader(seq_label_reader) - # # step 8-1*: load pretrained parameters + # step 8-1*: load pretrained parameters trainer.load_pretrain(pre_params) - # # step 8-2*: set saver to save model + # step 8-2*: set saver to save model save_steps = 1951 # print('save_steps: {}'.format(save_steps)) trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type) diff --git a/paddlepalm/head/cls.py b/paddlepalm/head/cls.py index 66117ac..4da3580 100644 --- a/paddlepalm/head/cls.py +++ b/paddlepalm/head/cls.py @@ -98,7 +98,7 @@ class Classify(Head): raise ValueError('argument output_dir not found in config. Please add it into config dict/file.') with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: for i in range(len(self._preds)): - label = np.argmax(np.array(self._preds[i])) + label = int(np.argmax(np.array(self._preds[i]))) result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]} result = json.dumps(result) writer.write(result+'\n') diff --git a/paddlepalm/head/match.py b/paddlepalm/head/match.py index 9df4a1a..38cf1b2 100644 --- a/paddlepalm/head/match.py +++ b/paddlepalm/head/match.py @@ -179,7 +179,7 @@ class Match(Head): with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: for i in range(len(self._preds)): if self._learning_strategy == 'pointwise': - label = np.argmax(np.array(self._preds[i])) + label = int(np.argmax(np.array(self._preds[i]))) result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]} elif self._learning_strategy == 'pairwise': result = {'index': i, 'probs': self._preds[i][0]} diff --git a/paddlepalm/reader/utils/reader4ernie.py b/paddlepalm/reader/utils/reader4ernie.py index d3f7eb2..c2b0874 100644 --- a/paddlepalm/reader/utils/reader4ernie.py +++ b/paddlepalm/reader/utils/reader4ernie.py @@ -37,7 +37,7 @@ from paddlepalm.reader.utils.mlm_batching import prepare_batch_data log = logging.getLogger(__name__) -if six.PY3: +if six.PY3 and hasattr(sys.stdout, 'buffer'): import io sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') -- GitLab