未验证 提交 a900ca05 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update ERNIE Zeus (#2127)

* update ERNIE Zeus

* update README

* update README
上级 088b37d6
...@@ -52,6 +52,9 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注 ...@@ -52,6 +52,9 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
- ```bash - ```bash
# 作文创作 # 作文创作
# 请设置 '--ak' 和 '--sk' 参数
# 或者设置 'WENXIN_AK' 和 'WENXIN_SK' 环境变量
# 更多细节参考下方 API 说明
$ hub run ernie_zeus \ $ hub run ernie_zeus \
--task composition_generation \ --task composition_generation \
--text '诚以养德,信以修身' --text '诚以养德,信以修身'
...@@ -67,7 +70,9 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注 ...@@ -67,7 +70,9 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
- ```python - ```python
import paddlehub as hub import paddlehub as hub
# 加载模型 # 请设置 'ak' 'sk' 参数
# 或者设置 'WENXIN_AK' 'WENXIN_SK' 环境变量
# 更多细节参考下方 API 说明
model = hub.Module(name='ernie_zeus') model = hub.Module(name='ernie_zeus')
# 作文创作 # 作文创作
...@@ -81,8 +86,8 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注 ...@@ -81,8 +86,8 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
- ### 3. API - ### 3. API
- ```python - ```python
def __init__( def __init__(
api_key: str = '', ak: Optional[str] = None,
secret_key: str = '' sk: Optional[str] = None
) -> None ) -> None
``` ```
...@@ -90,8 +95,8 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注 ...@@ -90,8 +95,8 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
- **参数** - **参数**
- api_key(str): API Key。(可选) - sk(Optional[str]): 文心 API AK,默认为 None,即从环境变量 'WENXIN_AK' 中获取;
- secret_key(str): Secret Key。(可选) - ak(Optional[str]): 文心 API SK,默认为 None,即从环境变量 'WENXIN_SK' 中获取。
- ```python - ```python
def custom_generation( def custom_generation(
...@@ -107,9 +112,7 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注 ...@@ -107,9 +112,7 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
is_unidirectional: bool = False, is_unidirectional: bool = False,
min_dec_penalty_text: str = '', min_dec_penalty_text: str = '',
logits_bias: int = -10000, logits_bias: int = -10000,
mask_type: str = 'word', mask_type: str = 'word'
api_key: str = '',
secret_key: str = ''
) -> str ) -> str
``` ```
- 自定义文本生成 API - 自定义文本生成 API
...@@ -292,6 +295,10 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注 ...@@ -292,6 +295,10 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
初始发布 初始发布
* 1.1.0
移除默认 AK 和 SK
```shell ```shell
$ hub install ernie_zeus == 1.0.0 $ hub install ernie_zeus == 1.1.0
``` ```
import json
import argparse import argparse
import json
import os
import requests import requests
from paddlehub.module.module import moduleinfo, runnable
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
def get_access_token(ak: str = '', sk: str = '') -> str: def get_access_token(ak: str = None, sk: str = None) -> str:
''' '''
Get Access Token Get Access Token
...@@ -16,15 +19,16 @@ def get_access_token(ak: str = '', sk: str = '') -> str: ...@@ -16,15 +19,16 @@ def get_access_token(ak: str = '', sk: str = '') -> str:
Return: Return:
access_token(str): Access Token access_token(str): Access Token
''' '''
ak = ak if ak else os.getenv('WENXIN_AK')
sk = sk if sk else os.getenv('WENXIN_SK')
assert ak and sk, RuntimeError(
'Please go to the wenxin official website to apply for AK and SK and set the parameters “ak” and “sk” correctly, or set the environment variables “WENXIN_AK” and “WENXIN_SK”.'
)
url = 'https://wenxin.baidu.com/younger/portal/api/oauth/token' url = 'https://wenxin.baidu.com/younger/portal/api/oauth/token'
headers = { headers = {'Content-Type': 'application/x-www-form-urlencoded'}
'Content-Type': 'application/x-www-form-urlencoded' datas = {'grant_type': 'client_credentials', 'client_id': ak, 'client_secret': sk}
}
datas = {
'grant_type': 'client_credentials',
'client_id': ak if ak != '' else 'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE',
'client_secret': sk if sk != '' else 'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs'
}
responses = requests.post(url, datas, headers=headers) responses = requests.post(url, datas, headers=headers)
...@@ -37,16 +41,15 @@ def get_access_token(ak: str = '', sk: str = '') -> str: ...@@ -37,16 +41,15 @@ def get_access_token(ak: str = '', sk: str = '') -> str:
return results['data'] return results['data']
@moduleinfo( @moduleinfo(name='ernie_zeus',
name='ernie_zeus',
type='nlp/text_generation', type='nlp/text_generation',
author='paddlepaddle', author='paddlepaddle',
author_email='', author_email='',
summary='ernie_zeus', summary='ernie_zeus',
version='1.0.0' version='1.1.0')
)
class ERNIEZeus: class ERNIEZeus:
def __init__(self, ak: str = '', sk: str = '') -> None:
def __init__(self, ak: str = None, sk: str = None) -> None:
self.access_token = get_access_token(ak, sk) self.access_token = get_access_token(ak, sk)
def custom_generation(self, def custom_generation(self,
...@@ -92,9 +95,7 @@ class ERNIEZeus: ...@@ -92,9 +95,7 @@ class ERNIEZeus:
''' '''
url = 'https://wenxin.baidu.com/moduleApi/portal/api/rest/1.0/ernie/3.0.28/zeus?from=paddlehub' url = 'https://wenxin.baidu.com/moduleApi/portal/api/rest/1.0/ernie/3.0.28/zeus?from=paddlehub'
access_token = self.access_token access_token = self.access_token
headers = { headers = {'Content-Type': 'application/x-www-form-urlencoded'}
'Content-Type': 'application/x-www-form-urlencoded'
}
datas = { datas = {
'access_token': access_token, 'access_token': access_token,
'text': text, 'text': text,
...@@ -131,8 +132,7 @@ class ERNIEZeus: ...@@ -131,8 +132,7 @@ class ERNIEZeus:
''' '''
文本生成 文本生成
''' '''
return self.custom_generation( return self.custom_generation(text,
text,
min_dec_len, min_dec_len,
seq_len, seq_len,
topp, topp,
...@@ -144,8 +144,7 @@ class ERNIEZeus: ...@@ -144,8 +144,7 @@ class ERNIEZeus:
is_unidirectional=True, is_unidirectional=True,
min_dec_penalty_text='。?:![<S>]', min_dec_penalty_text='。?:![<S>]',
logits_bias=-10, logits_bias=-10,
mask_type='paragraph' mask_type='paragraph')
)
def text_summarization(self, def text_summarization(self,
text: str, text: str,
...@@ -157,8 +156,7 @@ class ERNIEZeus: ...@@ -157,8 +156,7 @@ class ERNIEZeus:
摘要生成 摘要生成
''' '''
text = "文章:{} 摘要:".format(text) text = "文章:{} 摘要:".format(text)
return self.custom_generation( return self.custom_generation(text,
text,
min_dec_len, min_dec_len,
seq_len, seq_len,
topp, topp,
...@@ -170,8 +168,7 @@ class ERNIEZeus: ...@@ -170,8 +168,7 @@ class ERNIEZeus:
is_unidirectional=False, is_unidirectional=False,
min_dec_penalty_text='', min_dec_penalty_text='',
logits_bias=-10000, logits_bias=-10000,
mask_type='word' mask_type='word')
)
def copywriting_generation(self, def copywriting_generation(self,
text: str, text: str,
...@@ -183,8 +180,7 @@ class ERNIEZeus: ...@@ -183,8 +180,7 @@ class ERNIEZeus:
文案生成 文案生成
''' '''
text = "标题:{} 文案:".format(text) text = "标题:{} 文案:".format(text)
return self.custom_generation( return self.custom_generation(text,
text,
min_dec_len, min_dec_len,
seq_len, seq_len,
topp, topp,
...@@ -196,8 +192,7 @@ class ERNIEZeus: ...@@ -196,8 +192,7 @@ class ERNIEZeus:
is_unidirectional=False, is_unidirectional=False,
min_dec_penalty_text='', min_dec_penalty_text='',
logits_bias=-10000, logits_bias=-10000,
mask_type='word' mask_type='word')
)
def novel_continuation(self, def novel_continuation(self,
text: str, text: str,
...@@ -209,8 +204,7 @@ class ERNIEZeus: ...@@ -209,8 +204,7 @@ class ERNIEZeus:
小说续写 小说续写
''' '''
text = "上文:{} 下文:".format(text) text = "上文:{} 下文:".format(text)
return self.custom_generation( return self.custom_generation(text,
text,
min_dec_len, min_dec_len,
seq_len, seq_len,
topp, topp,
...@@ -222,8 +216,7 @@ class ERNIEZeus: ...@@ -222,8 +216,7 @@ class ERNIEZeus:
is_unidirectional=True, is_unidirectional=True,
min_dec_penalty_text='。?:![<S>]', min_dec_penalty_text='。?:![<S>]',
logits_bias=-5, logits_bias=-5,
mask_type='paragraph' mask_type='paragraph')
)
def answer_generation(self, def answer_generation(self,
text: str, text: str,
...@@ -235,8 +228,7 @@ class ERNIEZeus: ...@@ -235,8 +228,7 @@ class ERNIEZeus:
自由问答 自由问答
''' '''
text = "问题:{} 回答:".format(text) text = "问题:{} 回答:".format(text)
return self.custom_generation( return self.custom_generation(text,
text,
min_dec_len, min_dec_len,
seq_len, seq_len,
topp, topp,
...@@ -248,8 +240,7 @@ class ERNIEZeus: ...@@ -248,8 +240,7 @@ class ERNIEZeus:
is_unidirectional=True, is_unidirectional=True,
min_dec_penalty_text='。?:![<S>]', min_dec_penalty_text='。?:![<S>]',
logits_bias=-5, logits_bias=-5,
mask_type='paragraph' mask_type='paragraph')
)
def couplet_continuation(self, def couplet_continuation(self,
text: str, text: str,
...@@ -261,8 +252,7 @@ class ERNIEZeus: ...@@ -261,8 +252,7 @@ class ERNIEZeus:
对联续写 对联续写
''' '''
text = "上联:{} 下联:".format(text) text = "上联:{} 下联:".format(text)
return self.custom_generation( return self.custom_generation(text,
text,
min_dec_len, min_dec_len,
seq_len, seq_len,
topp, topp,
...@@ -274,8 +264,7 @@ class ERNIEZeus: ...@@ -274,8 +264,7 @@ class ERNIEZeus:
is_unidirectional=False, is_unidirectional=False,
min_dec_penalty_text='', min_dec_penalty_text='',
logits_bias=-10000, logits_bias=-10000,
mask_type='word' mask_type='word')
)
def composition_generation(self, def composition_generation(self,
text: str, text: str,
...@@ -287,8 +276,7 @@ class ERNIEZeus: ...@@ -287,8 +276,7 @@ class ERNIEZeus:
作文创作 作文创作
''' '''
text = "作文题目:{} 正文:".format(text) text = "作文题目:{} 正文:".format(text)
return self.custom_generation( return self.custom_generation(text,
text,
min_dec_len, min_dec_len,
seq_len, seq_len,
topp, topp,
...@@ -300,8 +288,7 @@ class ERNIEZeus: ...@@ -300,8 +288,7 @@ class ERNIEZeus:
is_unidirectional=False, is_unidirectional=False,
min_dec_penalty_text='', min_dec_penalty_text='',
logits_bias=-10000, logits_bias=-10000,
mask_type='word' mask_type='word')
)
def text_cloze(self, def text_cloze(self,
text: str, text: str,
...@@ -312,8 +299,7 @@ class ERNIEZeus: ...@@ -312,8 +299,7 @@ class ERNIEZeus:
''' '''
完形填空 完形填空
''' '''
return self.custom_generation( return self.custom_generation(text,
text,
min_dec_len, min_dec_len,
seq_len, seq_len,
topp, topp,
...@@ -325,13 +311,11 @@ class ERNIEZeus: ...@@ -325,13 +311,11 @@ class ERNIEZeus:
is_unidirectional=False, is_unidirectional=False,
min_dec_penalty_text='', min_dec_penalty_text='',
logits_bias=-10000, logits_bias=-10000,
mask_type='word' mask_type='word')
)
@runnable @runnable
def cmd(self, argvs): def cmd(self, argvs):
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Run the {}".format(self.name),
description="Run the {}".format(self.name),
prog="hub run {}".format(self.name), prog="hub run {}".format(self.name),
usage='%(prog)s', usage='%(prog)s',
add_help=True) add_help=True)
...@@ -370,12 +354,7 @@ class ERNIEZeus: ...@@ -370,12 +354,7 @@ class ERNIEZeus:
kwargs.pop('min_dec_penalty_text') kwargs.pop('min_dec_penalty_text')
kwargs.pop('logits_bias') kwargs.pop('logits_bias')
kwargs.pop('mask_type') kwargs.pop('mask_type')
default_kwargs = { default_kwargs = {'min_dec_len': 1, 'seq_len': 128, 'topp': 1.0, 'penalty_score': 1.0}
'min_dec_len': 1,
'seq_len': 128,
'topp': 1.0,
'penalty_score': 1.0
}
else: else:
default_kwargs = { default_kwargs = {
'min_dec_len': 1, 'min_dec_len': 1,
...@@ -400,52 +379,3 @@ class ERNIEZeus: ...@@ -400,52 +379,3 @@ class ERNIEZeus:
kwargs.pop(k) kwargs.pop(k)
return func(**kwargs) return func(**kwargs)
if __name__ == '__main__':
ernie_zeus = ERNIEZeus()
result = ernie_zeus.custom_generation(
'你好,'
)
print(result)
result = ernie_zeus.text_generation(
'给宠物猫起一些可爱的名字。名字:'
)
print(result)
result = ernie_zeus.text_summarization(
'在芬兰、瑞典提交“入约”申请近一个月来,北约成员国内部尚未对此达成一致意见。与此同时,俄罗斯方面也多次对北约“第六轮扩张”发出警告。据北约官网显示,北约秘书长斯托尔滕贝格将于本月12日至13日出访瑞典和芬兰,并将分别与两国领导人进行会晤。'
)
print(result)
result = ernie_zeus.copywriting_generation(
'芍药香氛的沐浴乳'
)
print(result)
result = ernie_zeus.novel_continuation(
'昆仑山可以说是天下龙脉的根源,所有的山脉都可以看作是昆仑的分支。这些分出来的枝枝杈杈,都可以看作是一条条独立的龙脉。'
)
print(result)
result = ernie_zeus.answer_generation(
'交朋友的原则是什么?'
)
print(result)
result = ernie_zeus.couplet_continuation(
'五湖四海皆春色'
)
print(result)
result = ernie_zeus.composition_generation(
'诚以养德,信以修身'
)
print(result)
result = ernie_zeus.text_cloze(
'她有着一双[MASK]的眼眸。'
)
print(result)
import unittest
import paddlehub as hub
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.module = hub.Module(name='ernie_zeus')
def test_custom_generation(self):
results = self.module.custom_generation('你好,')
self.assertIsInstance(results, str)
def test_text_generation(self):
results = self.module.text_generation('给宠物猫起一些可爱的名字。名字:')
self.assertIsInstance(results, str)
def test_text_summarization(self):
results = self.module.text_summarization(
'在芬兰、瑞典提交“入约”申请近一个月来,北约成员国内部尚未对此达成一致意见。与此同时,俄罗斯方面也多次对北约“第六轮扩张”发出警告。据北约官网显示,北约秘书长斯托尔滕贝格将于本月12日至13日出访瑞典和芬兰,并将分别与两国领导人进行会晤。'
)
self.assertIsInstance(results, str)
def test_copywriting_generation(self):
results = self.module.copywriting_generation('芍药香氛的沐浴乳')
self.assertIsInstance(results, str)
def test_modulenovel_continuation(self):
results = self.module.novel_continuation('昆仑山可以说是天下龙脉的根源,所有的山脉都可以看作是昆仑的分支。这些分出来的枝枝杈杈,都可以看作是一条条独立的龙脉。')
self.assertIsInstance(results, str)
def test_answer_generation(self):
results = self.module.answer_generation('交朋友的原则是什么?')
self.assertIsInstance(results, str)
def test_couplet_continuation(self):
results = self.module.couplet_continuation('五湖四海皆春色')
self.assertIsInstance(results, str)
def test_composition_generation(self):
results = self.module.composition_generation('诚以养德,信以修身')
self.assertIsInstance(results, str)
def test_text_cloze(self):
results = self.module.text_cloze('她有着一双[MASK]的眼眸。')
self.assertIsInstance(results, str)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册