import os import json import logging from urllib.parse import quote import requests import traceback from .inference import Inference from utils import exception_reconnect logger = logging.getLogger(__name__) class BBTInference(Inference): def __init__(self): super(BBTInference, self).__init__() self.params_url = "../llm_set/params/bbt.json" self.paras_dict = self.get_params() self.paras_base_dict.update(self.paras_dict) self.paras_dict = self.paras_base_dict self.url = self.paras_dict.get("url") self.user_id = self.paras_dict.get("user_id") self.timeout = self.paras_dict.get("timeout") self.modelConfig = { 'temperature': self.paras_dict.get("temperature"), 'top_p': self.paras_dict.get("top_p"), # 'top_k': 80, # 'do_sample': True, # 'num_beams': 4, # 'no_repeat_ngram_size': 4, # 'repetition_penalty': 1.1, # 'length_penalty': 1, 'max_new_tokens': self.paras_dict.get("max_new_tokens"), } def get_params(self): if not os.path.exists(self.params_url): logger.error(f"params_url:{self.params_url} is not exists.") content = open(self.params_url).read() return json.loads(content) @exception_reconnect def inference(self, query_text): query_text = "补全以下方法体:\n" + query_text try: data = {'content': query_text, 'user_id': self.user_id, 'config': json.dumps(self.modelConfig)} resp = requests.post(self.url, data=data) except Exception as e: logger.error(traceback.format_exc()) logger.error(e) return None if resp.status_code == 200: resp_json = json.loads(resp.text) return resp_json["msg"] else: return None