diff --git a/modules/image/text_to_image/ernie_vilg/README.md b/modules/image/text_to_image/ernie_vilg/README.md deleted file mode 100644 index c85e52d2bcaf23662582125360efe167ab82fc00..0000000000000000000000000000000000000000 --- a/modules/image/text_to_image/ernie_vilg/README.md +++ /dev/null @@ -1,104 +0,0 @@ -# ernie_vilg - -|模型名称|ernie_vilg| -| :--- | :---: | -|类别|图像-文图生成| -|网络|ERNIE-ViLG| -|数据集|-| -|是否支持Fine-tuning|否| -|模型大小|-| -|最新更新日期|2022-08-02| -|数据指标|-| - -## 一、模型基本信息 - -### 应用效果展示 - - - 输入文本 "宁静的小镇" 风格 "油画" - - - 输出图像 -

- -
- - -### 模型介绍 - -文心ERNIE-ViLG参数规模达到100亿,是目前为止全球最大规模中文跨模态生成模型,在文本生成图像、图像描述等跨模态生成任务上效果全球领先,在图文生成领域MS-COCO、COCO-CN、AIC-ICC等数据集上取得最好效果。你可以输入一段文本描述以及生成风格,模型就会根据输入的内容自动创作出符合要求的图像。 - -## 二、安装 - -- ### 1、环境依赖 - - - paddlepaddle >= 2.0.0 - - - paddlehub >= 2.2.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst) - -- ### 2、安装 - - - ```shell - $ hub install ernie_vilg - ``` - - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) - | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) - - -## 三、模型API预测 - -- ### 1、命令行预测 - - - ```shell - $ hub run ernie_vilg --text_prompts "宁静的小镇" --output_dir ernie_vilg_out - ``` - -- ### 2、预测代码示例 - - - ```python - import paddlehub as hub - - module = hub.Module(name="ernie_vilg") - text_prompts = ["宁静的小镇"] - images = module.generate_image(text_prompts=text_prompts, output_dir='./ernie_vilg_out/') - ``` - -- ### 3、API - - - ```python - def __init__(ak: Optional[str]=None, sk: Optional[str]=None) - ``` - - 初始化模块,可自定义用于申请访问文心API的ak和sk。 - - - **参数** - - ak:(Optional[str]): 用于申请文心api使用token的ak,可不填。 - - sk:(Optional[str]): 用于申请文心api使用token的sk,可不填。 - - - ```python - def generate_image( - text_prompts:str, - style: Optional[str] = "油画", - topk: Optional[int] = 10, - output_dir: Optional[str] = 'ernievilg_output') - ``` - - - 文图生成API,生成文本描述内容的图像。 - - - **参数** - - - text_prompts(str): 输入的语句,描述想要生成的图像的内容。 - - style(Optional[str]): 生成图像的风格,当前支持'油画','水彩','粉笔画','卡通','儿童画','蜡笔画'。 - - topk(Optional[int]): 保存前多少张图,最多保存10张。 - - output_dir(Optional[str]): 保存输出图像的目录,默认为"ernievilg_output"。 - - - - **返回** - - images(List(PIL.Image)): 返回生成的所有图像列表,PIL的Image格式。 - -## 四、更新历史 - -* 1.0.0 - - 初始发布 - - ```shell - $ hub install ernie_vilg == 1.0.0 - ``` diff --git a/modules/image/text_to_image/ernie_vilg/module.py b/modules/image/text_to_image/ernie_vilg/module.py deleted file mode 100644 index 7af5abb0c3a335402823c19e39a26aed71787528..0000000000000000000000000000000000000000 --- a/modules/image/text_to_image/ernie_vilg/module.py +++ /dev/null @@ -1,230 +0,0 @@ -import argparse -import ast -import os -import re -import sys -import time -from functools import partial -from io import BytesIO -from typing import List -from typing import Optional - -import requests -from PIL import Image -from tqdm.auto import tqdm - -import paddlehub as hub -from paddlehub.module.module import moduleinfo -from paddlehub.module.module import runnable -from paddlehub.module.module import serving - - -@moduleinfo(name="ernie_vilg", - version="1.0.0", - type="image/text_to_image", - summary="", - author="baidu-nlp", - author_email="paddle-dev@baidu.com") -class ErnieVilG: - - def __init__(self, ak=None, sk=None): - """ - :param ak: ak for applying token to request wenxin api. - :param sk: sk for applying token to request wenxin api. - """ - if ak is None or sk is None: - self.ak = 'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE' - self.sk = 'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs' - else: - self.ak = ak - self.sk = sk - self.token_host = 'https://wenxin.baidu.com/younger/portal/api/oauth/token' - self.token = self._apply_token(self.ak, self.sk) - - def _apply_token(self, ak, sk): - if ak is None or sk is None: - ak = self.ak - sk = self.sk - response = requests.get(self.token_host, - params={ - 'grant_type': 'client_credentials', - 'client_id': ak, - 'client_secret': sk - }) - if response: - res = response.json() - if res['code'] != 0: - print('Request access token error.') - raise RuntimeError("Request access token error.") - else: - print('Request access token error.') - raise RuntimeError("Request access token error.") - return res['data'] - - def generate_image(self, - text_prompts, - style: Optional[str] = "油画", - topk: Optional[int] = 10, - output_dir: Optional[str] = 'ernievilg_output'): - """ - Create image by text prompts using ErnieVilG model. - - :param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like. - :param style: Image stype, currently supported 油画、水彩、粉笔画、卡通、儿童画、蜡笔画 - :param topk: Top k images to save. - :output_dir: Output directory - """ - if not os.path.exists(output_dir): - os.makedirs(output_dir, exist_ok=True) - token = self.token - create_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/txt2img?from=paddlehub' - get_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/getImg?from=paddlehub' - if isinstance(text_prompts, str): - text_prompts = [text_prompts] - taskids = [] - for text_prompt in text_prompts: - res = requests.post(create_url, - headers={'Content-Type': 'application/x-www-form-urlencoded'}, - data={ - 'access_token': token, - "text": text_prompt, - "style": style - }) - res = res.json() - if res['code'] == 4001: - print('请求参数错误') - raise RuntimeError("请求参数错误") - elif res['code'] == 4002: - print('请求参数格式错误,请检查必传参数是否齐全,参数类型等') - raise RuntimeError("请求参数格式错误,请检查必传参数是否齐全,参数类型等") - elif res['code'] == 4003: - print('请求参数中,图片风格不在可选范围内') - raise RuntimeError("请求参数中,图片风格不在可选范围内") - elif res['code'] == 4004: - print('API服务内部错误,可能引起原因有请求超时、模型推理错误等') - raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等") - elif res['code'] == 100 or res['code'] == 110 or res['code'] == 111: - token = self._apply_token(self.ak, self.sk) - res = requests.post(create_url, - headers={'Content-Type': 'application/x-www-form-urlencoded'}, - data={ - 'access_token': token, - "text": text_prompt, - "style": style - }) - res = res.json() - if res['code'] != 0: - print("Token失效重新请求后依然发生错误,请检查输入的参数") - raise RuntimeError("Token失效重新请求后依然发生错误,请检查输入的参数") - - taskids.append(res['data']["taskId"]) - - start_time = time.time() - process_bar = tqdm(total=100, unit='%') - results = {} - first_iter = True - while True: - if not taskids: - break - total_time = 0 - has_done = [] - for taskid in taskids: - res = requests.post(get_url, - headers={'Content-Type': 'application/x-www-form-urlencoded'}, - data={ - 'access_token': token, - 'taskId': {taskid} - }) - res = res.json() - if res['code'] == 4001: - print('请求参数错误') - raise RuntimeError("请求参数错误") - elif res['code'] == 4002: - print('请求参数格式错误,请检查必传参数是否齐全,参数类型等') - raise RuntimeError("请求参数格式错误,请检查必传参数是否齐全,参数类型等") - elif res['code'] == 4003: - print('请求参数中,图片风格不在可选范围内') - raise RuntimeError("请求参数中,图片风格不在可选范围内") - elif res['code'] == 4004: - print('API服务内部错误,可能引起原因有请求超时、模型推理错误等') - raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等") - elif res['code'] == 100 or res['code'] == 110 or res['code'] == 111: - token = self._apply_token(self.ak, self.sk) - res = requests.post(get_url, - headers={'Content-Type': 'application/x-www-form-urlencoded'}, - data={ - 'access_token': token, - 'taskId': {taskid} - }) - res = res.json() - if res['code'] != 0: - print("Token失效重新请求后依然发生错误,请检查输入的参数") - raise RuntimeError("Token失效重新请求后依然发生错误,请检查输入的参数") - if res['data']['status'] == 1: - has_done.append(res['data']['taskId']) - results[res['data']['text']] = { - 'imgUrls': res['data']['imgUrls'], - 'waiting': res['data']['waiting'], - 'taskId': res['data']['taskId'] - } - total_time = int(re.match('[0-9]+', str(res['data']['waiting'])).group(0)) * 60 - end_time = time.time() - progress_rate = int(((end_time - start_time) / total_time * 100)) if total_time != 0 else 100 - if progress_rate > process_bar.n: - increase_rate = progress_rate - process_bar.n - if progress_rate >= 100: - increase_rate = 100 - process_bar.n - else: - increase_rate = 0 - process_bar.update(increase_rate) - time.sleep(5) - for taskid in has_done: - taskids.remove(taskid) - print('Saving Images...') - result_images = [] - for text, data in results.items(): - for idx, imgdata in enumerate(data['imgUrls']): - image = Image.open(BytesIO(requests.get(imgdata['image']).content)) - image.save(os.path.join(output_dir, '{}_{}.png'.format(text, idx))) - result_images.append(image) - if idx + 1 >= topk: - break - print('Done') - return result_images - - @runnable - def run_cmd(self, argvs): - """ - Run as a command. - """ - self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name), - prog='hub run {}'.format(self.name), - usage='%(prog)s', - add_help=True) - self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") - self.add_module_input_arg() - args = self.parser.parse_args(argvs) - if args.ak is not None and args.sk is not None: - self.ak = args.ak - self.sk = args.sk - self.token = self._apply_token(self.ak, self.sk) - results = self.generate_image(text_prompts=args.text_prompts, - style=args.style, - topk=args.topk, - output_dir=args.output_dir) - return results - - def add_module_input_arg(self): - """ - Add the command input options. - """ - self.arg_input_group.add_argument('--text_prompts', type=str) - self.arg_input_group.add_argument('--style', - type=str, - default='油画', - choices=['油画', '水彩', '粉笔画', '卡通', '儿童画', '蜡笔画'], - help="绘画风格") - self.arg_input_group.add_argument('--topk', type=int, default=10, help="选取保存前多少张图,最多10张") - self.arg_input_group.add_argument('--ak', type=str, default=None, help="申请文心api使用token的ak") - self.arg_input_group.add_argument('--sk', type=str, default=None, help="申请文心api使用token的sk") - self.arg_input_group.add_argument('--output_dir', type=str, default='ernievilg_output') diff --git a/modules/image/text_to_image/ernie_vilg/requirements.txt b/modules/image/text_to_image/ernie_vilg/requirements.txt deleted file mode 100644 index 5bb8c66c68ea2361106e2f5bbca7f136537e62b8..0000000000000000000000000000000000000000 --- a/modules/image/text_to_image/ernie_vilg/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -requests -tqdm