From d9d160a2fcb3d7e257e78433f77d1bbe847f064d Mon Sep 17 00:00:00 2001 From: kinghuin Date: Thu, 13 Aug 2020 10:46:07 +0800 Subject: [PATCH] fix ernie_gen bug. and plato and ddparser config (#817) --- .../syntactic_analysis/DDParser/README.md | 24 ++++++++------- .../syntactic_analysis/DDParser/module.py | 30 +++++++++---------- .../ernie_gen_couplet/README.md | 4 +++ .../ernie_gen_couplet/module.py | 6 ++-- .../ernie_gen_poetry/README.md | 4 +++ .../ernie_gen_poetry/module.py | 6 ++-- .../text_generation/plato2_en_base/README.md | 2 +- .../text_generation/plato2_en_large/README.md | 4 ++- 8 files changed, 45 insertions(+), 35 deletions(-) diff --git a/hub_module/modules/text/syntactic_analysis/DDParser/README.md b/hub_module/modules/text/syntactic_analysis/DDParser/README.md index fed9e668..53209b3f 100644 --- a/hub_module/modules/text/syntactic_analysis/DDParser/README.md +++ b/hub_module/modules/text/syntactic_analysis/DDParser/README.md @@ -6,17 +6,18 @@ $ hub run ddparser --input_text="百度是一家高科技公司" # API -## parse(texts=[]) +## parse(texts=[], return\_visual=False) 依存分析接口,输入文本,输出依存关系。 **参数** -* texts(list[list[str] or list[str]]): 待预测数据。各元素可以是未分词的字符串,也可以是已分词的token列表。 +* texts(list\[list\[str\] or list\[str\]]): 待预测数据。各元素可以是未分词的字符串,也可以是已分词的token列表。 +* return\_visual(bool): 是否返回依存分析可视化结果。如果为True,返回结果中将包含'visual'字段。 **返回** -* results(list[dict]): 依存分析结果。每个元素都是dict类型,包含以下信息: +* results(list\[dict\]): 依存分析结果。每个元素都是dict类型,包含以下信息: ```python { 'word': list[str], 分词结果。 @@ -34,9 +35,9 @@ $ hub run ddparser --input_text="百度是一家高科技公司" **参数** -* word(list[list[str]): 分词信息。 -* head(list[int]): 当前成分其支配者的id。 -* deprel(list[str]): 当前成分与支配者的依存关系。 +* word(list\[list\[str\]\): 分词信息。 +* head(list\[int\]): 当前成分其支配者的id。 +* deprel(list\[str\]): 当前成分与支配者的依存关系。 **返回** @@ -55,11 +56,12 @@ results = module.parse(texts=test_text) print(results) test_tokens = [['百度', '是', '一家', '高科技', '公司']] -results = module.parse(texts=test_text) +results = module.parse(texts=test_text, return_visual = True) print(results) result = results[0] data = module.visualize(result['word'],result['head'],result['deprel']) +# or data = result['visual'] cv2.imwrite('test.jpg',data) ``` @@ -81,7 +83,7 @@ Loading ddparser successful. 这样就完成了服务化API的部署,默认端口号为8866。 -**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。 +**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。 ## 第二步:发送预测请求 @@ -105,12 +107,12 @@ data = {"texts": text, "return_visual": return_visual} url = "http://0.0.0.0:8866/predict/ddparser" headers = {"Content-Type": "application/json"} r = requests.post(url=url, headers=headers, data=json.dumps(data)) -results, visuals = r.json()['results'] +results = r.json()['results'] for i in range(len(results)): - print(results[i]) + print(results[i]['word']) # 不同于本地调用parse接口,serving返回的图像是list类型的,需要先用numpy加载再显示或保存。 - cv2.imwrite('%s.jpg'%i, np.array(visuals[i])) + cv2.imwrite('%s.jpg'%i, np.array(results[i]['visual'])) ``` 关于PaddleHub Serving更多信息参考[服务部署](https://github.com/PaddlePaddle/PaddleHub/blob/release/v1.6/docs/tutorial/serving.md) diff --git a/hub_module/modules/text/syntactic_analysis/DDParser/module.py b/hub_module/modules/text/syntactic_analysis/DDParser/module.py index f9afa0bb..2d451703 100644 --- a/hub_module/modules/text/syntactic_analysis/DDParser/module.py +++ b/hub_module/modules/text/syntactic_analysis/DDParser/module.py @@ -32,15 +32,16 @@ class ddparser(hub.NLPPredictionModule): """ self.ddp = DDParserModel(prob=True, use_pos=True) self.font = font_manager.FontProperties( - fname=os.path.join(self.directory, "SimHei.ttf")) + fname=os.path.join(self.directory, "SourceHanSans-Regular.ttf")) @serving def serving_parse(self, texts=[], return_visual=False): - results, visuals = self.parse(texts, return_visual) - for i, visual in enumerate(visuals): - visuals[i] = visual.tolist() + results = self.parse(texts, return_visual) + if return_visual: + for i, result in enumerate(results): + result['visual'] = result['visual'].tolist() - return results, visuals + return results def parse(self, texts=[], return_visual=False): """ @@ -57,11 +58,9 @@ class ddparser(hub.NLPPredictionModule): 'head': list[int], the head ids. 'deprel': list[str], the dependency relation. 'prob': list[float], the prediction probility of the dependency relation. - 'postag': list[str], the POS tag. If the element of the texts is list, the key 'postag' will not be returned. + 'postag': list[str], the POS tag. If the element of the texts is list, the key 'postag' will not return. + 'visual' : list[numpy.array]: the dependency visualization. Use cv2.imshow to show or cv2.imwrite to save it. If return_visual=False, it will not return. } - - visuals : list[numpy.array]: the dependency visualization. Use cv2.imshow to show or cv2.imwrite to save it. If return_visual=False, it will not be empty. - """ if not texts: @@ -73,13 +72,11 @@ class ddparser(hub.NLPPredictionModule): else: raise ValueError("All of the elements should be string or list") results = do_parse(texts) - visuals = [] if return_visual: for result in results: - visuals.append( - self.visualize(result['word'], result['head'], - result['deprel'])) - return results, visuals + result['visual'] = self.visualize( + result['word'], result['head'], result['deprel']) + return results @runnable def run_cmd(self, argvs): @@ -194,10 +191,11 @@ if __name__ == "__main__": results = module.parse(texts=test_text) print(results) test_tokens = [['百度', '是', '一家', '高科技', '公司']] - results = module.parse(texts=test_text) + results = module.parse(texts=test_text, return_visual=True) print(results) result = results[0] data = module.visualize(result['word'], result['head'], result['deprel']) import cv2 import numpy as np - cv2.imwrite('test.jpg', np.array(data)) + cv2.imwrite('test1.jpg', data) + cv2.imwrite('test2.jpg', result['visual']) diff --git a/hub_module/modules/text/text_generation/ernie_gen_couplet/README.md b/hub_module/modules/text/text_generation/ernie_gen_couplet/README.md index fc120e86..74d4f01d 100644 --- a/hub_module/modules/text/text_generation/ernie_gen_couplet/README.md +++ b/hub_module/modules/text/text_generation/ernie_gen_couplet/README.md @@ -97,3 +97,7 @@ paddlehub >= 1.7.0 * 1.0.0 初始发布 + +* 1.0.1 + + 修复windows中的编码问题 diff --git a/hub_module/modules/text/text_generation/ernie_gen_couplet/module.py b/hub_module/modules/text/text_generation/ernie_gen_couplet/module.py index 1d698412..341c44c4 100644 --- a/hub_module/modules/text/text_generation/ernie_gen_couplet/module.py +++ b/hub_module/modules/text/text_generation/ernie_gen_couplet/module.py @@ -35,7 +35,7 @@ from ernie_gen_couplet.model.modeling_ernie_gen import ErnieModelForGeneration @moduleinfo( name="ernie_gen_couplet", - version="1.0.0", + version="1.0.1", summary= "ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning. This module has fine-tuned for couplet generation task.", author="baidu-nlp", @@ -50,10 +50,10 @@ class ErnieGen(hub.NLPPredictionModule): assets_path = os.path.join(self.directory, "assets") gen_checkpoint_path = os.path.join(assets_path, "ernie_gen_couplet") ernie_cfg_path = os.path.join(assets_path, 'ernie_config.json') - with open(ernie_cfg_path) as ernie_cfg_file: + with open(ernie_cfg_path, encoding='utf8') as ernie_cfg_file: ernie_cfg = dict(json.loads(ernie_cfg_file.read())) ernie_vocab_path = os.path.join(assets_path, 'vocab.txt') - with open(ernie_vocab_path) as ernie_vocab_file: + with open(ernie_vocab_path, encoding='utf8') as ernie_vocab_file: ernie_vocab = { j.strip().split('\t')[0]: i for i, j in enumerate(ernie_vocab_file.readlines()) diff --git a/hub_module/modules/text/text_generation/ernie_gen_poetry/README.md b/hub_module/modules/text/text_generation/ernie_gen_poetry/README.md index 3c5a85d8..bb0b29c8 100644 --- a/hub_module/modules/text/text_generation/ernie_gen_poetry/README.md +++ b/hub_module/modules/text/text_generation/ernie_gen_poetry/README.md @@ -97,3 +97,7 @@ paddlehub >= 1.7.0 * 1.0.0 初始发布 + +* 1.0.1 + + 修复windows中的编码问题 diff --git a/hub_module/modules/text/text_generation/ernie_gen_poetry/module.py b/hub_module/modules/text/text_generation/ernie_gen_poetry/module.py index 33fc9e25..04e84853 100644 --- a/hub_module/modules/text/text_generation/ernie_gen_poetry/module.py +++ b/hub_module/modules/text/text_generation/ernie_gen_poetry/module.py @@ -35,7 +35,7 @@ from ernie_gen_poetry.model.modeling_ernie_gen import ErnieModelForGeneration @moduleinfo( name="ernie_gen_poetry", - version="1.0.0", + version="1.0.1", summary= "ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning. This module has fine-tuned for poetry generation task.", author="baidu-nlp", @@ -50,10 +50,10 @@ class ErnieGen(hub.NLPPredictionModule): assets_path = os.path.join(self.directory, "assets") gen_checkpoint_path = os.path.join(assets_path, "ernie_gen_poetry") ernie_cfg_path = os.path.join(assets_path, 'ernie_config.json') - with open(ernie_cfg_path) as ernie_cfg_file: + with open(ernie_cfg_path, encoding='utf8') as ernie_cfg_file: ernie_cfg = dict(json.loads(ernie_cfg_file.read())) ernie_vocab_path = os.path.join(assets_path, 'vocab.txt') - with open(ernie_vocab_path) as ernie_vocab_file: + with open(ernie_vocab_path, encoding='utf8') as ernie_vocab_file: ernie_vocab = { j.strip().split('\t')[0]: i for i, j in enumerate(ernie_vocab_file.readlines()) diff --git a/hub_module/modules/text/text_generation/plato2_en_base/README.md b/hub_module/modules/text/text_generation/plato2_en_base/README.md index c886502d..46db0016 100644 --- a/hub_module/modules/text/text_generation/plato2_en_base/README.md +++ b/hub_module/modules/text/text_generation/plato2_en_base/README.md @@ -10,7 +10,7 @@ PLATO2是一个超大规模生成式对话系统模型。它承袭了PLATO隐变 ## 命令行预测 ```shell -$ hub run plato2_en_base --input_text="Hello, how are you" --use_gpu +$ hub run plato2_en_base --input_text="Hello, how are you" ``` ## API diff --git a/hub_module/modules/text/text_generation/plato2_en_large/README.md b/hub_module/modules/text/text_generation/plato2_en_large/README.md index 23502136..91c63323 100644 --- a/hub_module/modules/text/text_generation/plato2_en_large/README.md +++ b/hub_module/modules/text/text_generation/plato2_en_large/README.md @@ -7,10 +7,12 @@ PLATO2是一个超大规模生成式对话系统模型。它承袭了PLATO隐变 更多详情参考论文[PLATO-2: Towards Building an Open-Domain Chatbot via Curriculum Learning](https://arxiv.org/abs/2006.16779) +**注:plato2\_en\_large 模型大小12GB,下载时间较长,请耐心等候。运行此模型要求显存至少16GB。** + ## 命令行预测 ```shell -$ hub run plato2_en_large --input_text="Hello, how are you" --use_gpu +$ hub run plato2_en_large --input_text="Hello, how are you" ``` ## API -- GitLab