From 80c6600677df2e9c0ffc98e1545e41fd37271564 Mon Sep 17 00:00:00 2001 From: chenjian Date: Tue, 16 Aug 2022 05:01:03 +0000 Subject: [PATCH] fix --- .../image/text_to_image/ernie_vilg/README.md | 17 +++-- .../image/text_to_image/ernie_vilg/module.py | 63 +++++++++---------- 2 files changed, 36 insertions(+), 44 deletions(-) diff --git a/modules/image/text_to_image/ernie_vilg/README.md b/modules/image/text_to_image/ernie_vilg/README.md index 34d6c131..937b971e 100644 --- a/modules/image/text_to_image/ernie_vilg/README.md +++ b/modules/image/text_to_image/ernie_vilg/README.md @@ -2,8 +2,8 @@ |模型名称|ernie_vilg| | :--- | :---: | -|类别|多模态-文图生成| -|网络|-| +|类别|图像-文图生成| +|网络|ERNIE-ViLG| |数据集|-| |是否支持Fine-tuning|否| |模型大小|-| @@ -58,16 +58,14 @@ module = hub.Module(name="ernie_vilg") text_prompts = ["宁静的小镇"] - images = module.generate_image(text_prompts=text_prompts, style='油画', output_dir='./ernie_vilg_out/') + images = module.generate_image(text_prompts=text_prompts, output_dir='./ernie_vilg_out/') ``` - ### 3、API - ```python def generate_image( - text_prompts: Optional[List[str]] = [ - "宁静的乡村" - ], + text_prompts:str, style: Optional[str] = "油画", output_dir: Optional[str] = 'ernievilg_output') ``` @@ -76,13 +74,14 @@ - **参数** - - text_prompts(Optional[List[str]]): 输入的语句,描述想要生成的图像的内容。 - - style(Optional[str]): 生成图像的风格,当前支持 油画、水彩画、中国画。 + - text_prompts(str): 输入的语句,描述想要生成的图像的内容。 + - style(Optional[str]): 生成图像的风格,当前支持'油画','水彩','粉笔画','卡通','儿童画','蜡笔画'。 + - topk(Optional[int]): 保存前多少张图,最多保存10张。 - output_dir(Optional[str]): 保存输出图像的目录,默认为"ernievilg_output"。 - **返回** - - images(List(PIL.Image)): 返回生成的所有图像列表,PIL的Image格式,每个prompt生成10张图像。 + - images(List(PIL.Image)): 返回生成的所有图像列表,PIL的Image格式。 ## 四、更新历史 diff --git a/modules/image/text_to_image/ernie_vilg/module.py b/modules/image/text_to_image/ernie_vilg/module.py index a1b613d8..7fc56390 100644 --- a/modules/image/text_to_image/ernie_vilg/module.py +++ b/modules/image/text_to_image/ernie_vilg/module.py @@ -1,16 +1,3 @@ -# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import argparse import ast import os @@ -34,21 +21,22 @@ from paddlehub.module.module import serving @moduleinfo(name="ernie_vilg", version="1.0.0", - type="MultiModal/image_generation", + type="image/text_to_image", summary="", - author="paddlepaddle", + author="baidu-nlp", author_email="paddle-dev@baidu.com") class ErnieVilG: def generate_image(self, - text_prompts: Optional[List[str]] = ["宁静的乡村"], + text_prompts, style: Optional[str] = "油画", + topk: Optional[int] = 10, output_dir: Optional[str] = 'ernievilg_output'): """ Create image by text prompts using ErnieVilG model. :param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like. - :param style: Image stype, currently supported 油画、水彩画、中国画 + :param style: Image stype, currently supported 油画、水彩、粉笔画、卡通、儿童画、蜡笔画 :output_dir: Output directory """ if not os.path.exists(output_dir): @@ -66,10 +54,10 @@ class ErnieVilG: res = response.json() if res['code'] != 0: print('Request access token error.') - exit(-1) + raise RuntimeError("Request access token error.") else: print('Request access token error.') - exit(-1) + raise RuntimeError("Request access token error.") token = res['data'] create_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/txt2img' @@ -77,7 +65,6 @@ class ErnieVilG: if isinstance(text_prompts, str): text_prompts = [text_prompts] taskids = [] - error = False for text_prompt in text_prompts: res = requests.post(create_url, headers={'Content-Type': 'application/x-www-form-urlencoded'}, @@ -89,18 +76,16 @@ class ErnieVilG: res = res.json() if res['code'] == 4001: print('请求参数错误') - error = True + raise RuntimeError("请求参数错误") elif res['code'] == 4002: print('请求参数格式错误,请检查必传参数是否齐全,参数类型等') - error = True + raise RuntimeError("请求参数格式错误,请检查必传参数是否齐全,参数类型等") elif res['code'] == 4003: print('请求参数中,图片风格不在可选范围内') - error = True + raise RuntimeError("请求参数中,图片风格不在可选范围内") elif res['code'] == 4004: print('API服务内部错误,可能引起原因有请求超时、模型推理错误等') - error = True - if error == True: - exit(-1) + raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等") taskids.append(res['data']["taskId"]) start_time = time.time() @@ -122,18 +107,16 @@ class ErnieVilG: res = res.json() if res['code'] == 4001: print('请求参数错误') - error = True + raise RuntimeError("请求参数错误") elif res['code'] == 4002: print('请求参数格式错误,请检查必传参数是否齐全,参数类型等') - error = True + raise RuntimeError("请求参数格式错误,请检查必传参数是否齐全,参数类型等") elif res['code'] == 4003: print('请求参数中,图片风格不在可选范围内') - error = True + raise RuntimeError("请求参数中,图片风格不在可选范围内") elif res['code'] == 4004: print('API服务内部错误,可能引起原因有请求超时、模型推理错误等') - error = True - if error == True: - exit(-1) + raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等") if res['data']['status'] == 1: has_done.append(res['data']['taskId']) results[res['data']['text']] = { @@ -161,6 +144,8 @@ class ErnieVilG: image = Image.open(BytesIO(requests.get(imgdata['image']).content)) image.save(os.path.join(output_dir, '{}_{}.png'.format(text, idx))) result_images.append(image) + if idx + 1 >= topk: + break print('Done') return result_images @@ -176,13 +161,21 @@ class ErnieVilG: self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.add_module_input_arg() args = self.parser.parse_args(argvs) - results = self.generate_image(text_prompts=args.text_prompts, style=args.style, output_dir=args.output_dir) + results = self.generate_image(text_prompts=args.text_prompts, + style=args.style, + topk=args.topk, + output_dir=args.output_dir) return results def add_module_input_arg(self): """ Add the command input options. """ - self.arg_input_group.add_argument('--text_prompts', type=str, default='宁静的小镇') - self.arg_input_group.add_argument('--style', type=str, default='油画', choices=['油画', '水彩画', '中国画'], help="绘画风格") + self.arg_input_group.add_argument('--text_prompts', type=str) + self.arg_input_group.add_argument('--style', + type=str, + default='油画', + choices=['油画', '水彩', '粉笔画', '卡通', '儿童画', '蜡笔画'], + help="绘画风格") + self.arg_input_group.add_argument('--topk', type=int, default=10, help="选取保存前多少张图,最多10张") self.arg_input_group.add_argument('--output_dir', type=str, default='ernievilg_output') -- GitLab