baichuan_inference.py 2.2 KB
Newer Older
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60


import os
import json
import torch
import logging
from .inference import Inference
from transformers import AutoModelForCausalLM, AutoTokenizer


logger = logging.getLogger(__name__)


class BaiChuanInference(Inference):

    def __init__(self):
        super(BaiChuanInference, self).__init__()
        self.params_url = "../llm_set/params/baichuan.json"
        self.paras_dict = self.get_params()
        self.paras_dict.update(self.paras_base_dict)

        self.temperature = self.paras_dict.get("temperature")
        self.quantize = self.paras_dict.get("quantize")
        self.device = self.paras_dict.get("device")
        self.model_path = self.paras_dict.get("model_path")
        self.DEV = torch.device(self.device)

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(self.model_path, device_map="auto",
                                                          trust_remote_code=True,
                                                          low_cpu_mem_usage=True,
                                                          torch_dtype=torch.float16)
        self.max_length = self.paras_dict.get("max_length")
        self.min_length = self.paras_dict.get("min_length")
        self.top_p = self.paras_dict.get("top_p")
        self.top_k = self.paras_dict.get("top_k")

    def get_params(self):
        if not os.path.exists(self.params_url):
            logger.error(f"params_url:{self.params_url} is not exists.")
        content = open(self.params_url).read()
        return json.loads(content)

    def inference(self, message):
        input_ids = self.tokenizer.encode(message, return_tensors="pt").to(self.DEV)

        sentence = None
        with torch.no_grad():
            generation_output = self.model.generate(
                input_ids,
                do_sample=True,
                min_length=self.min_length,
                max_length=self.max_length,
                top_p=self.top_p,
                top_k=self.top_k,
                temperature=self.temperature
            )
            output = self.tokenizer.decode([el.item() for el in generation_output[0]])
            sentence = output.strip()
        return sentence