提交 f7f8bbf5 编写于 作者: C chenjian

fix

上级 70532e06
# ernie_vilg
|模型名称|ernie_vilg|
| :--- | :---: |
|类别|图像-文图生成|
|网络|ERNIE-ViLG|
|数据集|-|
|是否支持Fine-tuning|否|
|模型大小|-|
|最新更新日期|2022-08-02|
|数据指标|-|
## 一、模型基本信息
### 应用效果展示
- 输入文本 "宁静的小镇" 风格 "油画"
- 输出图像
<p align="center">
<img src="https://user-images.githubusercontent.com/22424850/183041589-57debf50-80ec-496f-8bb5-42d9d38646dd.png" width = "80%" hspace='10'/>
<br />
### 模型介绍
文心ERNIE-ViLG参数规模达到100亿,是目前为止全球最大规模中文跨模态生成模型,在文本生成图像、图像描述等跨模态生成任务上效果全球领先,在图文生成领域MS-COCO、COCO-CN、AIC-ICC等数据集上取得最好效果。你可以输入一段文本描述以及生成风格,模型就会根据输入的内容自动创作出符合要求的图像。
## 二、安装
- ### 1、环境依赖
- paddlepaddle >= 2.0.0
- paddlehub >= 2.2.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst)
- ### 2、安装
- ```shell
$ hub install ernie_vilg
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测
- ### 1、命令行预测
- ```shell
$ hub run ernie_vilg --text_prompts "宁静的小镇" --output_dir ernie_vilg_out
```
- ### 2、预测代码示例
- ```python
import paddlehub as hub
module = hub.Module(name="ernie_vilg")
text_prompts = ["宁静的小镇"]
images = module.generate_image(text_prompts=text_prompts, output_dir='./ernie_vilg_out/')
```
- ### 3、API
- ```python
def __init__(ak: Optional[str]=None, sk: Optional[str]=None)
```
- 初始化模块,可自定义用于申请访问文心API的ak和sk。
- **参数**
- ak:(Optional[str]): 用于申请文心api使用token的ak,可不填。
- sk:(Optional[str]): 用于申请文心api使用token的sk,可不填。
- ```python
def generate_image(
text_prompts:str,
style: Optional[str] = "油画",
topk: Optional[int] = 10,
output_dir: Optional[str] = 'ernievilg_output')
```
- 文图生成API,生成文本描述内容的图像。
- **参数**
- text_prompts(str): 输入的语句,描述想要生成的图像的内容。
- style(Optional[str]): 生成图像的风格,当前支持'油画','水彩','粉笔画','卡通','儿童画','蜡笔画'。
- topk(Optional[int]): 保存前多少张图,最多保存10张。
- output_dir(Optional[str]): 保存输出图像的目录,默认为"ernievilg_output"。
- **返回**
- images(List(PIL.Image)): 返回生成的所有图像列表,PIL的Image格式。
## 四、更新历史
* 1.0.0
初始发布
```shell
$ hub install ernie_vilg == 1.0.0
```
import argparse
import ast
import os
import re
import sys
import time
from functools import partial
from io import BytesIO
from typing import List
from typing import Optional
import requests
from PIL import Image
from tqdm.auto import tqdm
import paddlehub as hub
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
@moduleinfo(name="ernie_vilg",
version="1.0.0",
type="image/text_to_image",
summary="",
author="baidu-nlp",
author_email="paddle-dev@baidu.com")
class ErnieVilG:
def __init__(self, ak=None, sk=None):
"""
:param ak: ak for applying token to request wenxin api.
:param sk: sk for applying token to request wenxin api.
"""
if ak is None or sk is None:
self.ak = 'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE'
self.sk = 'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs'
else:
self.ak = ak
self.sk = sk
self.token_host = 'https://wenxin.baidu.com/younger/portal/api/oauth/token'
self.token = self._apply_token(self.ak, self.sk)
def _apply_token(self, ak, sk):
if ak is None or sk is None:
ak = self.ak
sk = self.sk
response = requests.get(self.token_host,
params={
'grant_type': 'client_credentials',
'client_id': ak,
'client_secret': sk
})
if response:
res = response.json()
if res['code'] != 0:
print('Request access token error.')
raise RuntimeError("Request access token error.")
else:
print('Request access token error.')
raise RuntimeError("Request access token error.")
return res['data']
def generate_image(self,
text_prompts,
style: Optional[str] = "油画",
topk: Optional[int] = 10,
output_dir: Optional[str] = 'ernievilg_output'):
"""
Create image by text prompts using ErnieVilG model.
:param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like.
:param style: Image stype, currently supported 油画、水彩、粉笔画、卡通、儿童画、蜡笔画
:param topk: Top k images to save.
:output_dir: Output directory
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
token = self.token
create_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/txt2img?from=paddlehub'
get_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/getImg?from=paddlehub'
if isinstance(text_prompts, str):
text_prompts = [text_prompts]
taskids = []
for text_prompt in text_prompts:
res = requests.post(create_url,
headers={'Content-Type': 'application/x-www-form-urlencoded'},
data={
'access_token': token,
"text": text_prompt,
"style": style
})
res = res.json()
if res['code'] == 4001:
print('请求参数错误')
raise RuntimeError("请求参数错误")
elif res['code'] == 4002:
print('请求参数格式错误,请检查必传参数是否齐全,参数类型等')
raise RuntimeError("请求参数格式错误,请检查必传参数是否齐全,参数类型等")
elif res['code'] == 4003:
print('请求参数中,图片风格不在可选范围内')
raise RuntimeError("请求参数中,图片风格不在可选范围内")
elif res['code'] == 4004:
print('API服务内部错误,可能引起原因有请求超时、模型推理错误等')
raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等")
elif res['code'] == 100 or res['code'] == 110 or res['code'] == 111:
token = self._apply_token(self.ak, self.sk)
res = requests.post(create_url,
headers={'Content-Type': 'application/x-www-form-urlencoded'},
data={
'access_token': token,
"text": text_prompt,
"style": style
})
res = res.json()
if res['code'] != 0:
print("Token失效重新请求后依然发生错误,请检查输入的参数")
raise RuntimeError("Token失效重新请求后依然发生错误,请检查输入的参数")
taskids.append(res['data']["taskId"])
start_time = time.time()
process_bar = tqdm(total=100, unit='%')
results = {}
first_iter = True
while True:
if not taskids:
break
total_time = 0
has_done = []
for taskid in taskids:
res = requests.post(get_url,
headers={'Content-Type': 'application/x-www-form-urlencoded'},
data={
'access_token': token,
'taskId': {taskid}
})
res = res.json()
if res['code'] == 4001:
print('请求参数错误')
raise RuntimeError("请求参数错误")
elif res['code'] == 4002:
print('请求参数格式错误,请检查必传参数是否齐全,参数类型等')
raise RuntimeError("请求参数格式错误,请检查必传参数是否齐全,参数类型等")
elif res['code'] == 4003:
print('请求参数中,图片风格不在可选范围内')
raise RuntimeError("请求参数中,图片风格不在可选范围内")
elif res['code'] == 4004:
print('API服务内部错误,可能引起原因有请求超时、模型推理错误等')
raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等")
elif res['code'] == 100 or res['code'] == 110 or res['code'] == 111:
token = self._apply_token(self.ak, self.sk)
res = requests.post(get_url,
headers={'Content-Type': 'application/x-www-form-urlencoded'},
data={
'access_token': token,
'taskId': {taskid}
})
res = res.json()
if res['code'] != 0:
print("Token失效重新请求后依然发生错误,请检查输入的参数")
raise RuntimeError("Token失效重新请求后依然发生错误,请检查输入的参数")
if res['data']['status'] == 1:
has_done.append(res['data']['taskId'])
results[res['data']['text']] = {
'imgUrls': res['data']['imgUrls'],
'waiting': res['data']['waiting'],
'taskId': res['data']['taskId']
}
total_time = int(re.match('[0-9]+', str(res['data']['waiting'])).group(0)) * 60
end_time = time.time()
progress_rate = int(((end_time - start_time) / total_time * 100)) if total_time != 0 else 100
if progress_rate > process_bar.n:
increase_rate = progress_rate - process_bar.n
if progress_rate >= 100:
increase_rate = 100 - process_bar.n
else:
increase_rate = 0
process_bar.update(increase_rate)
time.sleep(5)
for taskid in has_done:
taskids.remove(taskid)
print('Saving Images...')
result_images = []
for text, data in results.items():
for idx, imgdata in enumerate(data['imgUrls']):
image = Image.open(BytesIO(requests.get(imgdata['image']).content))
image.save(os.path.join(output_dir, '{}_{}.png'.format(text, idx)))
result_images.append(image)
if idx + 1 >= topk:
break
print('Done')
return result_images
@runnable
def run_cmd(self, argvs):
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
if args.ak is not None and args.sk is not None:
self.ak = args.ak
self.sk = args.sk
self.token = self._apply_token(self.ak, self.sk)
results = self.generate_image(text_prompts=args.text_prompts,
style=args.style,
topk=args.topk,
output_dir=args.output_dir)
return results
def add_module_input_arg(self):
"""
Add the command input options.
"""
self.arg_input_group.add_argument('--text_prompts', type=str)
self.arg_input_group.add_argument('--style',
type=str,
default='油画',
choices=['油画', '水彩', '粉笔画', '卡通', '儿童画', '蜡笔画'],
help="绘画风格")
self.arg_input_group.add_argument('--topk', type=int, default=10, help="选取保存前多少张图,最多10张")
self.arg_input_group.add_argument('--ak', type=str, default=None, help="申请文心api使用token的ak")
self.arg_input_group.add_argument('--sk', type=str, default=None, help="申请文心api使用token的sk")
self.arg_input_group.add_argument('--output_dir', type=str, default='ernievilg_output')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册