diff --git a/modules/text/text_generation/ernie_zeus/README.md b/modules/text/text_generation/ernie_zeus/README.md index 920d55293168a1a42bc7fd15738177181f432ac0..c59a01fbb4362e8f9e4ecfe13acbdd5ab6565030 100644 --- a/modules/text/text_generation/ernie_zeus/README.md +++ b/modules/text/text_generation/ernie_zeus/README.md @@ -52,9 +52,12 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注 - ```bash # 作文创作 + # 请设置 '--ak' 和 '--sk' 参数 + # 或者设置 'WENXIN_AK' 和 'WENXIN_SK' 环境变量 + # 更多细节参考下方 API 说明 $ hub run ernie_zeus \ --task composition_generation \ - --text '诚以养德,信以修身' + --text '诚以养德,信以修身' ``` - **参数** @@ -67,7 +70,9 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注 - ```python import paddlehub as hub - # 加载模型 + # 请设置 'ak' 和 'sk' 参数 + # 或者设置 'WENXIN_AK' 和 'WENXIN_SK' 环境变量 + # 更多细节参考下方 API 说明 model = hub.Module(name='ernie_zeus') # 作文创作 @@ -81,17 +86,17 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注 - ### 3. API - ```python def __init__( - api_key: str = '', - secret_key: str = '' + ak: Optional[str] = None, + sk: Optional[str] = None ) -> None ``` - + - 初始化 API - **参数** - - api_key(str): API Key。(可选) - - secret_key(str): Secret Key。(可选) + - sk(Optional[str]): 文心 API AK,默认为 None,即从环境变量 'WENXIN_AK' 中获取; + - ak(Optional[str]): 文心 API SK,默认为 None,即从环境变量 'WENXIN_SK' 中获取。 - ```python def custom_generation( @@ -107,9 +112,7 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注 is_unidirectional: bool = False, min_dec_penalty_text: str = '', logits_bias: int = -10000, - mask_type: str = 'word', - api_key: str = '', - secret_key: str = '' + mask_type: str = 'word' ) -> str ``` - 自定义文本生成 API @@ -122,11 +125,11 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注 - penalty_score(float): 通过对已生成的 token 增加惩罚, 减少重复生成的现象。值越大表示惩罚越大。取值范围 [1.0, 2.0]。 - stop_token(str): 预测结果解析时使用的结束字符串, 碰到对应字符串则直接截断并返回。可以通过设置该值, 过滤掉 few-shot 等场景下模型重复的 cases。 - task_prompt(str): 指定预置的任务模板, 效果更好。 - PARAGRAPH: 引导模型生成一段文章; SENT: 引导模型生成一句话; ENTITY: 引导模型生成词组; - Summarization: 摘要; MT: 翻译; Text2Annotation: 抽取; Correction: 纠错; - QA_MRC: 阅读理解; Dialogue: 对话; QA_Closed_book: 闭卷问答; QA_Multi_Choice: 多选问答; - QuestionGeneration: 问题生成; Paraphrasing: 复述; NLI: 文本蕴含识别; SemanticMatching: 匹配; - Text2SQL: 文本描述转SQL; TextClassification: 文本分类; SentimentClassification: 情感分析; + PARAGRAPH: 引导模型生成一段文章; SENT: 引导模型生成一句话; ENTITY: 引导模型生成词组; + Summarization: 摘要; MT: 翻译; Text2Annotation: 抽取; Correction: 纠错; + QA_MRC: 阅读理解; Dialogue: 对话; QA_Closed_book: 闭卷问答; QA_Multi_Choice: 多选问答; + QuestionGeneration: 问题生成; Paraphrasing: 复述; NLI: 文本蕴含识别; SemanticMatching: 匹配; + Text2SQL: 文本描述转SQL; TextClassification: 文本分类; SentimentClassification: 情感分析; zuowen: 写作文; adtext: 写文案; couplet: 对对联; novel: 写小说; cloze: 文本补全; Misc: 其它任务。 - penalty_text(str): 模型会惩罚该字符串中的 token。通过设置该值, 可以减少某些冗余与异常字符的生成。 - choice_text(str): 模型只能生成该字符串中的 token 的组合。通过设置该值, 可以对某些抽取式任务进行定向调优。 @@ -288,10 +291,14 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注 - text(str): 段落摘要。 ## 四、更新历史 -* 1.0.0 +* 1.0.0 初始发布 +* 1.1.0 + + 移除默认 AK 和 SK + ```shell - $ hub install ernie_zeus == 1.0.0 - ``` \ No newline at end of file + $ hub install ernie_zeus == 1.1.0 + ``` diff --git a/modules/text/text_generation/ernie_zeus/__init__.py b/modules/text/text_generation/ernie_zeus/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/text/text_generation/ernie_zeus/module.py b/modules/text/text_generation/ernie_zeus/module.py index 169c6e81186a69a849427c5f19d585eeebee1470..81921f3b642bd96e9cb87ffe8bf38a077745ddae 100644 --- a/modules/text/text_generation/ernie_zeus/module.py +++ b/modules/text/text_generation/ernie_zeus/module.py @@ -1,11 +1,14 @@ -import json import argparse +import json +import os import requests -from paddlehub.module.module import moduleinfo, runnable + +from paddlehub.module.module import moduleinfo +from paddlehub.module.module import runnable -def get_access_token(ak: str = '', sk: str = '') -> str: +def get_access_token(ak: str = None, sk: str = None) -> str: ''' Get Access Token @@ -16,15 +19,16 @@ def get_access_token(ak: str = '', sk: str = '') -> str: Return: access_token(str): Access Token ''' + ak = ak if ak else os.getenv('WENXIN_AK') + sk = sk if sk else os.getenv('WENXIN_SK') + + assert ak and sk, RuntimeError( + 'Please go to the wenxin official website to apply for AK and SK and set the parameters “ak” and “sk” correctly, or set the environment variables “WENXIN_AK” and “WENXIN_SK”.' + ) + url = 'https://wenxin.baidu.com/younger/portal/api/oauth/token' - headers = { - 'Content-Type': 'application/x-www-form-urlencoded' - } - datas = { - 'grant_type': 'client_credentials', - 'client_id': ak if ak != '' else 'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE', - 'client_secret': sk if sk != '' else 'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs' - } + headers = {'Content-Type': 'application/x-www-form-urlencoded'} + datas = {'grant_type': 'client_credentials', 'client_id': ak, 'client_secret': sk} responses = requests.post(url, datas, headers=headers) @@ -37,16 +41,15 @@ def get_access_token(ak: str = '', sk: str = '') -> str: return results['data'] -@moduleinfo( - name='ernie_zeus', - type='nlp/text_generation', - author='paddlepaddle', - author_email='', - summary='ernie_zeus', - version='1.0.0' -) +@moduleinfo(name='ernie_zeus', + type='nlp/text_generation', + author='paddlepaddle', + author_email='', + summary='ernie_zeus', + version='1.1.0') class ERNIEZeus: - def __init__(self, ak: str = '', sk: str = '') -> None: + + def __init__(self, ak: str = None, sk: str = None) -> None: self.access_token = get_access_token(ak, sk) def custom_generation(self, @@ -74,11 +77,11 @@ class ERNIEZeus: penalty_score(float): 通过对已生成的 token 增加惩罚, 减少重复生成的现象。值越大表示惩罚越大。取值范围 [1.0, 2.0]。 stop_token(str): 预测结果解析时使用的结束字符串, 碰到对应字符串则直接截断并返回。可以通过设置该值, 过滤掉 few-shot 等场景下模型重复的 cases。 task_prompt(str): 指定预置的任务模板, 效果更好。 - PARAGRAPH: 引导模型生成一段文章; SENT: 引导模型生成一句话; ENTITY: 引导模型生成词组; - Summarization: 摘要; MT: 翻译; Text2Annotation: 抽取; Correction: 纠错; - QA_MRC: 阅读理解; Dialogue: 对话; QA_Closed_book: 闭卷问答; QA_Multi_Choice: 多选问答; - QuestionGeneration: 问题生成; Paraphrasing: 复述; NLI: 文本蕴含识别; SemanticMatching: 匹配; - Text2SQL: 文本描述转SQL; TextClassification: 文本分类; SentimentClassification: 情感分析; + PARAGRAPH: 引导模型生成一段文章; SENT: 引导模型生成一句话; ENTITY: 引导模型生成词组; + Summarization: 摘要; MT: 翻译; Text2Annotation: 抽取; Correction: 纠错; + QA_MRC: 阅读理解; Dialogue: 对话; QA_Closed_book: 闭卷问答; QA_Multi_Choice: 多选问答; + QuestionGeneration: 问题生成; Paraphrasing: 复述; NLI: 文本蕴含识别; SemanticMatching: 匹配; + Text2SQL: 文本描述转SQL; TextClassification: 文本分类; SentimentClassification: 情感分析; zuowen: 写作文; adtext: 写文案; couplet: 对对联; novel: 写小说; cloze: 文本补全; Misc: 其它任务。 penalty_text(str): 模型会惩罚该字符串中的 token。通过设置该值, 可以减少某些冗余与异常字符的生成。 choice_text(str): 模型只能生成该字符串中的 token 的组合。通过设置该值, 可以对某些抽取式任务进行定向调优。 @@ -87,14 +90,12 @@ class ERNIEZeus: logits_bias(int): 配合 penalty_text 使用, 对给定的 penalty_text 中的 token 增加一个 logits_bias, 可以通过设置该值屏蔽某些 token 生成的概率。 mask_type(str): 设置该值可以控制模型生成粒度。可选参数为 word, sentence, paragraph。 - Return: + Return: text(str): 生成的文本 ''' url = 'https://wenxin.baidu.com/moduleApi/portal/api/rest/1.0/ernie/3.0.28/zeus?from=paddlehub' access_token = self.access_token - headers = { - 'Content-Type': 'application/x-www-form-urlencoded' - } + headers = {'Content-Type': 'application/x-www-form-urlencoded'} datas = { 'access_token': access_token, 'text': text, @@ -131,21 +132,19 @@ class ERNIEZeus: ''' 文本生成 ''' - return self.custom_generation( - text, - min_dec_len, - seq_len, - topp, - penalty_score, - stop_token='', - task_prompt='PARAGRAPH', - penalty_text='[{[gEND]', - choice_text='', - is_unidirectional=True, - min_dec_penalty_text='。?:![]', - logits_bias=-10, - mask_type='paragraph' - ) + return self.custom_generation(text, + min_dec_len, + seq_len, + topp, + penalty_score, + stop_token='', + task_prompt='PARAGRAPH', + penalty_text='[{[gEND]', + choice_text='', + is_unidirectional=True, + min_dec_penalty_text='。?:![]', + logits_bias=-10, + mask_type='paragraph') def text_summarization(self, text: str, @@ -157,21 +156,19 @@ class ERNIEZeus: 摘要生成 ''' text = "文章:{} 摘要:".format(text) - return self.custom_generation( - text, - min_dec_len, - seq_len, - topp, - penalty_score, - stop_token='', - task_prompt='Summarization', - penalty_text='', - choice_text='', - is_unidirectional=False, - min_dec_penalty_text='', - logits_bias=-10000, - mask_type='word' - ) + return self.custom_generation(text, + min_dec_len, + seq_len, + topp, + penalty_score, + stop_token='', + task_prompt='Summarization', + penalty_text='', + choice_text='', + is_unidirectional=False, + min_dec_penalty_text='', + logits_bias=-10000, + mask_type='word') def copywriting_generation(self, text: str, @@ -183,21 +180,19 @@ class ERNIEZeus: 文案生成 ''' text = "标题:{} 文案:".format(text) - return self.custom_generation( - text, - min_dec_len, - seq_len, - topp, - penalty_score, - stop_token='', - task_prompt='adtext', - penalty_text='', - choice_text='', - is_unidirectional=False, - min_dec_penalty_text='', - logits_bias=-10000, - mask_type='word' - ) + return self.custom_generation(text, + min_dec_len, + seq_len, + topp, + penalty_score, + stop_token='', + task_prompt='adtext', + penalty_text='', + choice_text='', + is_unidirectional=False, + min_dec_penalty_text='', + logits_bias=-10000, + mask_type='word') def novel_continuation(self, text: str, @@ -209,21 +204,19 @@ class ERNIEZeus: 小说续写 ''' text = "上文:{} 下文:".format(text) - return self.custom_generation( - text, - min_dec_len, - seq_len, - topp, - penalty_score, - stop_token='', - task_prompt='gPARAGRAPH', - penalty_text='', - choice_text='', - is_unidirectional=True, - min_dec_penalty_text='。?:![]', - logits_bias=-5, - mask_type='paragraph' - ) + return self.custom_generation(text, + min_dec_len, + seq_len, + topp, + penalty_score, + stop_token='', + task_prompt='gPARAGRAPH', + penalty_text='', + choice_text='', + is_unidirectional=True, + min_dec_penalty_text='。?:![]', + logits_bias=-5, + mask_type='paragraph') def answer_generation(self, text: str, @@ -235,21 +228,19 @@ class ERNIEZeus: 自由问答 ''' text = "问题:{} 回答:".format(text) - return self.custom_generation( - text, - min_dec_len, - seq_len, - topp, - penalty_score, - stop_token='', - task_prompt='qa', - penalty_text='[gEND]', - choice_text='', - is_unidirectional=True, - min_dec_penalty_text='。?:![]', - logits_bias=-5, - mask_type='paragraph' - ) + return self.custom_generation(text, + min_dec_len, + seq_len, + topp, + penalty_score, + stop_token='', + task_prompt='qa', + penalty_text='[gEND]', + choice_text='', + is_unidirectional=True, + min_dec_penalty_text='。?:![]', + logits_bias=-5, + mask_type='paragraph') def couplet_continuation(self, text: str, @@ -261,21 +252,19 @@ class ERNIEZeus: 对联续写 ''' text = "上联:{} 下联:".format(text) - return self.custom_generation( - text, - min_dec_len, - seq_len, - topp, - penalty_score, - stop_token='', - task_prompt='couplet', - penalty_text='', - choice_text='', - is_unidirectional=False, - min_dec_penalty_text='', - logits_bias=-10000, - mask_type='word' - ) + return self.custom_generation(text, + min_dec_len, + seq_len, + topp, + penalty_score, + stop_token='', + task_prompt='couplet', + penalty_text='', + choice_text='', + is_unidirectional=False, + min_dec_penalty_text='', + logits_bias=-10000, + mask_type='word') def composition_generation(self, text: str, @@ -287,21 +276,19 @@ class ERNIEZeus: 作文创作 ''' text = "作文题目:{} 正文:".format(text) - return self.custom_generation( - text, - min_dec_len, - seq_len, - topp, - penalty_score, - stop_token='', - task_prompt='zuowen', - penalty_text='', - choice_text='', - is_unidirectional=False, - min_dec_penalty_text='', - logits_bias=-10000, - mask_type='word' - ) + return self.custom_generation(text, + min_dec_len, + seq_len, + topp, + penalty_score, + stop_token='', + task_prompt='zuowen', + penalty_text='', + choice_text='', + is_unidirectional=False, + min_dec_penalty_text='', + logits_bias=-10000, + mask_type='word') def text_cloze(self, text: str, @@ -312,29 +299,26 @@ class ERNIEZeus: ''' 完形填空 ''' - return self.custom_generation( - text, - min_dec_len, - seq_len, - topp, - penalty_score, - stop_token='', - task_prompt='cloze', - penalty_text='', - choice_text='', - is_unidirectional=False, - min_dec_penalty_text='', - logits_bias=-10000, - mask_type='word' - ) + return self.custom_generation(text, + min_dec_len, + seq_len, + topp, + penalty_score, + stop_token='', + task_prompt='cloze', + penalty_text='', + choice_text='', + is_unidirectional=False, + min_dec_penalty_text='', + logits_bias=-10000, + mask_type='word') @runnable def cmd(self, argvs): - parser = argparse.ArgumentParser( - description="Run the {}".format(self.name), - prog="hub run {}".format(self.name), - usage='%(prog)s', - add_help=True) + parser = argparse.ArgumentParser(description="Run the {}".format(self.name), + prog="hub run {}".format(self.name), + usage='%(prog)s', + add_help=True) parser.add_argument('--text', type=str, required=True) parser.add_argument('--min_dec_len', type=int, default=1) @@ -370,12 +354,7 @@ class ERNIEZeus: kwargs.pop('min_dec_penalty_text') kwargs.pop('logits_bias') kwargs.pop('mask_type') - default_kwargs = { - 'min_dec_len': 1, - 'seq_len': 128, - 'topp': 1.0, - 'penalty_score': 1.0 - } + default_kwargs = {'min_dec_len': 1, 'seq_len': 128, 'topp': 1.0, 'penalty_score': 1.0} else: default_kwargs = { 'min_dec_len': 1, @@ -400,52 +379,3 @@ class ERNIEZeus: kwargs.pop(k) return func(**kwargs) - - -if __name__ == '__main__': - ernie_zeus = ERNIEZeus() - - result = ernie_zeus.custom_generation( - '你好,' - ) - print(result) - - result = ernie_zeus.text_generation( - '给宠物猫起一些可爱的名字。名字:' - ) - print(result) - - result = ernie_zeus.text_summarization( - '在芬兰、瑞典提交“入约”申请近一个月来,北约成员国内部尚未对此达成一致意见。与此同时,俄罗斯方面也多次对北约“第六轮扩张”发出警告。据北约官网显示,北约秘书长斯托尔滕贝格将于本月12日至13日出访瑞典和芬兰,并将分别与两国领导人进行会晤。' - ) - print(result) - - result = ernie_zeus.copywriting_generation( - '芍药香氛的沐浴乳' - ) - print(result) - - result = ernie_zeus.novel_continuation( - '昆仑山可以说是天下龙脉的根源,所有的山脉都可以看作是昆仑的分支。这些分出来的枝枝杈杈,都可以看作是一条条独立的龙脉。' - ) - print(result) - - result = ernie_zeus.answer_generation( - '交朋友的原则是什么?' - ) - print(result) - - result = ernie_zeus.couplet_continuation( - '五湖四海皆春色' - ) - print(result) - - result = ernie_zeus.composition_generation( - '诚以养德,信以修身' - ) - print(result) - - result = ernie_zeus.text_cloze( - '她有着一双[MASK]的眼眸。' - ) - print(result) diff --git a/modules/text/text_generation/ernie_zeus/test.py b/modules/text/text_generation/ernie_zeus/test.py new file mode 100644 index 0000000000000000000000000000000000000000..ae8495c1562a83adfdf6681dc2282b0d85c7a312 --- /dev/null +++ b/modules/text/text_generation/ernie_zeus/test.py @@ -0,0 +1,52 @@ +import unittest + +import paddlehub as hub + + +class TestHubModule(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.module = hub.Module(name='ernie_zeus') + + def test_custom_generation(self): + results = self.module.custom_generation('你好,') + self.assertIsInstance(results, str) + + def test_text_generation(self): + results = self.module.text_generation('给宠物猫起一些可爱的名字。名字:') + self.assertIsInstance(results, str) + + def test_text_summarization(self): + results = self.module.text_summarization( + '在芬兰、瑞典提交“入约”申请近一个月来,北约成员国内部尚未对此达成一致意见。与此同时,俄罗斯方面也多次对北约“第六轮扩张”发出警告。据北约官网显示,北约秘书长斯托尔滕贝格将于本月12日至13日出访瑞典和芬兰,并将分别与两国领导人进行会晤。' + ) + self.assertIsInstance(results, str) + + def test_copywriting_generation(self): + results = self.module.copywriting_generation('芍药香氛的沐浴乳') + self.assertIsInstance(results, str) + + def test_modulenovel_continuation(self): + results = self.module.novel_continuation('昆仑山可以说是天下龙脉的根源,所有的山脉都可以看作是昆仑的分支。这些分出来的枝枝杈杈,都可以看作是一条条独立的龙脉。') + self.assertIsInstance(results, str) + + def test_answer_generation(self): + results = self.module.answer_generation('交朋友的原则是什么?') + self.assertIsInstance(results, str) + + def test_couplet_continuation(self): + results = self.module.couplet_continuation('五湖四海皆春色') + self.assertIsInstance(results, str) + + def test_composition_generation(self): + results = self.module.composition_generation('诚以养德,信以修身') + self.assertIsInstance(results, str) + + def test_text_cloze(self): + results = self.module.text_cloze('她有着一双[MASK]的眼眸。') + self.assertIsInstance(results, str) + + +if __name__ == "__main__": + unittest.main()