diff --git a/modules/image/text_to_image/ernie_vilg/module.py b/modules/image/text_to_image/ernie_vilg/module.py index 891169c43b172437a88cca35e957a2b00d31d688..906e18cb6b3198fbfcb9de84cf86f231847370ff 100644 --- a/modules/image/text_to_image/ernie_vilg/module.py +++ b/modules/image/text_to_image/ernie_vilg/module.py @@ -27,6 +27,32 @@ from paddlehub.module.module import serving author_email="paddle-dev@baidu.com") class ErnieVilG: + def __init__(self): + self.ak = 'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE' + self.sk = 'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs' + self.token_host = 'https://wenxin.baidu.com/younger/portal/api/oauth/token' + self.token = self._apply_token(self.ak, self.sk) + + def _apply_token(self, ak, sk): + if ak is None or sk is None: + ak = self.ak + sk = self.sk + response = requests.get(self.token_host, + params={ + 'grant_type': 'client_credentials', + 'client_id': ak, + 'client_secret': sk + }) + if response: + res = response.json() + if res['code'] != 0: + print('Request access token error.') + raise RuntimeError("Request access token error.") + else: + print('Request access token error.') + raise RuntimeError("Request access token error.") + return res['data'] + def generate_image(self, text_prompts, style: Optional[str] = "油画", @@ -46,27 +72,10 @@ class ErnieVilG: """ if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) - if ak == None: - ak = 'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE' - if sk == None: - sk = 'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs' - token_host = 'https://wenxin.baidu.com/younger/portal/api/oauth/token' - response = requests.get(token_host, - params={ - 'grant_type': 'client_credentials', - 'client_id': ak, - 'client_secret': sk - }) - if response: - res = response.json() - if res['code'] != 0: - print('Request access token error.') - raise RuntimeError("Request access token error.") - else: - print('Request access token error.') - raise RuntimeError("Request access token error.") + token = self.token + if ak is not None and sk is not None: + token = self._apply_token(ak, sk) - token = res['data'] create_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/txt2img?from=paddlehub' get_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/getImg?from=paddlehub' if isinstance(text_prompts, str): @@ -93,6 +102,20 @@ class ErnieVilG: elif res['code'] == 4004: print('API服务内部错误,可能引起原因有请求超时、模型推理错误等') raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等") + elif res['code'] == 100 or res['code'] == 110 or res['code'] == 111: + token = self._apply_token(ak, sk) + res = requests.post(create_url, + headers={'Content-Type': 'application/x-www-form-urlencoded'}, + data={ + 'access_token': token, + "text": text_prompt, + "style": style + }) + res = res.json() + if res['code'] != 0: + print("Token失效重新请求后依然发生错误,请检查输入的参数") + raise RuntimeError("Token失效重新请求后依然发生错误,请检查输入的参数") + taskids.append(res['data']["taskId"]) start_time = time.time() @@ -124,6 +147,19 @@ class ErnieVilG: elif res['code'] == 4004: print('API服务内部错误,可能引起原因有请求超时、模型推理错误等') raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等") + elif res['code'] == 100 or res['code'] == 110 or res['code'] == 111: + token = self._apply_token(ak, sk) + res = requests.post(create_url, + headers={'Content-Type': 'application/x-www-form-urlencoded'}, + data={ + 'access_token': token, + "text": text_prompt, + "style": style + }) + res = res.json() + if res['code'] != 0: + print("Token失效重新请求后依然发生错误,请检查输入的参数") + raise RuntimeError("Token失效重新请求后依然发生错误,请检查输入的参数") if res['data']['status'] == 1: has_done.append(res['data']['taskId']) results[res['data']['text']] = {