提交 80c66006 编写于 作者: C chenjian

fix

上级 8a513e59
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
|模型名称|ernie_vilg| |模型名称|ernie_vilg|
| :--- | :---: | | :--- | :---: |
|类别|多模态-文图生成| |类别|图像-文图生成|
|网络|-| |网络|ERNIE-ViLG|
|数据集|-| |数据集|-|
|是否支持Fine-tuning|否| |是否支持Fine-tuning|否|
|模型大小|-| |模型大小|-|
...@@ -58,16 +58,14 @@ ...@@ -58,16 +58,14 @@
module = hub.Module(name="ernie_vilg") module = hub.Module(name="ernie_vilg")
text_prompts = ["宁静的小镇"] text_prompts = ["宁静的小镇"]
images = module.generate_image(text_prompts=text_prompts, style='油画', output_dir='./ernie_vilg_out/') images = module.generate_image(text_prompts=text_prompts, output_dir='./ernie_vilg_out/')
``` ```
- ### 3、API - ### 3、API
- ```python - ```python
def generate_image( def generate_image(
text_prompts: Optional[List[str]] = [ text_prompts:str,
"宁静的乡村"
],
style: Optional[str] = "油画", style: Optional[str] = "油画",
output_dir: Optional[str] = 'ernievilg_output') output_dir: Optional[str] = 'ernievilg_output')
``` ```
...@@ -76,13 +74,14 @@ ...@@ -76,13 +74,14 @@
- **参数** - **参数**
- text_prompts(Optional[List[str]]): 输入的语句,描述想要生成的图像的内容。 - text_prompts(str): 输入的语句,描述想要生成的图像的内容。
- style(Optional[str]): 生成图像的风格,当前支持 油画、水彩画、中国画。 - style(Optional[str]): 生成图像的风格,当前支持'油画','水彩','粉笔画','卡通','儿童画','蜡笔画'。
- topk(Optional[int]): 保存前多少张图,最多保存10张。
- output_dir(Optional[str]): 保存输出图像的目录,默认为"ernievilg_output"。 - output_dir(Optional[str]): 保存输出图像的目录,默认为"ernievilg_output"。
- **返回** - **返回**
- images(List(PIL.Image)): 返回生成的所有图像列表,PIL的Image格式,每个prompt生成10张图像 - images(List(PIL.Image)): 返回生成的所有图像列表,PIL的Image格式。
## 四、更新历史 ## 四、更新历史
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import ast import ast
import os import os
...@@ -34,21 +21,22 @@ from paddlehub.module.module import serving ...@@ -34,21 +21,22 @@ from paddlehub.module.module import serving
@moduleinfo(name="ernie_vilg", @moduleinfo(name="ernie_vilg",
version="1.0.0", version="1.0.0",
type="MultiModal/image_generation", type="image/text_to_image",
summary="", summary="",
author="paddlepaddle", author="baidu-nlp",
author_email="paddle-dev@baidu.com") author_email="paddle-dev@baidu.com")
class ErnieVilG: class ErnieVilG:
def generate_image(self, def generate_image(self,
text_prompts: Optional[List[str]] = ["宁静的乡村"], text_prompts,
style: Optional[str] = "油画", style: Optional[str] = "油画",
topk: Optional[int] = 10,
output_dir: Optional[str] = 'ernievilg_output'): output_dir: Optional[str] = 'ernievilg_output'):
""" """
Create image by text prompts using ErnieVilG model. Create image by text prompts using ErnieVilG model.
:param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like. :param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like.
:param style: Image stype, currently supported 油画、水彩画、中国 :param style: Image stype, currently supported 油画、水彩、粉笔画、卡通、儿童画、蜡笔
:output_dir: Output directory :output_dir: Output directory
""" """
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
...@@ -66,10 +54,10 @@ class ErnieVilG: ...@@ -66,10 +54,10 @@ class ErnieVilG:
res = response.json() res = response.json()
if res['code'] != 0: if res['code'] != 0:
print('Request access token error.') print('Request access token error.')
exit(-1) raise RuntimeError("Request access token error.")
else: else:
print('Request access token error.') print('Request access token error.')
exit(-1) raise RuntimeError("Request access token error.")
token = res['data'] token = res['data']
create_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/txt2img' create_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/txt2img'
...@@ -77,7 +65,6 @@ class ErnieVilG: ...@@ -77,7 +65,6 @@ class ErnieVilG:
if isinstance(text_prompts, str): if isinstance(text_prompts, str):
text_prompts = [text_prompts] text_prompts = [text_prompts]
taskids = [] taskids = []
error = False
for text_prompt in text_prompts: for text_prompt in text_prompts:
res = requests.post(create_url, res = requests.post(create_url,
headers={'Content-Type': 'application/x-www-form-urlencoded'}, headers={'Content-Type': 'application/x-www-form-urlencoded'},
...@@ -89,18 +76,16 @@ class ErnieVilG: ...@@ -89,18 +76,16 @@ class ErnieVilG:
res = res.json() res = res.json()
if res['code'] == 4001: if res['code'] == 4001:
print('请求参数错误') print('请求参数错误')
error = True raise RuntimeError("请求参数错误")
elif res['code'] == 4002: elif res['code'] == 4002:
print('请求参数格式错误,请检查必传参数是否齐全,参数类型等') print('请求参数格式错误,请检查必传参数是否齐全,参数类型等')
error = True raise RuntimeError("请求参数格式错误,请检查必传参数是否齐全,参数类型等")
elif res['code'] == 4003: elif res['code'] == 4003:
print('请求参数中,图片风格不在可选范围内') print('请求参数中,图片风格不在可选范围内')
error = True raise RuntimeError("请求参数中,图片风格不在可选范围内")
elif res['code'] == 4004: elif res['code'] == 4004:
print('API服务内部错误,可能引起原因有请求超时、模型推理错误等') print('API服务内部错误,可能引起原因有请求超时、模型推理错误等')
error = True raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等")
if error == True:
exit(-1)
taskids.append(res['data']["taskId"]) taskids.append(res['data']["taskId"])
start_time = time.time() start_time = time.time()
...@@ -122,18 +107,16 @@ class ErnieVilG: ...@@ -122,18 +107,16 @@ class ErnieVilG:
res = res.json() res = res.json()
if res['code'] == 4001: if res['code'] == 4001:
print('请求参数错误') print('请求参数错误')
error = True raise RuntimeError("请求参数错误")
elif res['code'] == 4002: elif res['code'] == 4002:
print('请求参数格式错误,请检查必传参数是否齐全,参数类型等') print('请求参数格式错误,请检查必传参数是否齐全,参数类型等')
error = True raise RuntimeError("请求参数格式错误,请检查必传参数是否齐全,参数类型等")
elif res['code'] == 4003: elif res['code'] == 4003:
print('请求参数中,图片风格不在可选范围内') print('请求参数中,图片风格不在可选范围内')
error = True raise RuntimeError("请求参数中,图片风格不在可选范围内")
elif res['code'] == 4004: elif res['code'] == 4004:
print('API服务内部错误,可能引起原因有请求超时、模型推理错误等') print('API服务内部错误,可能引起原因有请求超时、模型推理错误等')
error = True raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等")
if error == True:
exit(-1)
if res['data']['status'] == 1: if res['data']['status'] == 1:
has_done.append(res['data']['taskId']) has_done.append(res['data']['taskId'])
results[res['data']['text']] = { results[res['data']['text']] = {
...@@ -161,6 +144,8 @@ class ErnieVilG: ...@@ -161,6 +144,8 @@ class ErnieVilG:
image = Image.open(BytesIO(requests.get(imgdata['image']).content)) image = Image.open(BytesIO(requests.get(imgdata['image']).content))
image.save(os.path.join(output_dir, '{}_{}.png'.format(text, idx))) image.save(os.path.join(output_dir, '{}_{}.png'.format(text, idx)))
result_images.append(image) result_images.append(image)
if idx + 1 >= topk:
break
print('Done') print('Done')
return result_images return result_images
...@@ -176,13 +161,21 @@ class ErnieVilG: ...@@ -176,13 +161,21 @@ class ErnieVilG:
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.add_module_input_arg() self.add_module_input_arg()
args = self.parser.parse_args(argvs) args = self.parser.parse_args(argvs)
results = self.generate_image(text_prompts=args.text_prompts, style=args.style, output_dir=args.output_dir) results = self.generate_image(text_prompts=args.text_prompts,
style=args.style,
topk=args.topk,
output_dir=args.output_dir)
return results return results
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
Add the command input options. Add the command input options.
""" """
self.arg_input_group.add_argument('--text_prompts', type=str, default='宁静的小镇') self.arg_input_group.add_argument('--text_prompts', type=str)
self.arg_input_group.add_argument('--style', type=str, default='油画', choices=['油画', '水彩画', '中国画'], help="绘画风格") self.arg_input_group.add_argument('--style',
type=str,
default='油画',
choices=['油画', '水彩', '粉笔画', '卡通', '儿童画', '蜡笔画'],
help="绘画风格")
self.arg_input_group.add_argument('--topk', type=int, default=10, help="选取保存前多少张图,最多10张")
self.arg_input_group.add_argument('--output_dir', type=str, default='ernievilg_output') self.arg_input_group.add_argument('--output_dir', type=str, default='ernievilg_output')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册