提交 2aeceff2 编写于 作者: C chenjian

fix

上级 6802506b
......@@ -27,6 +27,32 @@ from paddlehub.module.module import serving
author_email="paddle-dev@baidu.com")
class ErnieVilG:
def __init__(self):
self.ak = 'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE'
self.sk = 'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs'
self.token_host = 'https://wenxin.baidu.com/younger/portal/api/oauth/token'
self.token = self._apply_token(self.ak, self.sk)
def _apply_token(self, ak, sk):
if ak is None or sk is None:
ak = self.ak
sk = self.sk
response = requests.get(self.token_host,
params={
'grant_type': 'client_credentials',
'client_id': ak,
'client_secret': sk
})
if response:
res = response.json()
if res['code'] != 0:
print('Request access token error.')
raise RuntimeError("Request access token error.")
else:
print('Request access token error.')
raise RuntimeError("Request access token error.")
return res['data']
def generate_image(self,
text_prompts,
style: Optional[str] = "油画",
......@@ -46,27 +72,10 @@ class ErnieVilG:
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
if ak == None:
ak = 'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE'
if sk == None:
sk = 'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs'
token_host = 'https://wenxin.baidu.com/younger/portal/api/oauth/token'
response = requests.get(token_host,
params={
'grant_type': 'client_credentials',
'client_id': ak,
'client_secret': sk
})
if response:
res = response.json()
if res['code'] != 0:
print('Request access token error.')
raise RuntimeError("Request access token error.")
else:
print('Request access token error.')
raise RuntimeError("Request access token error.")
token = self.token
if ak is not None and sk is not None:
token = self._apply_token(ak, sk)
token = res['data']
create_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/txt2img?from=paddlehub'
get_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/getImg?from=paddlehub'
if isinstance(text_prompts, str):
......@@ -93,6 +102,20 @@ class ErnieVilG:
elif res['code'] == 4004:
print('API服务内部错误,可能引起原因有请求超时、模型推理错误等')
raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等")
elif res['code'] == 100 or res['code'] == 110 or res['code'] == 111:
token = self._apply_token(ak, sk)
res = requests.post(create_url,
headers={'Content-Type': 'application/x-www-form-urlencoded'},
data={
'access_token': token,
"text": text_prompt,
"style": style
})
res = res.json()
if res['code'] != 0:
print("Token失效重新请求后依然发生错误,请检查输入的参数")
raise RuntimeError("Token失效重新请求后依然发生错误,请检查输入的参数")
taskids.append(res['data']["taskId"])
start_time = time.time()
......@@ -124,6 +147,19 @@ class ErnieVilG:
elif res['code'] == 4004:
print('API服务内部错误,可能引起原因有请求超时、模型推理错误等')
raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等")
elif res['code'] == 100 or res['code'] == 110 or res['code'] == 111:
token = self._apply_token(ak, sk)
res = requests.post(create_url,
headers={'Content-Type': 'application/x-www-form-urlencoded'},
data={
'access_token': token,
"text": text_prompt,
"style": style
})
res = res.json()
if res['code'] != 0:
print("Token失效重新请求后依然发生错误,请检查输入的参数")
raise RuntimeError("Token失效重新请求后依然发生错误,请检查输入的参数")
if res['data']['status'] == 1:
has_done.append(res['data']['taskId'])
results[res['data']['text']] = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册