提交 b32bb9f9 编写于 作者: C chenjian

fix

上级 d6ac509d
......@@ -63,13 +63,20 @@
- ### 3、API
- ```python
def __init__(ak: Optional[str]=None, sk: Optional[str]=None)
```
- 初始化模块,可自定义用于申请访问文心API的ak和sk。
- **参数**
- ak:(Optional[str]): 用于申请文心api使用token的ak,可不填。
- sk:(Optional[str]): 用于申请文心api使用token的sk,可不填。
- ```python
def generate_image(
text_prompts:str,
style: Optional[str] = "油画",
topk: Optional[int] = 10,
ak: Optional[str] = None,
sk: Optional[str] = None,
output_dir: Optional[str] = 'ernievilg_output')
```
......@@ -80,8 +87,6 @@
- text_prompts(str): 输入的语句,描述想要生成的图像的内容。
- style(Optional[str]): 生成图像的风格,当前支持'油画','水彩','粉笔画','卡通','儿童画','蜡笔画'。
- topk(Optional[int]): 保存前多少张图,最多保存10张。
- ak:(Optional[str]): 用于申请文心api使用token的ak,可不填。
- sk:(Optional[str]): 用于申请文心api使用token的sk,可不填。
- output_dir(Optional[str]): 保存输出图像的目录,默认为"ernievilg_output"。
......
......@@ -27,9 +27,17 @@ from paddlehub.module.module import serving
author_email="paddle-dev@baidu.com")
class ErnieVilG:
def __init__(self):
self.ak = 'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE'
self.sk = 'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs'
def __init__(self, ak=None, sk=None):
"""
:param ak: ak for applying token to request wenxin api.
:param sk: sk for applying token to request wenxin api.
"""
if ak is None or sk is None:
self.ak = 'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE'
self.sk = 'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs'
else:
self.ak = ak
self.sk = sk
self.token_host = 'https://wenxin.baidu.com/younger/portal/api/oauth/token'
self.token = self._apply_token(self.ak, self.sk)
......@@ -57,8 +65,6 @@ class ErnieVilG:
text_prompts,
style: Optional[str] = "油画",
topk: Optional[int] = 10,
ak: Optional[str] = None,
sk: Optional[str] = None,
output_dir: Optional[str] = 'ernievilg_output'):
"""
Create image by text prompts using ErnieVilG model.
......@@ -66,16 +72,11 @@ class ErnieVilG:
:param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like.
:param style: Image stype, currently supported 油画、水彩、粉笔画、卡通、儿童画、蜡笔画
:param topk: Top k images to save.
:param ak: ak for applying token to request wenxin api.
:param sk: sk for applying token to request wenxin api.
:output_dir: Output directory
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
token = self.token
if ak is not None and sk is not None:
token = self._apply_token(ak, sk)
create_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/txt2img?from=paddlehub'
get_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/getImg?from=paddlehub'
if isinstance(text_prompts, str):
......@@ -103,7 +104,7 @@ class ErnieVilG:
print('API服务内部错误,可能引起原因有请求超时、模型推理错误等')
raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等")
elif res['code'] == 100 or res['code'] == 110 or res['code'] == 111:
token = self._apply_token(ak, sk)
token = self._apply_token(self.ak, self.sk)
res = requests.post(create_url,
headers={'Content-Type': 'application/x-www-form-urlencoded'},
data={
......@@ -148,7 +149,7 @@ class ErnieVilG:
print('API服务内部错误,可能引起原因有请求超时、模型推理错误等')
raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等")
elif res['code'] == 100 or res['code'] == 110 or res['code'] == 111:
token = self._apply_token(ak, sk)
token = self._apply_token(self.ak, self.sk)
res = requests.post(get_url,
headers={'Content-Type': 'application/x-www-form-urlencoded'},
data={
......@@ -203,11 +204,13 @@ class ErnieVilG:
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
if args.ak is not None and args.sk is not None:
self.ak = args.ak
self.sk = args.sk
self.token = self._apply_token(self.ak, self.sk)
results = self.generate_image(text_prompts=args.text_prompts,
style=args.style,
topk=args.topk,
ak=args.ak,
sk=args.sk,
output_dir=args.output_dir)
return results
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册