import os import json import logging from urllib.parse import quote import requests import traceback from .inference import Inference from utils import exception_reconnect logger = logging.getLogger(__name__) class ChatGPTInference(Inference): def __init__(self): super(ChatGPTInference, self).__init__() self.params_url = "../llm_set/params/chatgpt.json" self.paras_dict = self.get_params() self.paras_base_dict.update(self.paras_dict) self.paras_dict = self.paras_base_dict self.gpt_url = self.paras_dict.get("url") self.id = self.paras_dict.get("id") self.stream = self.paras_dict.get("stream") self.temperature = self.paras_dict.get("temperature") self.timeout = self.paras_dict.get("timeout") def get_params(self): if not os.path.exists(self.params_url): logger.error(f"params_url:{self.params_url} is not exists.") content = open(self.params_url).read() return json.loads(content) @exception_reconnect def inference(self, query_text): query_text = "补全以下方法体:\n" + query_text query_text = quote(query_text) param_str = f"?id={self.id}&stream={self.stream}&temperature={self.temperature}&q={query_text}" try: resp = requests.get(self.gpt_url + param_str, timeout=self.timeout) except Exception as e: logger.error(traceback.format_exc()) logger.error(e) return None if resp.status_code == 200: resp_json = json.loads(resp.text) return "\n".join(resp_json["message"]["content"]["parts"]) else: return None