未验证 提交 088b37d6 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update ERNIE VILG (#2126)

* update ERNIE VILG

* update README

* update
上级 7eef3bfd
...@@ -54,12 +54,17 @@ ...@@ -54,12 +54,17 @@
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
- ### 3. 使用申请(可选)
- 请前往 [文心旸谷社区](https://wenxin.baidu.com/moduleApi/key) 申请使用本模型所需的 API key 和 Secret Key。
## 三、模型API预测 ## 三、模型API预测
- ### 1、命令行预测 - ### 1、命令行预测
- ```shell - ```shell
# 请设置 '--ak' 和 '--sk' 参数
# 或者设置 'WENXIN_AK' 和 'WENXIN_SK' 环境变量
# 更多细节参考下方 API 说明
$ hub run ernie_vilg --text_prompts "宁静的小镇" --style "油画" --output_dir ernie_vilg_out $ hub run ernie_vilg --text_prompts "宁静的小镇" --style "油画" --output_dir ernie_vilg_out
``` ```
...@@ -68,6 +73,9 @@ ...@@ -68,6 +73,9 @@
- ```python - ```python
import paddlehub as hub import paddlehub as hub
# 请设置 'ak' 和 'sk' 参数
# 或者设置 'WENXIN_AK' 和 'WENXIN_SK' 环境变量
# 更多细节参考下方 API 说明
module = hub.Module(name="ernie_vilg") module = hub.Module(name="ernie_vilg")
text_prompts = ["宁静的小镇"] text_prompts = ["宁静的小镇"]
images = module.generate_image(text_prompts=text_prompts, style='油画', output_dir='./ernie_vilg_out/') images = module.generate_image(text_prompts=text_prompts, style='油画', output_dir='./ernie_vilg_out/')
...@@ -75,12 +83,27 @@ ...@@ -75,12 +83,27 @@
- ### 3、API - ### 3、API
- ```python
def __init__(
ak: Optional[str] = None,
sk: Optional[str] = None
)
```
- 初始化 API。
- **参数**
- ak(Optional[str]): 文心 API AK,默认为 None,即从环境变量 'WENXIN_AK' 中获取;
- sk(Optional[str]): 文心 API SK,默认为 None,即从环境变量 'WENXIN_SK' 中获取。
- ```python - ```python
def generate_image( def generate_image(
text_prompts:str, text_prompts:str,
style: Optional[str] = "探索无限", style: Optional[str] = "探索无限",
topk: Optional[int] = 6, topk: Optional[int] = 6,
output_dir: Optional[str] = 'ernievilg_output') output_dir: Optional[str] = 'ernievilg_output'
)
``` ```
- 文图生成API,生成文本描述内容的图像。 - 文图生成API,生成文本描述内容的图像。
...@@ -390,7 +413,7 @@ DiscoDiffusion Prompt 技巧资料:https://docs.google.com/document/d/1l8s7uS2 ...@@ -390,7 +413,7 @@ DiscoDiffusion Prompt 技巧资料:https://docs.google.com/document/d/1l8s7uS2
* 1.2.0 * 1.2.0
移除分辨率参数 移除分辨率参数,移除默认 AK 和 SK
```shell ```shell
$ hub install ernie_vilg == 1.2.0 $ hub install ernie_vilg == 1.2.0
......
...@@ -27,19 +27,17 @@ class ErnieVilG: ...@@ -27,19 +27,17 @@ class ErnieVilG:
:param ak: ak for applying token to request wenxin api. :param ak: ak for applying token to request wenxin api.
:param sk: sk for applying token to request wenxin api. :param sk: sk for applying token to request wenxin api.
""" """
if ak is None or sk is None: self.ak = ak
self.ak = 'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE' self.sk = sk
self.sk = 'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs'
else:
self.ak = ak
self.sk = sk
self.token_host = 'https://wenxin.baidu.com/younger/portal/api/oauth/token' self.token_host = 'https://wenxin.baidu.com/younger/portal/api/oauth/token'
self.token = self._apply_token(self.ak, self.sk) self.token = self._apply_token(self.ak, self.sk)
def _apply_token(self, ak, sk): def _apply_token(self, ak, sk):
if ak is None or sk is None: ak = ak if ak else os.getenv('WENXIN_AK')
ak = self.ak sk = sk if sk else os.getenv('WENXIN_SK')
sk = self.sk assert ak and sk, RuntimeError(
'Please go to the wenxin official website to apply for AK and SK and set the parameters “ak” and “sk” correctly, or set the environment variables “WENXIN_AK” and “WENXIN_SK”.'
)
response = requests.get(self.token_host, response = requests.get(self.token_host,
params={ params={
'grant_type': 'client_credentials', 'grant_type': 'client_credentials',
...@@ -145,7 +143,7 @@ class ErnieVilG: ...@@ -145,7 +143,7 @@ class ErnieVilG:
time.sleep(5) time.sleep(5)
continue continue
else: else:
time.sleep(6) time.sleep(10)
if not taskids: if not taskids:
break break
has_done = [] has_done = []
...@@ -226,10 +224,9 @@ class ErnieVilG: ...@@ -226,10 +224,9 @@ class ErnieVilG:
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.add_module_input_arg() self.add_module_input_arg()
args = self.parser.parse_args(argvs) args = self.parser.parse_args(argvs)
if args.ak is not None and args.sk is not None: self.ak = args.ak
self.ak = args.ak self.sk = args.sk
self.sk = args.sk self.token = self._apply_token(self.ak, self.sk)
self.token = self._apply_token(self.ak, self.sk)
results = self.generate_image(text_prompts=args.text_prompts, results = self.generate_image(text_prompts=args.text_prompts,
style=args.style, style=args.style,
topk=args.topk, topk=args.topk,
......
...@@ -16,7 +16,7 @@ class TestHubModule(unittest.TestCase): ...@@ -16,7 +16,7 @@ class TestHubModule(unittest.TestCase):
def test_generate_image(self): def test_generate_image(self):
self.module.generate_image(text_prompts=['戴眼镜的猫'], self.module.generate_image(text_prompts=['戴眼镜的猫'],
style="探索无限", style="像素风格",
topk=6, topk=6,
visualization=True, visualization=True,
output_dir='ernievilg_output') output_dir='ernievilg_output')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册