#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2023/7/26 10:28
# @Author  : clong
# @File    : llama2_inference.py


import os
import json
import torch
import logging
from .llama import Llama
from .inference import Inference

logger = logging.getLogger(__name__)


class LLAMA2Inference(Inference):

    def __init__(self):
        super(LLAMA2Inference, self).__init__()
        self.params_url = "../llm_set/params/llama2.json"
        self.paras_dict = self.get_params()
        self.paras_dict.update(self.paras_base_dict)

        self.temperature = self.paras_dict.get("temperature")
        self.model_path = self.paras_dict.get("model_path")
        self.tokenizer_path = self.paras_dict.get("tokenizer_path")
        self.max_batch_size = self.paras_dict.get("max_batch_size")

        self.max_length = self.paras_dict.get("max_length")
        self.min_length = self.paras_dict.get("min_length")
        self.top_p = self.paras_dict.get("top_p")
        self.top_k = self.paras_dict.get("top_k")

        self.model = Llama.build(
            ckpt_dir=self.model_path,
            tokenizer_path=self.tokenizer_path,
            max_seq_len=self.max_length,
            max_batch_size=self.max_batch_size,
        )

    def get_params(self):
        if not os.path.exists(self.params_url):
            logger.error(f"params_url:{self.params_url} is not exists.")
        content = open(self.params_url).read()
        return json.loads(content)

    def inference(self, message):
        results = self.model.text_completion(
            [message],
            max_gen_len=self.max_length,
            temperature=self.temperature,
            top_p=self.top_p
        )
        return results[0]['generation']
