Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
2aeceff2
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2aeceff2
编写于
8月 17, 2022
作者:
C
chenjian
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix
上级
6802506b
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
56 addition
and
20 deletion
+56
-20
modules/image/text_to_image/ernie_vilg/module.py
modules/image/text_to_image/ernie_vilg/module.py
+56
-20
未找到文件。
modules/image/text_to_image/ernie_vilg/module.py
浏览文件 @
2aeceff2
...
@@ -27,6 +27,32 @@ from paddlehub.module.module import serving
...
@@ -27,6 +27,32 @@ from paddlehub.module.module import serving
author_email
=
"paddle-dev@baidu.com"
)
author_email
=
"paddle-dev@baidu.com"
)
class
ErnieVilG
:
class
ErnieVilG
:
def
__init__
(
self
):
self
.
ak
=
'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE'
self
.
sk
=
'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs'
self
.
token_host
=
'https://wenxin.baidu.com/younger/portal/api/oauth/token'
self
.
token
=
self
.
_apply_token
(
self
.
ak
,
self
.
sk
)
def
_apply_token
(
self
,
ak
,
sk
):
if
ak
is
None
or
sk
is
None
:
ak
=
self
.
ak
sk
=
self
.
sk
response
=
requests
.
get
(
self
.
token_host
,
params
=
{
'grant_type'
:
'client_credentials'
,
'client_id'
:
ak
,
'client_secret'
:
sk
})
if
response
:
res
=
response
.
json
()
if
res
[
'code'
]
!=
0
:
print
(
'Request access token error.'
)
raise
RuntimeError
(
"Request access token error."
)
else
:
print
(
'Request access token error.'
)
raise
RuntimeError
(
"Request access token error."
)
return
res
[
'data'
]
def
generate_image
(
self
,
def
generate_image
(
self
,
text_prompts
,
text_prompts
,
style
:
Optional
[
str
]
=
"油画"
,
style
:
Optional
[
str
]
=
"油画"
,
...
@@ -46,27 +72,10 @@ class ErnieVilG:
...
@@ -46,27 +72,10 @@ class ErnieVilG:
"""
"""
if
not
os
.
path
.
exists
(
output_dir
):
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
ak
==
None
:
token
=
self
.
token
ak
=
'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE'
if
ak
is
not
None
and
sk
is
not
None
:
if
sk
==
None
:
token
=
self
.
_apply_token
(
ak
,
sk
)
sk
=
'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs'
token_host
=
'https://wenxin.baidu.com/younger/portal/api/oauth/token'
response
=
requests
.
get
(
token_host
,
params
=
{
'grant_type'
:
'client_credentials'
,
'client_id'
:
ak
,
'client_secret'
:
sk
})
if
response
:
res
=
response
.
json
()
if
res
[
'code'
]
!=
0
:
print
(
'Request access token error.'
)
raise
RuntimeError
(
"Request access token error."
)
else
:
print
(
'Request access token error.'
)
raise
RuntimeError
(
"Request access token error."
)
token
=
res
[
'data'
]
create_url
=
'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/txt2img?from=paddlehub'
create_url
=
'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/txt2img?from=paddlehub'
get_url
=
'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/getImg?from=paddlehub'
get_url
=
'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/getImg?from=paddlehub'
if
isinstance
(
text_prompts
,
str
):
if
isinstance
(
text_prompts
,
str
):
...
@@ -93,6 +102,20 @@ class ErnieVilG:
...
@@ -93,6 +102,20 @@ class ErnieVilG:
elif
res
[
'code'
]
==
4004
:
elif
res
[
'code'
]
==
4004
:
print
(
'API服务内部错误,可能引起原因有请求超时、模型推理错误等'
)
print
(
'API服务内部错误,可能引起原因有请求超时、模型推理错误等'
)
raise
RuntimeError
(
"API服务内部错误,可能引起原因有请求超时、模型推理错误等"
)
raise
RuntimeError
(
"API服务内部错误,可能引起原因有请求超时、模型推理错误等"
)
elif
res
[
'code'
]
==
100
or
res
[
'code'
]
==
110
or
res
[
'code'
]
==
111
:
token
=
self
.
_apply_token
(
ak
,
sk
)
res
=
requests
.
post
(
create_url
,
headers
=
{
'Content-Type'
:
'application/x-www-form-urlencoded'
},
data
=
{
'access_token'
:
token
,
"text"
:
text_prompt
,
"style"
:
style
})
res
=
res
.
json
()
if
res
[
'code'
]
!=
0
:
print
(
"Token失效重新请求后依然发生错误,请检查输入的参数"
)
raise
RuntimeError
(
"Token失效重新请求后依然发生错误,请检查输入的参数"
)
taskids
.
append
(
res
[
'data'
][
"taskId"
])
taskids
.
append
(
res
[
'data'
][
"taskId"
])
start_time
=
time
.
time
()
start_time
=
time
.
time
()
...
@@ -124,6 +147,19 @@ class ErnieVilG:
...
@@ -124,6 +147,19 @@ class ErnieVilG:
elif
res
[
'code'
]
==
4004
:
elif
res
[
'code'
]
==
4004
:
print
(
'API服务内部错误,可能引起原因有请求超时、模型推理错误等'
)
print
(
'API服务内部错误,可能引起原因有请求超时、模型推理错误等'
)
raise
RuntimeError
(
"API服务内部错误,可能引起原因有请求超时、模型推理错误等"
)
raise
RuntimeError
(
"API服务内部错误,可能引起原因有请求超时、模型推理错误等"
)
elif
res
[
'code'
]
==
100
or
res
[
'code'
]
==
110
or
res
[
'code'
]
==
111
:
token
=
self
.
_apply_token
(
ak
,
sk
)
res
=
requests
.
post
(
create_url
,
headers
=
{
'Content-Type'
:
'application/x-www-form-urlencoded'
},
data
=
{
'access_token'
:
token
,
"text"
:
text_prompt
,
"style"
:
style
})
res
=
res
.
json
()
if
res
[
'code'
]
!=
0
:
print
(
"Token失效重新请求后依然发生错误,请检查输入的参数"
)
raise
RuntimeError
(
"Token失效重新请求后依然发生错误,请检查输入的参数"
)
if
res
[
'data'
][
'status'
]
==
1
:
if
res
[
'data'
][
'status'
]
==
1
:
has_done
.
append
(
res
[
'data'
][
'taskId'
])
has_done
.
append
(
res
[
'data'
][
'taskId'
])
results
[
res
[
'data'
][
'text'
]]
=
{
results
[
res
[
'data'
][
'text'
]]
=
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录