Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
80c66006
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
80c66006
编写于
8月 16, 2022
作者:
C
chenjian
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix
上级
8a513e59
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
36 addition
and
44 deletion
+36
-44
modules/image/text_to_image/ernie_vilg/README.md
modules/image/text_to_image/ernie_vilg/README.md
+8
-9
modules/image/text_to_image/ernie_vilg/module.py
modules/image/text_to_image/ernie_vilg/module.py
+28
-35
未找到文件。
modules/image/text_to_image/ernie_vilg/README.md
浏览文件 @
80c66006
...
@@ -2,8 +2,8 @@
...
@@ -2,8 +2,8 @@
|模型名称|ernie_vilg|
|模型名称|ernie_vilg|
| :--- | :---: |
| :--- | :---: |
|类别|
多模态
-文图生成|
|类别|
图像
-文图生成|
|网络|
-
|
|网络|
ERNIE-ViLG
|
|数据集|-|
|数据集|-|
|是否支持Fine-tuning|否|
|是否支持Fine-tuning|否|
|模型大小|-|
|模型大小|-|
...
@@ -58,16 +58,14 @@
...
@@ -58,16 +58,14 @@
module = hub.Module(name="ernie_vilg")
module = hub.Module(name="ernie_vilg")
text_prompts = ["宁静的小镇"]
text_prompts = ["宁静的小镇"]
images = module.generate_image(text_prompts=text_prompts,
style='油画',
output_dir='./ernie_vilg_out/')
images = module.generate_image(text_prompts=text_prompts, output_dir='./ernie_vilg_out/')
```
```
-
### 3、API
-
### 3、API
-
```python
-
```python
def generate_image(
def generate_image(
text_prompts: Optional[List[str]] = [
text_prompts:str,
"宁静的乡村"
],
style: Optional[str] = "油画",
style: Optional[str] = "油画",
output_dir: Optional[str] = 'ernievilg_output')
output_dir: Optional[str] = 'ernievilg_output')
```
```
...
@@ -76,13 +74,14 @@
...
@@ -76,13 +74,14 @@
- **参数**
- **参数**
- text_prompts(Optional[List[str]]): 输入的语句,描述想要生成的图像的内容。
- text_prompts(str): 输入的语句,描述想要生成的图像的内容。
- style(Optional[str]): 生成图像的风格,当前支持 油画、水彩画、中国画。
- style(Optional[str]): 生成图像的风格,当前支持'油画','水彩','粉笔画','卡通','儿童画','蜡笔画'。
- topk(Optional[int]): 保存前多少张图,最多保存10张。
- output_dir(Optional[str]): 保存输出图像的目录,默认为"ernievilg_output"。
- output_dir(Optional[str]): 保存输出图像的目录,默认为"ernievilg_output"。
- **返回**
- **返回**
- images(List(PIL.Image)): 返回生成的所有图像列表,PIL的Image格式
,每个prompt生成10张图像
。
- images(List(PIL.Image)): 返回生成的所有图像列表,PIL的Image格式。
## 四、更新历史
## 四、更新历史
...
...
modules/image/text_to_image/ernie_vilg/module.py
浏览文件 @
80c66006
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
argparse
import
ast
import
ast
import
os
import
os
...
@@ -34,21 +21,22 @@ from paddlehub.module.module import serving
...
@@ -34,21 +21,22 @@ from paddlehub.module.module import serving
@
moduleinfo
(
name
=
"ernie_vilg"
,
@
moduleinfo
(
name
=
"ernie_vilg"
,
version
=
"1.0.0"
,
version
=
"1.0.0"
,
type
=
"
MultiModal/image_generation
"
,
type
=
"
image/text_to_image
"
,
summary
=
""
,
summary
=
""
,
author
=
"
paddlepaddle
"
,
author
=
"
baidu-nlp
"
,
author_email
=
"paddle-dev@baidu.com"
)
author_email
=
"paddle-dev@baidu.com"
)
class
ErnieVilG
:
class
ErnieVilG
:
def
generate_image
(
self
,
def
generate_image
(
self
,
text_prompts
:
Optional
[
List
[
str
]]
=
[
"宁静的乡村"
]
,
text_prompts
,
style
:
Optional
[
str
]
=
"油画"
,
style
:
Optional
[
str
]
=
"油画"
,
topk
:
Optional
[
int
]
=
10
,
output_dir
:
Optional
[
str
]
=
'ernievilg_output'
):
output_dir
:
Optional
[
str
]
=
'ernievilg_output'
):
"""
"""
Create image by text prompts using ErnieVilG model.
Create image by text prompts using ErnieVilG model.
:param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like.
:param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like.
:param style: Image stype, currently supported 油画、水彩
画、中国
画
:param style: Image stype, currently supported 油画、水彩
、粉笔画、卡通、儿童画、蜡笔
画
:output_dir: Output directory
:output_dir: Output directory
"""
"""
if
not
os
.
path
.
exists
(
output_dir
):
if
not
os
.
path
.
exists
(
output_dir
):
...
@@ -66,10 +54,10 @@ class ErnieVilG:
...
@@ -66,10 +54,10 @@ class ErnieVilG:
res
=
response
.
json
()
res
=
response
.
json
()
if
res
[
'code'
]
!=
0
:
if
res
[
'code'
]
!=
0
:
print
(
'Request access token error.'
)
print
(
'Request access token error.'
)
exit
(
-
1
)
raise
RuntimeError
(
"Request access token error."
)
else
:
else
:
print
(
'Request access token error.'
)
print
(
'Request access token error.'
)
exit
(
-
1
)
raise
RuntimeError
(
"Request access token error."
)
token
=
res
[
'data'
]
token
=
res
[
'data'
]
create_url
=
'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/txt2img'
create_url
=
'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/txt2img'
...
@@ -77,7 +65,6 @@ class ErnieVilG:
...
@@ -77,7 +65,6 @@ class ErnieVilG:
if
isinstance
(
text_prompts
,
str
):
if
isinstance
(
text_prompts
,
str
):
text_prompts
=
[
text_prompts
]
text_prompts
=
[
text_prompts
]
taskids
=
[]
taskids
=
[]
error
=
False
for
text_prompt
in
text_prompts
:
for
text_prompt
in
text_prompts
:
res
=
requests
.
post
(
create_url
,
res
=
requests
.
post
(
create_url
,
headers
=
{
'Content-Type'
:
'application/x-www-form-urlencoded'
},
headers
=
{
'Content-Type'
:
'application/x-www-form-urlencoded'
},
...
@@ -89,18 +76,16 @@ class ErnieVilG:
...
@@ -89,18 +76,16 @@ class ErnieVilG:
res
=
res
.
json
()
res
=
res
.
json
()
if
res
[
'code'
]
==
4001
:
if
res
[
'code'
]
==
4001
:
print
(
'请求参数错误'
)
print
(
'请求参数错误'
)
error
=
True
raise
RuntimeError
(
"请求参数错误"
)
elif
res
[
'code'
]
==
4002
:
elif
res
[
'code'
]
==
4002
:
print
(
'请求参数格式错误,请检查必传参数是否齐全,参数类型等'
)
print
(
'请求参数格式错误,请检查必传参数是否齐全,参数类型等'
)
error
=
True
raise
RuntimeError
(
"请求参数格式错误,请检查必传参数是否齐全,参数类型等"
)
elif
res
[
'code'
]
==
4003
:
elif
res
[
'code'
]
==
4003
:
print
(
'请求参数中,图片风格不在可选范围内'
)
print
(
'请求参数中,图片风格不在可选范围内'
)
error
=
True
raise
RuntimeError
(
"请求参数中,图片风格不在可选范围内"
)
elif
res
[
'code'
]
==
4004
:
elif
res
[
'code'
]
==
4004
:
print
(
'API服务内部错误,可能引起原因有请求超时、模型推理错误等'
)
print
(
'API服务内部错误,可能引起原因有请求超时、模型推理错误等'
)
error
=
True
raise
RuntimeError
(
"API服务内部错误,可能引起原因有请求超时、模型推理错误等"
)
if
error
==
True
:
exit
(
-
1
)
taskids
.
append
(
res
[
'data'
][
"taskId"
])
taskids
.
append
(
res
[
'data'
][
"taskId"
])
start_time
=
time
.
time
()
start_time
=
time
.
time
()
...
@@ -122,18 +107,16 @@ class ErnieVilG:
...
@@ -122,18 +107,16 @@ class ErnieVilG:
res
=
res
.
json
()
res
=
res
.
json
()
if
res
[
'code'
]
==
4001
:
if
res
[
'code'
]
==
4001
:
print
(
'请求参数错误'
)
print
(
'请求参数错误'
)
error
=
True
raise
RuntimeError
(
"请求参数错误"
)
elif
res
[
'code'
]
==
4002
:
elif
res
[
'code'
]
==
4002
:
print
(
'请求参数格式错误,请检查必传参数是否齐全,参数类型等'
)
print
(
'请求参数格式错误,请检查必传参数是否齐全,参数类型等'
)
error
=
True
raise
RuntimeError
(
"请求参数格式错误,请检查必传参数是否齐全,参数类型等"
)
elif
res
[
'code'
]
==
4003
:
elif
res
[
'code'
]
==
4003
:
print
(
'请求参数中,图片风格不在可选范围内'
)
print
(
'请求参数中,图片风格不在可选范围内'
)
error
=
True
raise
RuntimeError
(
"请求参数中,图片风格不在可选范围内"
)
elif
res
[
'code'
]
==
4004
:
elif
res
[
'code'
]
==
4004
:
print
(
'API服务内部错误,可能引起原因有请求超时、模型推理错误等'
)
print
(
'API服务内部错误,可能引起原因有请求超时、模型推理错误等'
)
error
=
True
raise
RuntimeError
(
"API服务内部错误,可能引起原因有请求超时、模型推理错误等"
)
if
error
==
True
:
exit
(
-
1
)
if
res
[
'data'
][
'status'
]
==
1
:
if
res
[
'data'
][
'status'
]
==
1
:
has_done
.
append
(
res
[
'data'
][
'taskId'
])
has_done
.
append
(
res
[
'data'
][
'taskId'
])
results
[
res
[
'data'
][
'text'
]]
=
{
results
[
res
[
'data'
][
'text'
]]
=
{
...
@@ -161,6 +144,8 @@ class ErnieVilG:
...
@@ -161,6 +144,8 @@ class ErnieVilG:
image
=
Image
.
open
(
BytesIO
(
requests
.
get
(
imgdata
[
'image'
]).
content
))
image
=
Image
.
open
(
BytesIO
(
requests
.
get
(
imgdata
[
'image'
]).
content
))
image
.
save
(
os
.
path
.
join
(
output_dir
,
'{}_{}.png'
.
format
(
text
,
idx
)))
image
.
save
(
os
.
path
.
join
(
output_dir
,
'{}_{}.png'
.
format
(
text
,
idx
)))
result_images
.
append
(
image
)
result_images
.
append
(
image
)
if
idx
+
1
>=
topk
:
break
print
(
'Done'
)
print
(
'Done'
)
return
result_images
return
result_images
...
@@ -176,13 +161,21 @@ class ErnieVilG:
...
@@ -176,13 +161,21 @@ class ErnieVilG:
self
.
arg_input_group
=
self
.
parser
.
add_argument_group
(
title
=
"Input options"
,
description
=
"Input data. Required"
)
self
.
arg_input_group
=
self
.
parser
.
add_argument_group
(
title
=
"Input options"
,
description
=
"Input data. Required"
)
self
.
add_module_input_arg
()
self
.
add_module_input_arg
()
args
=
self
.
parser
.
parse_args
(
argvs
)
args
=
self
.
parser
.
parse_args
(
argvs
)
results
=
self
.
generate_image
(
text_prompts
=
args
.
text_prompts
,
style
=
args
.
style
,
output_dir
=
args
.
output_dir
)
results
=
self
.
generate_image
(
text_prompts
=
args
.
text_prompts
,
style
=
args
.
style
,
topk
=
args
.
topk
,
output_dir
=
args
.
output_dir
)
return
results
return
results
def
add_module_input_arg
(
self
):
def
add_module_input_arg
(
self
):
"""
"""
Add the command input options.
Add the command input options.
"""
"""
self
.
arg_input_group
.
add_argument
(
'--text_prompts'
,
type
=
str
,
default
=
'宁静的小镇'
)
self
.
arg_input_group
.
add_argument
(
'--text_prompts'
,
type
=
str
)
self
.
arg_input_group
.
add_argument
(
'--style'
,
type
=
str
,
default
=
'油画'
,
choices
=
[
'油画'
,
'水彩画'
,
'中国画'
],
help
=
"绘画风格"
)
self
.
arg_input_group
.
add_argument
(
'--style'
,
type
=
str
,
default
=
'油画'
,
choices
=
[
'油画'
,
'水彩'
,
'粉笔画'
,
'卡通'
,
'儿童画'
,
'蜡笔画'
],
help
=
"绘画风格"
)
self
.
arg_input_group
.
add_argument
(
'--topk'
,
type
=
int
,
default
=
10
,
help
=
"选取保存前多少张图,最多10张"
)
self
.
arg_input_group
.
add_argument
(
'--output_dir'
,
type
=
str
,
default
=
'ernievilg_output'
)
self
.
arg_input_group
.
add_argument
(
'--output_dir'
,
type
=
str
,
default
=
'ernievilg_output'
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录