chatgpt_inference.py 1.7 KB
Newer Older
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
1 2 3 4

import os
import json
import logging
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
5
from urllib.parse import quote
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
6 7
import requests
import traceback
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
8 9
from .inference import Inference
from utils import exception_reconnect
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
10 11 12 13 14 15


logger = logging.getLogger(__name__)


class ChatGPTInference(Inference):
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
16 17 18
    def __init__(self):
        super(ChatGPTInference, self).__init__()
        self.params_url = "../llm_set/params/chatgpt.json"
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
19
        self.paras_dict = self.get_params()
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
20 21
        self.paras_base_dict.update(self.paras_dict)
        self.paras_dict = self.paras_base_dict
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
22 23 24 25 26 27 28 29 30 31 32 33
        self.gpt_url = self.paras_dict.get("url")
        self.id = self.paras_dict.get("id")
        self.stream = self.paras_dict.get("stream")
        self.temperature = self.paras_dict.get("temperature")
        self.timeout = self.paras_dict.get("timeout")

    def get_params(self):
        if not os.path.exists(self.params_url):
            logger.error(f"params_url:{self.params_url} is not exists.")
        content = open(self.params_url).read()
        return json.loads(content)

CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
34
    @exception_reconnect
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
35
    def inference(self, query_text):
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
36
        query_text = "补全以下方法体:\n" + query_text
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
37
        query_text = quote(query_text)
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
38
        param_str = f"?id={self.id}&stream={self.stream}&temperature={self.temperature}&q={query_text}"
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
39 40 41 42 43 44 45 46 47 48 49 50
        try:
            resp = requests.get(self.gpt_url + param_str, timeout=self.timeout)
        except Exception as e:
            logger.error(traceback.format_exc())
            logger.error(e)
            return None

        if resp.status_code == 200:
            resp_json = json.loads(resp.text)
            return "\n".join(resp_json["message"]["content"]["parts"])
        else:
            return None