import os import json import torch import logging from .inference import Inference from transformers import AutoModelForCausalLM, AutoTokenizer logger = logging.getLogger(__name__) class BaiChuanInference(Inference): def __init__(self): super(BaiChuanInference, self).__init__() self.params_url = "../llm_set/params/baichuan.json" self.paras_dict = self.get_params() self.paras_dict.update(self.paras_base_dict) self.temperature = self.paras_dict.get("temperature") self.quantize = self.paras_dict.get("quantize") self.device = self.paras_dict.get("device") self.model_path = self.paras_dict.get("model_path") self.DEV = torch.device(self.device) self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) self.model = AutoModelForCausalLM.from_pretrained(self.model_path, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.float16) self.max_length = self.paras_dict.get("max_length") self.min_length = self.paras_dict.get("min_length") self.top_p = self.paras_dict.get("top_p") self.top_k = self.paras_dict.get("top_k") def get_params(self): if not os.path.exists(self.params_url): logger.error(f"params_url:{self.params_url} is not exists.") content = open(self.params_url).read() return json.loads(content) def inference(self, message): input_ids = self.tokenizer.encode(message, return_tensors="pt").to(self.DEV) sentence = None with torch.no_grad(): generation_output = self.model.generate( input_ids, do_sample=True, min_length=self.min_length, max_length=self.max_length, top_p=self.top_p, top_k=self.top_k, temperature=self.temperature ) output = self.tokenizer.decode([el.item() for el in generation_output[0]]) sentence = output.strip() return sentence