
import os
import json
import logging
from urllib.parse import quote
import requests
import traceback
from .inference import Inference
from utils import exception_reconnect


logger = logging.getLogger(__name__)


class ChatGPTInference(Inference):
    def __init__(self):
        super(ChatGPTInference, self).__init__()
        self.params_url = "../llm_set/params/chatgpt.json"
        self.paras_dict = self.get_params()
        self.paras_base_dict.update(self.paras_dict)
        self.paras_dict = self.paras_base_dict
        self.gpt_url = self.paras_dict.get("url")
        self.id = self.paras_dict.get("id")
        self.stream = self.paras_dict.get("stream")
        self.temperature = self.paras_dict.get("temperature")
        self.timeout = self.paras_dict.get("timeout")

    def get_params(self):
        if not os.path.exists(self.params_url):
            logger.error(f"params_url:{self.params_url} is not exists.")
        content = open(self.params_url).read()
        return json.loads(content)

    @exception_reconnect
    def inference(self, query_text):
        query_text = "补全以下方法体：\n" + query_text
        query_text = quote(query_text)
        param_str = f"?id={self.id}&stream={self.stream}&temperature={self.temperature}&q={query_text}"
        try:
            resp = requests.get(self.gpt_url + param_str, timeout=self.timeout)
        except Exception as e:
            logger.error(traceback.format_exc())
            logger.error(e)
            return None

        if resp.status_code == 200:
            resp_json = json.loads(resp.text)
            return "\n".join(resp_json["message"]["content"]["parts"])
        else:
            return None
