提交 80c66006 编写于 作者: C chenjian

fix

上级 8a513e59
......@@ -2,8 +2,8 @@
|模型名称|ernie_vilg|
| :--- | :---: |
|类别|多模态-文图生成|
|网络|-|
|类别|图像-文图生成|
|网络|ERNIE-ViLG|
|数据集|-|
|是否支持Fine-tuning|否|
|模型大小|-|
......@@ -58,16 +58,14 @@
module = hub.Module(name="ernie_vilg")
text_prompts = ["宁静的小镇"]
images = module.generate_image(text_prompts=text_prompts, style='油画', output_dir='./ernie_vilg_out/')
images = module.generate_image(text_prompts=text_prompts, output_dir='./ernie_vilg_out/')
```
- ### 3、API
- ```python
def generate_image(
text_prompts: Optional[List[str]] = [
"宁静的乡村"
],
text_prompts:str,
style: Optional[str] = "油画",
output_dir: Optional[str] = 'ernievilg_output')
```
......@@ -76,13 +74,14 @@
- **参数**
- text_prompts(Optional[List[str]]): 输入的语句,描述想要生成的图像的内容。
- style(Optional[str]): 生成图像的风格,当前支持 油画、水彩画、中国画。
- text_prompts(str): 输入的语句,描述想要生成的图像的内容。
- style(Optional[str]): 生成图像的风格,当前支持'油画','水彩','粉笔画','卡通','儿童画','蜡笔画'。
- topk(Optional[int]): 保存前多少张图,最多保存10张。
- output_dir(Optional[str]): 保存输出图像的目录,默认为"ernievilg_output"。
- **返回**
- images(List(PIL.Image)): 返回生成的所有图像列表,PIL的Image格式,每个prompt生成10张图像
- images(List(PIL.Image)): 返回生成的所有图像列表,PIL的Image格式。
## 四、更新历史
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
......@@ -34,21 +21,22 @@ from paddlehub.module.module import serving
@moduleinfo(name="ernie_vilg",
version="1.0.0",
type="MultiModal/image_generation",
type="image/text_to_image",
summary="",
author="paddlepaddle",
author="baidu-nlp",
author_email="paddle-dev@baidu.com")
class ErnieVilG:
def generate_image(self,
text_prompts: Optional[List[str]] = ["宁静的乡村"],
text_prompts,
style: Optional[str] = "油画",
topk: Optional[int] = 10,
output_dir: Optional[str] = 'ernievilg_output'):
"""
Create image by text prompts using ErnieVilG model.
:param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like.
:param style: Image stype, currently supported 油画、水彩画、中国
:param style: Image stype, currently supported 油画、水彩、粉笔画、卡通、儿童画、蜡笔
:output_dir: Output directory
"""
if not os.path.exists(output_dir):
......@@ -66,10 +54,10 @@ class ErnieVilG:
res = response.json()
if res['code'] != 0:
print('Request access token error.')
exit(-1)
raise RuntimeError("Request access token error.")
else:
print('Request access token error.')
exit(-1)
raise RuntimeError("Request access token error.")
token = res['data']
create_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/txt2img'
......@@ -77,7 +65,6 @@ class ErnieVilG:
if isinstance(text_prompts, str):
text_prompts = [text_prompts]
taskids = []
error = False
for text_prompt in text_prompts:
res = requests.post(create_url,
headers={'Content-Type': 'application/x-www-form-urlencoded'},
......@@ -89,18 +76,16 @@ class ErnieVilG:
res = res.json()
if res['code'] == 4001:
print('请求参数错误')
error = True
raise RuntimeError("请求参数错误")
elif res['code'] == 4002:
print('请求参数格式错误,请检查必传参数是否齐全,参数类型等')
error = True
raise RuntimeError("请求参数格式错误,请检查必传参数是否齐全,参数类型等")
elif res['code'] == 4003:
print('请求参数中,图片风格不在可选范围内')
error = True
raise RuntimeError("请求参数中,图片风格不在可选范围内")
elif res['code'] == 4004:
print('API服务内部错误,可能引起原因有请求超时、模型推理错误等')
error = True
if error == True:
exit(-1)
raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等")
taskids.append(res['data']["taskId"])
start_time = time.time()
......@@ -122,18 +107,16 @@ class ErnieVilG:
res = res.json()
if res['code'] == 4001:
print('请求参数错误')
error = True
raise RuntimeError("请求参数错误")
elif res['code'] == 4002:
print('请求参数格式错误,请检查必传参数是否齐全,参数类型等')
error = True
raise RuntimeError("请求参数格式错误,请检查必传参数是否齐全,参数类型等")
elif res['code'] == 4003:
print('请求参数中,图片风格不在可选范围内')
error = True
raise RuntimeError("请求参数中,图片风格不在可选范围内")
elif res['code'] == 4004:
print('API服务内部错误,可能引起原因有请求超时、模型推理错误等')
error = True
if error == True:
exit(-1)
raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等")
if res['data']['status'] == 1:
has_done.append(res['data']['taskId'])
results[res['data']['text']] = {
......@@ -161,6 +144,8 @@ class ErnieVilG:
image = Image.open(BytesIO(requests.get(imgdata['image']).content))
image.save(os.path.join(output_dir, '{}_{}.png'.format(text, idx)))
result_images.append(image)
if idx + 1 >= topk:
break
print('Done')
return result_images
......@@ -176,13 +161,21 @@ class ErnieVilG:
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
results = self.generate_image(text_prompts=args.text_prompts, style=args.style, output_dir=args.output_dir)
results = self.generate_image(text_prompts=args.text_prompts,
style=args.style,
topk=args.topk,
output_dir=args.output_dir)
return results
def add_module_input_arg(self):
"""
Add the command input options.
"""
self.arg_input_group.add_argument('--text_prompts', type=str, default='宁静的小镇')
self.arg_input_group.add_argument('--style', type=str, default='油画', choices=['油画', '水彩画', '中国画'], help="绘画风格")
self.arg_input_group.add_argument('--text_prompts', type=str)
self.arg_input_group.add_argument('--style',
type=str,
default='油画',
choices=['油画', '水彩', '粉笔画', '卡通', '儿童画', '蜡笔画'],
help="绘画风格")
self.arg_input_group.add_argument('--topk', type=int, default=10, help="选取保存前多少张图,最多10张")
self.arg_input_group.add_argument('--output_dir', type=str, default='ernievilg_output')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册