
import os
import json
import logging
from urllib.parse import quote
import requests
import traceback
from .inference import Inference
from utils import exception_reconnect


logger = logging.getLogger(__name__)


class BBTInference(Inference):
    def __init__(self):
        super(BBTInference, self).__init__()
        self.params_url = "../llm_set/params/bbt.json"
        self.paras_dict = self.get_params()
        self.paras_base_dict.update(self.paras_dict)
        self.paras_dict = self.paras_base_dict
        self.url = self.paras_dict.get("url")
        self.user_id = self.paras_dict.get("user_id")
        self.timeout = self.paras_dict.get("timeout")
        self.modelConfig = {
            'temperature': self.paras_dict.get("temperature"),
            'top_p': self.paras_dict.get("top_p"),
            # 'top_k': 80,
            # 'do_sample': True,
            # 'num_beams': 4,
            # 'no_repeat_ngram_size': 4,
            # 'repetition_penalty': 1.1,
            # 'length_penalty': 1,
            'max_new_tokens': self.paras_dict.get("max_new_tokens"),
        }

    def get_params(self):
        if not os.path.exists(self.params_url):
            logger.error(f"params_url:{self.params_url} is not exists.")
        content = open(self.params_url).read()
        return json.loads(content)

    @exception_reconnect
    def inference(self, query_text):
        query_text = "补全以下方法体：\n" + query_text
        try:
            data = {'content': query_text, 'user_id': self.user_id, 'config': json.dumps(self.modelConfig)}
            resp = requests.post(self.url, data=data)
        except Exception as e:
            logger.error(traceback.format_exc())
            logger.error(e)
            return None

        if resp.status_code == 200:
            resp_json = json.loads(resp.text)
            return resp_json["msg"]
        else:
            return None
