diff --git a/modules/image/text_to_image/ernie_vilg/README.md b/modules/image/text_to_image/ernie_vilg/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2ef41dfdc943b01aee11805678a3edd64c3c9287 --- /dev/null +++ b/modules/image/text_to_image/ernie_vilg/README.md @@ -0,0 +1,95 @@ +# ernie_vilg + +|模型名称|ernie_vilg| +| :--- | :---: | +|类别|多模态-文图生成| +|网络|-| +|数据集|-| +|是否支持Fine-tuning|否| +|模型大小|-| +|最新更新日期|2022-08-02| +|数据指标|-| + +## 一、模型基本信息 + +### 应用效果展示 + + - 输入文本 "宁静的小镇" 风格 "油画" + + - 输出图像 +
+
+
+
+
+### 模型介绍
+
+文心ERNIE-ViLG参数规模达到100亿,是目前为止全球最大规模中文跨模态生成模型,在文本生成图像、图像描述等跨模态生成任务上效果全球领先,在图文生成领域MS-COCO、COCO-CN、AIC-ICC等数据集上取得最好效果。你可以输入一段文本描述以及生成风格,模型就会根据输入的内容自动创作出符合要求的图像。
+
+## 二、安装
+
+- ### 1、环境依赖
+
+ - paddlepaddle >= 2.0.0
+
+ - paddlehub >= 2.2.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst)
+
+- ### 2、安装
+
+ - ```shell
+ $ hub install ernie_vilg
+ ```
+ - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
+ | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
+
+
+## 三、模型API预测
+
+- ### 1、命令行预测
+
+ - ```shell
+ $ hub run ernie_vilg --text_prompts "宁静的小镇" --output_dir ernie_vilg_out
+ ```
+
+- ### 2、预测代码示例
+
+ - ```python
+ import paddlehub as hub
+
+ module = hub.Module(name="ernie_vilg")
+ text_prompts = ["宁静的小镇"]
+ images = module.generate_image(text_prompts=text_prompts, output_dir='./ernie_vilg_out/')
+ ```
+
+- ### 3、API
+
+ - ```python
+ def generate_image(
+ text_prompts: Optional[List[str]] = [
+ "宁静的乡村"
+ ],
+ style: Optional[str] = "油画",
+ output_dir: Optional[str] = 'ernievilg_output')
+ ```
+
+ - 文图生成API,生成文本描述内容的图像。
+
+ - **参数**
+
+ - text_prompts(Optional[List[str]]): 输入的语句,描述想要生成的图像的内容。
+ - style(Optional[str]): 生成图像的风格,当前支持 油画、水彩画、中国画。
+ - output_dir(Optional[str]): 保存输出图像的目录,默认为"ernievilg_output"。
+
+
+ - **返回**
+ - images(List(PIL.Image)): 返回生成的所有图像列表,PIL的Image格式,每个prompt生成10张图像。
+
+## 四、更新历史
+
+* 1.0.0
+
+ 初始发布
+
+ ```shell
+ $ hub install ernie_vilg == 1.0.0
+ ```
diff --git a/modules/image/text_to_image/ernie_vilg/module.py b/modules/image/text_to_image/ernie_vilg/module.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1b613d854e9e2e5a2a6666c364fc3a7d82b5eba
--- /dev/null
+++ b/modules/image/text_to_image/ernie_vilg/module.py
@@ -0,0 +1,188 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import ast
+import os
+import re
+import sys
+import time
+from functools import partial
+from io import BytesIO
+from typing import List
+from typing import Optional
+
+import requests
+from PIL import Image
+from tqdm.auto import tqdm
+
+import paddlehub as hub
+from paddlehub.module.module import moduleinfo
+from paddlehub.module.module import runnable
+from paddlehub.module.module import serving
+
+
+@moduleinfo(name="ernie_vilg",
+ version="1.0.0",
+ type="MultiModal/image_generation",
+ summary="",
+ author="paddlepaddle",
+ author_email="paddle-dev@baidu.com")
+class ErnieVilG:
+
+ def generate_image(self,
+ text_prompts: Optional[List[str]] = ["宁静的乡村"],
+ style: Optional[str] = "油画",
+ output_dir: Optional[str] = 'ernievilg_output'):
+ """
+ Create image by text prompts using ErnieVilG model.
+
+ :param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like.
+ :param style: Image stype, currently supported 油画、水彩画、中国画
+ :output_dir: Output directory
+ """
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir, exist_ok=True)
+ ak = 'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE'
+ sk = 'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs'
+ token_host = 'https://wenxin.baidu.com/younger/portal/api/oauth/token'
+ response = requests.get(token_host,
+ params={
+ 'grant_type': 'client_credentials',
+ 'client_id': ak,
+ 'client_secret': sk
+ })
+ if response:
+ res = response.json()
+ if res['code'] != 0:
+ print('Request access token error.')
+ exit(-1)
+ else:
+ print('Request access token error.')
+ exit(-1)
+
+ token = res['data']
+ create_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/txt2img'
+ get_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/getImg'
+ if isinstance(text_prompts, str):
+ text_prompts = [text_prompts]
+ taskids = []
+ error = False
+ for text_prompt in text_prompts:
+ res = requests.post(create_url,
+ headers={'Content-Type': 'application/x-www-form-urlencoded'},
+ data={
+ 'access_token': token,
+ "text": text_prompt,
+ "style": style
+ })
+ res = res.json()
+ if res['code'] == 4001:
+ print('请求参数错误')
+ error = True
+ elif res['code'] == 4002:
+ print('请求参数格式错误,请检查必传参数是否齐全,参数类型等')
+ error = True
+ elif res['code'] == 4003:
+ print('请求参数中,图片风格不在可选范围内')
+ error = True
+ elif res['code'] == 4004:
+ print('API服务内部错误,可能引起原因有请求超时、模型推理错误等')
+ error = True
+ if error == True:
+ exit(-1)
+ taskids.append(res['data']["taskId"])
+
+ start_time = time.time()
+ process_bar = tqdm(total=100, unit='%')
+ results = {}
+ first_iter = True
+ while True:
+ if not taskids:
+ break
+ total_time = 0
+ has_done = []
+ for taskid in taskids:
+ res = requests.post(get_url,
+ headers={'Content-Type': 'application/x-www-form-urlencoded'},
+ data={
+ 'access_token': token,
+ 'taskId': {taskid}
+ })
+ res = res.json()
+ if res['code'] == 4001:
+ print('请求参数错误')
+ error = True
+ elif res['code'] == 4002:
+ print('请求参数格式错误,请检查必传参数是否齐全,参数类型等')
+ error = True
+ elif res['code'] == 4003:
+ print('请求参数中,图片风格不在可选范围内')
+ error = True
+ elif res['code'] == 4004:
+ print('API服务内部错误,可能引起原因有请求超时、模型推理错误等')
+ error = True
+ if error == True:
+ exit(-1)
+ if res['data']['status'] == 1:
+ has_done.append(res['data']['taskId'])
+ results[res['data']['text']] = {
+ 'imgUrls': res['data']['imgUrls'],
+ 'waiting': res['data']['waiting'],
+ 'taskId': res['data']['taskId']
+ }
+ total_time = int(re.match('[0-9]+', str(res['data']['waiting'])).group(0)) * 60
+ end_time = time.time()
+ progress_rate = int(((end_time - start_time) / total_time * 100)) if total_time != 0 else 100
+ if progress_rate > process_bar.n:
+ increase_rate = progress_rate - process_bar.n
+ if progress_rate >= 100:
+ increase_rate = 100 - process_bar.n
+ else:
+ increase_rate = 0
+ process_bar.update(increase_rate)
+ time.sleep(5)
+ for taskid in has_done:
+ taskids.remove(taskid)
+ print('Saving Images...')
+ result_images = []
+ for text, data in results.items():
+ for idx, imgdata in enumerate(data['imgUrls']):
+ image = Image.open(BytesIO(requests.get(imgdata['image']).content))
+ image.save(os.path.join(output_dir, '{}_{}.png'.format(text, idx)))
+ result_images.append(image)
+ print('Done')
+ return result_images
+
+ @runnable
+ def run_cmd(self, argvs):
+ """
+ Run as a command.
+ """
+ self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
+ prog='hub run {}'.format(self.name),
+ usage='%(prog)s',
+ add_help=True)
+ self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
+ self.add_module_input_arg()
+ args = self.parser.parse_args(argvs)
+ results = self.generate_image(text_prompts=args.text_prompts, style=args.style, output_dir=args.output_dir)
+ return results
+
+ def add_module_input_arg(self):
+ """
+ Add the command input options.
+ """
+ self.arg_input_group.add_argument('--text_prompts', type=str, default='宁静的小镇')
+ self.arg_input_group.add_argument('--style', type=str, default='油画', choices=['油画', '水彩画', '中国画'], help="绘画风格")
+ self.arg_input_group.add_argument('--output_dir', type=str, default='ernievilg_output')
diff --git a/modules/image/text_to_image/ernie_vilg/requirements.txt b/modules/image/text_to_image/ernie_vilg/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5bb8c66c68ea2361106e2f5bbca7f136537e62b8
--- /dev/null
+++ b/modules/image/text_to_image/ernie_vilg/requirements.txt
@@ -0,0 +1,2 @@
+requests
+tqdm