Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
a900ca05
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 2 年 前同步成功
通知
285
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a900ca05
编写于
11月 18, 2022
作者:
jm_12138
提交者:
GitHub
11月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update ERNIE Zeus (#2127)
* update ERNIE Zeus * update README * update README
上级
088b37d6
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
216 addition
and
227 deletion
+216
-227
modules/text/text_generation/ernie_zeus/README.md
modules/text/text_generation/ernie_zeus/README.md
+25
-18
modules/text/text_generation/ernie_zeus/__init__.py
modules/text/text_generation/ernie_zeus/__init__.py
+0
-0
modules/text/text_generation/ernie_zeus/module.py
modules/text/text_generation/ernie_zeus/module.py
+139
-209
modules/text/text_generation/ernie_zeus/test.py
modules/text/text_generation/ernie_zeus/test.py
+52
-0
未找到文件。
modules/text/text_generation/ernie_zeus/README.md
浏览文件 @
a900ca05
...
@@ -52,6 +52,9 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
...
@@ -52,6 +52,9 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
-
```
bash
-
```
bash
# 作文创作
# 作文创作
# 请设置 '--ak' 和 '--sk' 参数
# 或者设置 'WENXIN_AK' 和 'WENXIN_SK' 环境变量
# 更多细节参考下方 API 说明
$ hub run ernie_zeus
\
$ hub run ernie_zeus
\
--task composition_generation
\
--task composition_generation
\
--text '诚以养德,信以修身'
--text '诚以养德,信以修身'
...
@@ -67,7 +70,9 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
...
@@ -67,7 +70,9 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
-
```
python
-
```
python
import
paddlehub
as
hub
import
paddlehub
as
hub
#
加载模型
#
请设置
'ak'
和
'sk'
参数
#
或者设置
'WENXIN_AK'
和
'WENXIN_SK'
环境变量
#
更多细节参考下方
API
说明
model
=
hub
.
Module
(
name
=
'ernie_zeus'
)
model
=
hub
.
Module
(
name
=
'ernie_zeus'
)
#
作文创作
#
作文创作
...
@@ -81,8 +86,8 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
...
@@ -81,8 +86,8 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
-
### 3. API
-
### 3. API
-
```python
-
```python
def __init__(
def __init__(
a
pi_key: str = '',
a
k: Optional[str] = None,
s
ecret_key: str = ''
s
k: Optional[str] = None
) -> None
) -> None
```
```
...
@@ -90,8 +95,8 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
...
@@ -90,8 +95,8 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
- **参数**
- **参数**
-
api_key(str): API Key。(可选)
-
sk(Optional[str]): 文心 API AK,默认为 None,即从环境变量 'WENXIN_AK' 中获取;
-
secret_key(str): Secret Key。(可选)
-
ak(Optional[str]): 文心 API SK,默认为 None,即从环境变量 'WENXIN_SK' 中获取。
-
```python
-
```python
def custom_generation(
def custom_generation(
...
@@ -107,9 +112,7 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
...
@@ -107,9 +112,7 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
is_unidirectional: bool = False,
is_unidirectional: bool = False,
min_dec_penalty_text: str = '',
min_dec_penalty_text: str = '',
logits_bias: int = -10000,
logits_bias: int = -10000,
mask_type: str = 'word',
mask_type: str = 'word'
api_key: str = '',
secret_key: str = ''
) -> str
) -> str
```
```
-
自定义文本生成 API
-
自定义文本生成 API
...
@@ -292,6 +295,10 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
...
@@ -292,6 +295,10 @@ ERNIE 3.0 Zeus 是 ERNIE 3.0 系列模型的最新升级。其除了对无标注
初始发布
初始发布
*
1.1.0
移除默认 AK 和 SK
```
shell
```
shell
$
hub
install
ernie_zeus
==
1.
0
.0
$
hub
install
ernie_zeus
==
1.
1
.0
```
```
modules/text/text_generation/ernie_zeus/__init__.py
0 → 100644
浏览文件 @
a900ca05
modules/text/text_generation/ernie_zeus/module.py
浏览文件 @
a900ca05
import
json
import
argparse
import
argparse
import
json
import
os
import
requests
import
requests
from
paddlehub.module.module
import
moduleinfo
,
runnable
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.module.module
import
runnable
def
get_access_token
(
ak
:
str
=
''
,
sk
:
str
=
''
)
->
str
:
def
get_access_token
(
ak
:
str
=
None
,
sk
:
str
=
None
)
->
str
:
'''
'''
Get Access Token
Get Access Token
...
@@ -16,15 +19,16 @@ def get_access_token(ak: str = '', sk: str = '') -> str:
...
@@ -16,15 +19,16 @@ def get_access_token(ak: str = '', sk: str = '') -> str:
Return:
Return:
access_token(str): Access Token
access_token(str): Access Token
'''
'''
ak
=
ak
if
ak
else
os
.
getenv
(
'WENXIN_AK'
)
sk
=
sk
if
sk
else
os
.
getenv
(
'WENXIN_SK'
)
assert
ak
and
sk
,
RuntimeError
(
'Please go to the wenxin official website to apply for AK and SK and set the parameters “ak” and “sk” correctly, or set the environment variables “WENXIN_AK” and “WENXIN_SK”.'
)
url
=
'https://wenxin.baidu.com/younger/portal/api/oauth/token'
url
=
'https://wenxin.baidu.com/younger/portal/api/oauth/token'
headers
=
{
headers
=
{
'Content-Type'
:
'application/x-www-form-urlencoded'
}
'Content-Type'
:
'application/x-www-form-urlencoded'
datas
=
{
'grant_type'
:
'client_credentials'
,
'client_id'
:
ak
,
'client_secret'
:
sk
}
}
datas
=
{
'grant_type'
:
'client_credentials'
,
'client_id'
:
ak
if
ak
!=
''
else
'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE'
,
'client_secret'
:
sk
if
sk
!=
''
else
'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs'
}
responses
=
requests
.
post
(
url
,
datas
,
headers
=
headers
)
responses
=
requests
.
post
(
url
,
datas
,
headers
=
headers
)
...
@@ -37,16 +41,15 @@ def get_access_token(ak: str = '', sk: str = '') -> str:
...
@@ -37,16 +41,15 @@ def get_access_token(ak: str = '', sk: str = '') -> str:
return
results
[
'data'
]
return
results
[
'data'
]
@
moduleinfo
(
@
moduleinfo
(
name
=
'ernie_zeus'
,
name
=
'ernie_zeus'
,
type
=
'nlp/text_generation'
,
type
=
'nlp/text_generation'
,
author
=
'paddlepaddle'
,
author
=
'paddlepaddle'
,
author_email
=
''
,
author_email
=
''
,
summary
=
'ernie_zeus'
,
summary
=
'ernie_zeus'
,
version
=
'1.0.0'
version
=
'1.1.0'
)
)
class
ERNIEZeus
:
class
ERNIEZeus
:
def
__init__
(
self
,
ak
:
str
=
''
,
sk
:
str
=
''
)
->
None
:
def
__init__
(
self
,
ak
:
str
=
None
,
sk
:
str
=
None
)
->
None
:
self
.
access_token
=
get_access_token
(
ak
,
sk
)
self
.
access_token
=
get_access_token
(
ak
,
sk
)
def
custom_generation
(
self
,
def
custom_generation
(
self
,
...
@@ -92,9 +95,7 @@ class ERNIEZeus:
...
@@ -92,9 +95,7 @@ class ERNIEZeus:
'''
'''
url
=
'https://wenxin.baidu.com/moduleApi/portal/api/rest/1.0/ernie/3.0.28/zeus?from=paddlehub'
url
=
'https://wenxin.baidu.com/moduleApi/portal/api/rest/1.0/ernie/3.0.28/zeus?from=paddlehub'
access_token
=
self
.
access_token
access_token
=
self
.
access_token
headers
=
{
headers
=
{
'Content-Type'
:
'application/x-www-form-urlencoded'
}
'Content-Type'
:
'application/x-www-form-urlencoded'
}
datas
=
{
datas
=
{
'access_token'
:
access_token
,
'access_token'
:
access_token
,
'text'
:
text
,
'text'
:
text
,
...
@@ -131,8 +132,7 @@ class ERNIEZeus:
...
@@ -131,8 +132,7 @@ class ERNIEZeus:
'''
'''
文本生成
文本生成
'''
'''
return
self
.
custom_generation
(
return
self
.
custom_generation
(
text
,
text
,
min_dec_len
,
min_dec_len
,
seq_len
,
seq_len
,
topp
,
topp
,
...
@@ -144,8 +144,7 @@ class ERNIEZeus:
...
@@ -144,8 +144,7 @@ class ERNIEZeus:
is_unidirectional
=
True
,
is_unidirectional
=
True
,
min_dec_penalty_text
=
'。?:![<S>]'
,
min_dec_penalty_text
=
'。?:![<S>]'
,
logits_bias
=-
10
,
logits_bias
=-
10
,
mask_type
=
'paragraph'
mask_type
=
'paragraph'
)
)
def
text_summarization
(
self
,
def
text_summarization
(
self
,
text
:
str
,
text
:
str
,
...
@@ -157,8 +156,7 @@ class ERNIEZeus:
...
@@ -157,8 +156,7 @@ class ERNIEZeus:
摘要生成
摘要生成
'''
'''
text
=
"文章:{} 摘要:"
.
format
(
text
)
text
=
"文章:{} 摘要:"
.
format
(
text
)
return
self
.
custom_generation
(
return
self
.
custom_generation
(
text
,
text
,
min_dec_len
,
min_dec_len
,
seq_len
,
seq_len
,
topp
,
topp
,
...
@@ -170,8 +168,7 @@ class ERNIEZeus:
...
@@ -170,8 +168,7 @@ class ERNIEZeus:
is_unidirectional
=
False
,
is_unidirectional
=
False
,
min_dec_penalty_text
=
''
,
min_dec_penalty_text
=
''
,
logits_bias
=-
10000
,
logits_bias
=-
10000
,
mask_type
=
'word'
mask_type
=
'word'
)
)
def
copywriting_generation
(
self
,
def
copywriting_generation
(
self
,
text
:
str
,
text
:
str
,
...
@@ -183,8 +180,7 @@ class ERNIEZeus:
...
@@ -183,8 +180,7 @@ class ERNIEZeus:
文案生成
文案生成
'''
'''
text
=
"标题:{} 文案:"
.
format
(
text
)
text
=
"标题:{} 文案:"
.
format
(
text
)
return
self
.
custom_generation
(
return
self
.
custom_generation
(
text
,
text
,
min_dec_len
,
min_dec_len
,
seq_len
,
seq_len
,
topp
,
topp
,
...
@@ -196,8 +192,7 @@ class ERNIEZeus:
...
@@ -196,8 +192,7 @@ class ERNIEZeus:
is_unidirectional
=
False
,
is_unidirectional
=
False
,
min_dec_penalty_text
=
''
,
min_dec_penalty_text
=
''
,
logits_bias
=-
10000
,
logits_bias
=-
10000
,
mask_type
=
'word'
mask_type
=
'word'
)
)
def
novel_continuation
(
self
,
def
novel_continuation
(
self
,
text
:
str
,
text
:
str
,
...
@@ -209,8 +204,7 @@ class ERNIEZeus:
...
@@ -209,8 +204,7 @@ class ERNIEZeus:
小说续写
小说续写
'''
'''
text
=
"上文:{} 下文:"
.
format
(
text
)
text
=
"上文:{} 下文:"
.
format
(
text
)
return
self
.
custom_generation
(
return
self
.
custom_generation
(
text
,
text
,
min_dec_len
,
min_dec_len
,
seq_len
,
seq_len
,
topp
,
topp
,
...
@@ -222,8 +216,7 @@ class ERNIEZeus:
...
@@ -222,8 +216,7 @@ class ERNIEZeus:
is_unidirectional
=
True
,
is_unidirectional
=
True
,
min_dec_penalty_text
=
'。?:![<S>]'
,
min_dec_penalty_text
=
'。?:![<S>]'
,
logits_bias
=-
5
,
logits_bias
=-
5
,
mask_type
=
'paragraph'
mask_type
=
'paragraph'
)
)
def
answer_generation
(
self
,
def
answer_generation
(
self
,
text
:
str
,
text
:
str
,
...
@@ -235,8 +228,7 @@ class ERNIEZeus:
...
@@ -235,8 +228,7 @@ class ERNIEZeus:
自由问答
自由问答
'''
'''
text
=
"问题:{} 回答:"
.
format
(
text
)
text
=
"问题:{} 回答:"
.
format
(
text
)
return
self
.
custom_generation
(
return
self
.
custom_generation
(
text
,
text
,
min_dec_len
,
min_dec_len
,
seq_len
,
seq_len
,
topp
,
topp
,
...
@@ -248,8 +240,7 @@ class ERNIEZeus:
...
@@ -248,8 +240,7 @@ class ERNIEZeus:
is_unidirectional
=
True
,
is_unidirectional
=
True
,
min_dec_penalty_text
=
'。?:![<S>]'
,
min_dec_penalty_text
=
'。?:![<S>]'
,
logits_bias
=-
5
,
logits_bias
=-
5
,
mask_type
=
'paragraph'
mask_type
=
'paragraph'
)
)
def
couplet_continuation
(
self
,
def
couplet_continuation
(
self
,
text
:
str
,
text
:
str
,
...
@@ -261,8 +252,7 @@ class ERNIEZeus:
...
@@ -261,8 +252,7 @@ class ERNIEZeus:
对联续写
对联续写
'''
'''
text
=
"上联:{} 下联:"
.
format
(
text
)
text
=
"上联:{} 下联:"
.
format
(
text
)
return
self
.
custom_generation
(
return
self
.
custom_generation
(
text
,
text
,
min_dec_len
,
min_dec_len
,
seq_len
,
seq_len
,
topp
,
topp
,
...
@@ -274,8 +264,7 @@ class ERNIEZeus:
...
@@ -274,8 +264,7 @@ class ERNIEZeus:
is_unidirectional
=
False
,
is_unidirectional
=
False
,
min_dec_penalty_text
=
''
,
min_dec_penalty_text
=
''
,
logits_bias
=-
10000
,
logits_bias
=-
10000
,
mask_type
=
'word'
mask_type
=
'word'
)
)
def
composition_generation
(
self
,
def
composition_generation
(
self
,
text
:
str
,
text
:
str
,
...
@@ -287,8 +276,7 @@ class ERNIEZeus:
...
@@ -287,8 +276,7 @@ class ERNIEZeus:
作文创作
作文创作
'''
'''
text
=
"作文题目:{} 正文:"
.
format
(
text
)
text
=
"作文题目:{} 正文:"
.
format
(
text
)
return
self
.
custom_generation
(
return
self
.
custom_generation
(
text
,
text
,
min_dec_len
,
min_dec_len
,
seq_len
,
seq_len
,
topp
,
topp
,
...
@@ -300,8 +288,7 @@ class ERNIEZeus:
...
@@ -300,8 +288,7 @@ class ERNIEZeus:
is_unidirectional
=
False
,
is_unidirectional
=
False
,
min_dec_penalty_text
=
''
,
min_dec_penalty_text
=
''
,
logits_bias
=-
10000
,
logits_bias
=-
10000
,
mask_type
=
'word'
mask_type
=
'word'
)
)
def
text_cloze
(
self
,
def
text_cloze
(
self
,
text
:
str
,
text
:
str
,
...
@@ -312,8 +299,7 @@ class ERNIEZeus:
...
@@ -312,8 +299,7 @@ class ERNIEZeus:
'''
'''
完形填空
完形填空
'''
'''
return
self
.
custom_generation
(
return
self
.
custom_generation
(
text
,
text
,
min_dec_len
,
min_dec_len
,
seq_len
,
seq_len
,
topp
,
topp
,
...
@@ -325,13 +311,11 @@ class ERNIEZeus:
...
@@ -325,13 +311,11 @@ class ERNIEZeus:
is_unidirectional
=
False
,
is_unidirectional
=
False
,
min_dec_penalty_text
=
''
,
min_dec_penalty_text
=
''
,
logits_bias
=-
10000
,
logits_bias
=-
10000
,
mask_type
=
'word'
mask_type
=
'word'
)
)
@
runnable
@
runnable
def
cmd
(
self
,
argvs
):
def
cmd
(
self
,
argvs
):
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
"Run the {}"
.
format
(
self
.
name
),
description
=
"Run the {}"
.
format
(
self
.
name
),
prog
=
"hub run {}"
.
format
(
self
.
name
),
prog
=
"hub run {}"
.
format
(
self
.
name
),
usage
=
'%(prog)s'
,
usage
=
'%(prog)s'
,
add_help
=
True
)
add_help
=
True
)
...
@@ -370,12 +354,7 @@ class ERNIEZeus:
...
@@ -370,12 +354,7 @@ class ERNIEZeus:
kwargs
.
pop
(
'min_dec_penalty_text'
)
kwargs
.
pop
(
'min_dec_penalty_text'
)
kwargs
.
pop
(
'logits_bias'
)
kwargs
.
pop
(
'logits_bias'
)
kwargs
.
pop
(
'mask_type'
)
kwargs
.
pop
(
'mask_type'
)
default_kwargs
=
{
default_kwargs
=
{
'min_dec_len'
:
1
,
'seq_len'
:
128
,
'topp'
:
1.0
,
'penalty_score'
:
1.0
}
'min_dec_len'
:
1
,
'seq_len'
:
128
,
'topp'
:
1.0
,
'penalty_score'
:
1.0
}
else
:
else
:
default_kwargs
=
{
default_kwargs
=
{
'min_dec_len'
:
1
,
'min_dec_len'
:
1
,
...
@@ -400,52 +379,3 @@ class ERNIEZeus:
...
@@ -400,52 +379,3 @@ class ERNIEZeus:
kwargs
.
pop
(
k
)
kwargs
.
pop
(
k
)
return
func
(
**
kwargs
)
return
func
(
**
kwargs
)
if
__name__
==
'__main__'
:
ernie_zeus
=
ERNIEZeus
()
result
=
ernie_zeus
.
custom_generation
(
'你好,'
)
print
(
result
)
result
=
ernie_zeus
.
text_generation
(
'给宠物猫起一些可爱的名字。名字:'
)
print
(
result
)
result
=
ernie_zeus
.
text_summarization
(
'在芬兰、瑞典提交“入约”申请近一个月来,北约成员国内部尚未对此达成一致意见。与此同时,俄罗斯方面也多次对北约“第六轮扩张”发出警告。据北约官网显示,北约秘书长斯托尔滕贝格将于本月12日至13日出访瑞典和芬兰,并将分别与两国领导人进行会晤。'
)
print
(
result
)
result
=
ernie_zeus
.
copywriting_generation
(
'芍药香氛的沐浴乳'
)
print
(
result
)
result
=
ernie_zeus
.
novel_continuation
(
'昆仑山可以说是天下龙脉的根源,所有的山脉都可以看作是昆仑的分支。这些分出来的枝枝杈杈,都可以看作是一条条独立的龙脉。'
)
print
(
result
)
result
=
ernie_zeus
.
answer_generation
(
'交朋友的原则是什么?'
)
print
(
result
)
result
=
ernie_zeus
.
couplet_continuation
(
'五湖四海皆春色'
)
print
(
result
)
result
=
ernie_zeus
.
composition_generation
(
'诚以养德,信以修身'
)
print
(
result
)
result
=
ernie_zeus
.
text_cloze
(
'她有着一双[MASK]的眼眸。'
)
print
(
result
)
modules/text/text_generation/ernie_zeus/test.py
0 → 100644
浏览文件 @
a900ca05
import
unittest
import
paddlehub
as
hub
class
TestHubModule
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
)
->
None
:
cls
.
module
=
hub
.
Module
(
name
=
'ernie_zeus'
)
def
test_custom_generation
(
self
):
results
=
self
.
module
.
custom_generation
(
'你好,'
)
self
.
assertIsInstance
(
results
,
str
)
def
test_text_generation
(
self
):
results
=
self
.
module
.
text_generation
(
'给宠物猫起一些可爱的名字。名字:'
)
self
.
assertIsInstance
(
results
,
str
)
def
test_text_summarization
(
self
):
results
=
self
.
module
.
text_summarization
(
'在芬兰、瑞典提交“入约”申请近一个月来,北约成员国内部尚未对此达成一致意见。与此同时,俄罗斯方面也多次对北约“第六轮扩张”发出警告。据北约官网显示,北约秘书长斯托尔滕贝格将于本月12日至13日出访瑞典和芬兰,并将分别与两国领导人进行会晤。'
)
self
.
assertIsInstance
(
results
,
str
)
def
test_copywriting_generation
(
self
):
results
=
self
.
module
.
copywriting_generation
(
'芍药香氛的沐浴乳'
)
self
.
assertIsInstance
(
results
,
str
)
def
test_modulenovel_continuation
(
self
):
results
=
self
.
module
.
novel_continuation
(
'昆仑山可以说是天下龙脉的根源,所有的山脉都可以看作是昆仑的分支。这些分出来的枝枝杈杈,都可以看作是一条条独立的龙脉。'
)
self
.
assertIsInstance
(
results
,
str
)
def
test_answer_generation
(
self
):
results
=
self
.
module
.
answer_generation
(
'交朋友的原则是什么?'
)
self
.
assertIsInstance
(
results
,
str
)
def
test_couplet_continuation
(
self
):
results
=
self
.
module
.
couplet_continuation
(
'五湖四海皆春色'
)
self
.
assertIsInstance
(
results
,
str
)
def
test_composition_generation
(
self
):
results
=
self
.
module
.
composition_generation
(
'诚以养德,信以修身'
)
self
.
assertIsInstance
(
results
,
str
)
def
test_text_cloze
(
self
):
results
=
self
.
module
.
text_cloze
(
'她有着一双[MASK]的眼眸。'
)
self
.
assertIsInstance
(
results
,
str
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录