未验证 提交 a900ca05 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update ERNIE Zeus (#2127)

* update ERNIE Zeus

* update README

* update README
上级 088b37d6
......@@ -52,9 +52,12 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
- ```bash
# 作文创作
# 请设置 '--ak' 和 '--sk' 参数
# 或者设置 'WENXIN_AK' 和 'WENXIN_SK' 环境变量
# 更多细节参考下方 API 说明
$ hub run ernie_zeus \
--task composition_generation \
--text '诚以养德,信以修身'
--text '诚以养德,信以修身'
```
- **参数**
......@@ -67,7 +70,9 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
- ```python
import paddlehub as hub
# 加载模型
# 请设置 'ak' 'sk' 参数
# 或者设置 'WENXIN_AK' 'WENXIN_SK' 环境变量
# 更多细节参考下方 API 说明
model = hub.Module(name='ernie_zeus')
# 作文创作
......@@ -81,17 +86,17 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
- ### 3. API
- ```python
def __init__(
api_key: str = '',
secret_key: str = ''
ak: Optional[str] = None,
sk: Optional[str] = None
) -> None
```
- 初始化 API
- **参数**
- api_key(str): API Key。(可选)
- secret_key(str): Secret Key。(可选)
- sk(Optional[str]): 文心 API AK,默认为 None,即从环境变量 'WENXIN_AK' 中获取;
- ak(Optional[str]): 文心 API SK,默认为 None,即从环境变量 'WENXIN_SK' 中获取。
- ```python
def custom_generation(
......@@ -107,9 +112,7 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
is_unidirectional: bool = False,
min_dec_penalty_text: str = '',
logits_bias: int = -10000,
mask_type: str = 'word',
api_key: str = '',
secret_key: str = ''
mask_type: str = 'word'
) -> str
```
- 自定义文本生成 API
......@@ -122,11 +125,11 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
- penalty_score(float): 通过对已生成的 token 增加惩罚, 减少重复生成的现象。值越大表示惩罚越大。取值范围 [1.0, 2.0]。
- stop_token(str): 预测结果解析时使用的结束字符串, 碰到对应字符串则直接截断并返回。可以通过设置该值, 过滤掉 few-shot 等场景下模型重复的 cases。
- task_prompt(str): 指定预置的任务模板, 效果更好。
PARAGRAPH: 引导模型生成一段文章; SENT: 引导模型生成一句话; ENTITY: 引导模型生成词组;
Summarization: 摘要; MT: 翻译; Text2Annotation: 抽取; Correction: 纠错;
QA_MRC: 阅读理解; Dialogue: 对话; QA_Closed_book: 闭卷问答; QA_Multi_Choice: 多选问答;
QuestionGeneration: 问题生成; Paraphrasing: 复述; NLI: 文本蕴含识别; SemanticMatching: 匹配;
Text2SQL: 文本描述转SQL; TextClassification: 文本分类; SentimentClassification: 情感分析;
PARAGRAPH: 引导模型生成一段文章; SENT: 引导模型生成一句话; ENTITY: 引导模型生成词组;
Summarization: 摘要; MT: 翻译; Text2Annotation: 抽取; Correction: 纠错;
QA_MRC: 阅读理解; Dialogue: 对话; QA_Closed_book: 闭卷问答; QA_Multi_Choice: 多选问答;
QuestionGeneration: 问题生成; Paraphrasing: 复述; NLI: 文本蕴含识别; SemanticMatching: 匹配;
Text2SQL: 文本描述转SQL; TextClassification: 文本分类; SentimentClassification: 情感分析;
zuowen: 写作文; adtext: 写文案; couplet: 对对联; novel: 写小说; cloze: 文本补全; Misc: 其它任务。
- penalty_text(str): 模型会惩罚该字符串中的 token。通过设置该值, 可以减少某些冗余与异常字符的生成。
- choice_text(str): 模型只能生成该字符串中的 token 的组合。通过设置该值, 可以对某些抽取式任务进行定向调优。
......@@ -288,10 +291,14 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
- text(str): 段落摘要。
## 四、更新历史
* 1.0.0
* 1.0.0
初始发布
* 1.1.0
移除默认 AK 和 SK
```shell
$ hub install ernie_zeus == 1.0.0
```
\ No newline at end of file
$ hub install ernie_zeus == 1.1.0
```
import json
import argparse
import json
import os
import requests
from paddlehub.module.module import moduleinfo, runnable
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
def get_access_token(ak: str = '', sk: str = '') -> str:
def get_access_token(ak: str = None, sk: str = None) -> str:
'''
Get Access Token
......@@ -16,15 +19,16 @@ def get_access_token(ak: str = '', sk: str = '') -> str:
Return:
access_token(str): Access Token
'''
ak = ak if ak else os.getenv('WENXIN_AK')
sk = sk if sk else os.getenv('WENXIN_SK')
assert ak and sk, RuntimeError(
'Please go to the wenxin official website to apply for AK and SK and set the parameters “ak” and “sk” correctly, or set the environment variables “WENXIN_AK” and “WENXIN_SK”.'
)
url = 'https://wenxin.baidu.com/younger/portal/api/oauth/token'
headers = {
'Content-Type': 'application/x-www-form-urlencoded'
}
datas = {
'grant_type': 'client_credentials',
'client_id': ak if ak != '' else 'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE',
'client_secret': sk if sk != '' else 'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs'
}
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
datas = {'grant_type': 'client_credentials', 'client_id': ak, 'client_secret': sk}
responses = requests.post(url, datas, headers=headers)
......@@ -37,16 +41,15 @@ def get_access_token(ak: str = '', sk: str = '') -> str:
return results['data']
@moduleinfo(
name='ernie_zeus',
type='nlp/text_generation',
author='paddlepaddle',
author_email='',
summary='ernie_zeus',
version='1.0.0'
)
@moduleinfo(name='ernie_zeus',
type='nlp/text_generation',
author='paddlepaddle',
author_email='',
summary='ernie_zeus',
version='1.1.0')
class ERNIEZeus:
def __init__(self, ak: str = '', sk: str = '') -> None:
def __init__(self, ak: str = None, sk: str = None) -> None:
self.access_token = get_access_token(ak, sk)
def custom_generation(self,
......@@ -74,11 +77,11 @@ class ERNIEZeus:
penalty_score(float): 通过对已生成的 token 增加惩罚, 减少重复生成的现象。值越大表示惩罚越大。取值范围 [1.0, 2.0]。
stop_token(str): 预测结果解析时使用的结束字符串, 碰到对应字符串则直接截断并返回。可以通过设置该值, 过滤掉 few-shot 等场景下模型重复的 cases。
task_prompt(str): 指定预置的任务模板, 效果更好。
PARAGRAPH: 引导模型生成一段文章; SENT: 引导模型生成一句话; ENTITY: 引导模型生成词组;
Summarization: 摘要; MT: 翻译; Text2Annotation: 抽取; Correction: 纠错;
QA_MRC: 阅读理解; Dialogue: 对话; QA_Closed_book: 闭卷问答; QA_Multi_Choice: 多选问答;
QuestionGeneration: 问题生成; Paraphrasing: 复述; NLI: 文本蕴含识别; SemanticMatching: 匹配;
Text2SQL: 文本描述转SQL; TextClassification: 文本分类; SentimentClassification: 情感分析;
PARAGRAPH: 引导模型生成一段文章; SENT: 引导模型生成一句话; ENTITY: 引导模型生成词组;
Summarization: 摘要; MT: 翻译; Text2Annotation: 抽取; Correction: 纠错;
QA_MRC: 阅读理解; Dialogue: 对话; QA_Closed_book: 闭卷问答; QA_Multi_Choice: 多选问答;
QuestionGeneration: 问题生成; Paraphrasing: 复述; NLI: 文本蕴含识别; SemanticMatching: 匹配;
Text2SQL: 文本描述转SQL; TextClassification: 文本分类; SentimentClassification: 情感分析;
zuowen: 写作文; adtext: 写文案; couplet: 对对联; novel: 写小说; cloze: 文本补全; Misc: 其它任务。
penalty_text(str): 模型会惩罚该字符串中的 token。通过设置该值, 可以减少某些冗余与异常字符的生成。
choice_text(str): 模型只能生成该字符串中的 token 的组合。通过设置该值, 可以对某些抽取式任务进行定向调优。
......@@ -87,14 +90,12 @@ class ERNIEZeus:
logits_bias(int): 配合 penalty_text 使用, 对给定的 penalty_text 中的 token 增加一个 logits_bias, 可以通过设置该值屏蔽某些 token 生成的概率。
mask_type(str): 设置该值可以控制模型生成粒度。可选参数为 word, sentence, paragraph。
Return:
Return:
text(str): 生成的文本
'''
url = 'https://wenxin.baidu.com/moduleApi/portal/api/rest/1.0/ernie/3.0.28/zeus?from=paddlehub'
access_token = self.access_token
headers = {
'Content-Type': 'application/x-www-form-urlencoded'
}
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
datas = {
'access_token': access_token,
'text': text,
......@@ -131,21 +132,19 @@ class ERNIEZeus:
'''
文本生成
'''
return self.custom_generation(
text,
min_dec_len,
seq_len,
topp,
penalty_score,
stop_token='',
task_prompt='PARAGRAPH',
penalty_text='[{[gEND]',
choice_text='',
is_unidirectional=True,
min_dec_penalty_text='。?:![<S>]',
logits_bias=-10,
mask_type='paragraph'
)
return self.custom_generation(text,
min_dec_len,
seq_len,
topp,
penalty_score,
stop_token='',
task_prompt='PARAGRAPH',
penalty_text='[{[gEND]',
choice_text='',
is_unidirectional=True,
min_dec_penalty_text='。?:![<S>]',
logits_bias=-10,
mask_type='paragraph')
def text_summarization(self,
text: str,
......@@ -157,21 +156,19 @@ class ERNIEZeus:
摘要生成
'''
text = "文章:{} 摘要:".format(text)
return self.custom_generation(
text,
min_dec_len,
seq_len,
topp,
penalty_score,
stop_token='',
task_prompt='Summarization',
penalty_text='',
choice_text='',
is_unidirectional=False,
min_dec_penalty_text='',
logits_bias=-10000,
mask_type='word'
)
return self.custom_generation(text,
min_dec_len,
seq_len,
topp,
penalty_score,
stop_token='',
task_prompt='Summarization',
penalty_text='',
choice_text='',
is_unidirectional=False,
min_dec_penalty_text='',
logits_bias=-10000,
mask_type='word')
def copywriting_generation(self,
text: str,
......@@ -183,21 +180,19 @@ class ERNIEZeus:
文案生成
'''
text = "标题:{} 文案:".format(text)
return self.custom_generation(
text,
min_dec_len,
seq_len,
topp,
penalty_score,
stop_token='',
task_prompt='adtext',
penalty_text='',
choice_text='',
is_unidirectional=False,
min_dec_penalty_text='',
logits_bias=-10000,
mask_type='word'
)
return self.custom_generation(text,
min_dec_len,
seq_len,
topp,
penalty_score,
stop_token='',
task_prompt='adtext',
penalty_text='',
choice_text='',
is_unidirectional=False,
min_dec_penalty_text='',
logits_bias=-10000,
mask_type='word')
def novel_continuation(self,
text: str,
......@@ -209,21 +204,19 @@ class ERNIEZeus:
小说续写
'''
text = "上文:{} 下文:".format(text)
return self.custom_generation(
text,
min_dec_len,
seq_len,
topp,
penalty_score,
stop_token='',
task_prompt='gPARAGRAPH',
penalty_text='',
choice_text='',
is_unidirectional=True,
min_dec_penalty_text='。?:![<S>]',
logits_bias=-5,
mask_type='paragraph'
)
return self.custom_generation(text,
min_dec_len,
seq_len,
topp,
penalty_score,
stop_token='',
task_prompt='gPARAGRAPH',
penalty_text='',
choice_text='',
is_unidirectional=True,
min_dec_penalty_text='。?:![<S>]',
logits_bias=-5,
mask_type='paragraph')
def answer_generation(self,
text: str,
......@@ -235,21 +228,19 @@ class ERNIEZeus:
自由问答
'''
text = "问题:{} 回答:".format(text)
return self.custom_generation(
text,
min_dec_len,
seq_len,
topp,
penalty_score,
stop_token='',
task_prompt='qa',
penalty_text='[gEND]',
choice_text='',
is_unidirectional=True,
min_dec_penalty_text='。?:![<S>]',
logits_bias=-5,
mask_type='paragraph'
)
return self.custom_generation(text,
min_dec_len,
seq_len,
topp,
penalty_score,
stop_token='',
task_prompt='qa',
penalty_text='[gEND]',
choice_text='',
is_unidirectional=True,
min_dec_penalty_text='。?:![<S>]',
logits_bias=-5,
mask_type='paragraph')
def couplet_continuation(self,
text: str,
......@@ -261,21 +252,19 @@ class ERNIEZeus:
对联续写
'''
text = "上联:{} 下联:".format(text)
return self.custom_generation(
text,
min_dec_len,
seq_len,
topp,
penalty_score,
stop_token='',
task_prompt='couplet',
penalty_text='',
choice_text='',
is_unidirectional=False,
min_dec_penalty_text='',
logits_bias=-10000,
mask_type='word'
)
return self.custom_generation(text,
min_dec_len,
seq_len,
topp,
penalty_score,
stop_token='',
task_prompt='couplet',
penalty_text='',
choice_text='',
is_unidirectional=False,
min_dec_penalty_text='',
logits_bias=-10000,
mask_type='word')
def composition_generation(self,
text: str,
......@@ -287,21 +276,19 @@ class ERNIEZeus:
作文创作
'''
text = "作文题目:{} 正文:".format(text)
return self.custom_generation(
text,
min_dec_len,
seq_len,
topp,
penalty_score,
stop_token='',
task_prompt='zuowen',
penalty_text='',
choice_text='',
is_unidirectional=False,
min_dec_penalty_text='',
logits_bias=-10000,
mask_type='word'
)
return self.custom_generation(text,
min_dec_len,
seq_len,
topp,
penalty_score,
stop_token='',
task_prompt='zuowen',
penalty_text='',
choice_text='',
is_unidirectional=False,
min_dec_penalty_text='',
logits_bias=-10000,
mask_type='word')
def text_cloze(self,
text: str,
......@@ -312,29 +299,26 @@ class ERNIEZeus:
'''
完形填空
'''
return self.custom_generation(
text,
min_dec_len,
seq_len,
topp,
penalty_score,
stop_token='',
task_prompt='cloze',
penalty_text='',
choice_text='',
is_unidirectional=False,
min_dec_penalty_text='',
logits_bias=-10000,
mask_type='word'
)
return self.custom_generation(text,
min_dec_len,
seq_len,
topp,
penalty_score,
stop_token='',
task_prompt='cloze',
penalty_text='',
choice_text='',
is_unidirectional=False,
min_dec_penalty_text='',
logits_bias=-10000,
mask_type='word')
@runnable
def cmd(self, argvs):
parser = argparse.ArgumentParser(
description="Run the {}".format(self.name),
prog="hub run {}".format(self.name),
usage='%(prog)s',
add_help=True)
parser = argparse.ArgumentParser(description="Run the {}".format(self.name),
prog="hub run {}".format(self.name),
usage='%(prog)s',
add_help=True)
parser.add_argument('--text', type=str, required=True)
parser.add_argument('--min_dec_len', type=int, default=1)
......@@ -370,12 +354,7 @@ class ERNIEZeus:
kwargs.pop('min_dec_penalty_text')
kwargs.pop('logits_bias')
kwargs.pop('mask_type')
default_kwargs = {
'min_dec_len': 1,
'seq_len': 128,
'topp': 1.0,
'penalty_score': 1.0
}
default_kwargs = {'min_dec_len': 1, 'seq_len': 128, 'topp': 1.0, 'penalty_score': 1.0}
else:
default_kwargs = {
'min_dec_len': 1,
......@@ -400,52 +379,3 @@ class ERNIEZeus:
kwargs.pop(k)
return func(**kwargs)
if __name__ == '__main__':
ernie_zeus = ERNIEZeus()
result = ernie_zeus.custom_generation(
'你好,'
)
print(result)
result = ernie_zeus.text_generation(
'给宠物猫起一些可爱的名字。名字:'
)
print(result)
result = ernie_zeus.text_summarization(
'在芬兰、瑞典提交“入约”申请近一个月来,北约成员国内部尚未对此达成一致意见。与此同时,俄罗斯方面也多次对北约“第六轮扩张”发出警告。据北约官网显示,北约秘书长斯托尔滕贝格将于本月12日至13日出访瑞典和芬兰,并将分别与两国领导人进行会晤。'
)
print(result)
result = ernie_zeus.copywriting_generation(
'芍药香氛的沐浴乳'
)
print(result)
result = ernie_zeus.novel_continuation(
'昆仑山可以说是天下龙脉的根源,所有的山脉都可以看作是昆仑的分支。这些分出来的枝枝杈杈,都可以看作是一条条独立的龙脉。'
)
print(result)
result = ernie_zeus.answer_generation(
'交朋友的原则是什么?'
)
print(result)
result = ernie_zeus.couplet_continuation(
'五湖四海皆春色'
)
print(result)
result = ernie_zeus.composition_generation(
'诚以养德,信以修身'
)
print(result)
result = ernie_zeus.text_cloze(
'她有着一双[MASK]的眼眸。'
)
print(result)
import unittest
import paddlehub as hub
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.module = hub.Module(name='ernie_zeus')
def test_custom_generation(self):
results = self.module.custom_generation('你好,')
self.assertIsInstance(results, str)
def test_text_generation(self):
results = self.module.text_generation('给宠物猫起一些可爱的名字。名字:')
self.assertIsInstance(results, str)
def test_text_summarization(self):
results = self.module.text_summarization(
'在芬兰、瑞典提交“入约”申请近一个月来,北约成员国内部尚未对此达成一致意见。与此同时,俄罗斯方面也多次对北约“第六轮扩张”发出警告。据北约官网显示,北约秘书长斯托尔滕贝格将于本月12日至13日出访瑞典和芬兰,并将分别与两国领导人进行会晤。'
)
self.assertIsInstance(results, str)
def test_copywriting_generation(self):
results = self.module.copywriting_generation('芍药香氛的沐浴乳')
self.assertIsInstance(results, str)
def test_modulenovel_continuation(self):
results = self.module.novel_continuation('昆仑山可以说是天下龙脉的根源,所有的山脉都可以看作是昆仑的分支。这些分出来的枝枝杈杈,都可以看作是一条条独立的龙脉。')
self.assertIsInstance(results, str)
def test_answer_generation(self):
results = self.module.answer_generation('交朋友的原则是什么?')
self.assertIsInstance(results, str)
def test_couplet_continuation(self):
results = self.module.couplet_continuation('五湖四海皆春色')
self.assertIsInstance(results, str)
def test_composition_generation(self):
results = self.module.composition_generation('诚以养德,信以修身')
self.assertIsInstance(results, str)
def test_text_cloze(self):
results = self.module.text_cloze('她有着一双[MASK]的眼眸。')
self.assertIsInstance(results, str)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册