bbt_inference.py 1.9 KB
Newer Older
CSDN-Ada助手's avatar
fix bug  
CSDN-Ada助手 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

import os
import json
import logging
from urllib.parse import quote
import requests
import traceback
from .inference import Inference
from utils import exception_reconnect


logger = logging.getLogger(__name__)


class BBTInference(Inference):
    def __init__(self):
        super(BBTInference, self).__init__()
        self.params_url = "../llm_set/params/bbt.json"
        self.paras_dict = self.get_params()
        self.paras_base_dict.update(self.paras_dict)
        self.paras_dict = self.paras_base_dict
        self.url = self.paras_dict.get("url")
        self.user_id = self.paras_dict.get("user_id")
        self.timeout = self.paras_dict.get("timeout")
        self.modelConfig = {
            'temperature': self.paras_dict.get("temperature"),
            'top_p': self.paras_dict.get("top_p"),
            # 'top_k': 80,
            # 'do_sample': True,
            # 'num_beams': 4,
            # 'no_repeat_ngram_size': 4,
            # 'repetition_penalty': 1.1,
            # 'length_penalty': 1,
            'max_new_tokens': self.paras_dict.get("max_new_tokens"),
        }

    def get_params(self):
        if not os.path.exists(self.params_url):
            logger.error(f"params_url:{self.params_url} is not exists.")
        content = open(self.params_url).read()
        return json.loads(content)

    @exception_reconnect
    def inference(self, query_text):
        query_text = "补全以下方法体:\n" + query_text
        try:
            data = {'content': query_text, 'user_id': self.user_id, 'config': json.dumps(self.modelConfig)}
            resp = requests.post(self.url, data=data)
        except Exception as e:
            logger.error(traceback.format_exc())
            logger.error(e)
            return None

        if resp.status_code == 200:
            resp_json = json.loads(resp.text)
            return resp_json["msg"]
        else:
            return None