chatgpt_inference.py 1.4 KB
Newer Older
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45

import os
import json
import logging
import quote
import requests
import traceback
from inference import Inference


logger = logging.getLogger(__name__)


class ChatGPTInference(Inference):
    def __int__(self):
        self.params_url = "llm_set/params/chatgpt.json"
        self.paras_dict = self.get_params()
        self.paras_dict.update(self.paras_base_dict)
        self.gpt_url = self.paras_dict.get("url")
        self.id = self.paras_dict.get("id")
        self.stream = self.paras_dict.get("stream")
        self.temperature = self.paras_dict.get("temperature")
        self.timeout = self.paras_dict.get("timeout")

    def get_params(self):
        if not os.path.exists(self.params_url):
            logger.error(f"params_url:{self.params_url} is not exists.")
        content = open(self.params_url).read()
        return json.loads(content)

    def inference(self, query_text):
        query_text = quote(query_text)
        param_str = f"?id={self.gpt_url}&stream={self.stream}&temperature={self.temperature}&q={query_text}"
        try:
            resp = requests.get(self.gpt_url + param_str, timeout=self.timeout)
        except Exception as e:
            logger.error(traceback.format_exc())
            logger.error(e)
            return None

        if resp.status_code == 200:
            resp_json = json.loads(resp.text)
            return "\n".join(resp_json["message"]["content"]["parts"])
        else:
            return None