import os import json import torch import logging from transformers import AutoModel, AutoTokenizer from inference import Inference logger = logging.getLogger(__name__) class GLMInference(Inference): def __int__(self): self.params_url = "llm_set/params/chatgpt.json" self.paras_dict = self.get_params() self.paras_dict.update(self.paras_base_dict) self.temperature = self.paras_dict.get("temperature") self.quantize = self.paras_dict.get("quantize") self.device = self.paras_dict.get("device") self.model_path = self.paras_dict.get("model_path") if self.device == "cpu": self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True, ignore_mismatched_sizes=True).float() self.DEV = torch.device('cpu') else: os.environ["CUDA_VISIBLE_DEVICES"] = self.device.split(":")[1] self.DEV = torch.device(self.device) if self.quantize: self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).quantize(8).cuda() else: self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).cuda() self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) self.model = self.model.eval() self.max_length = self.paras_dict.get("max_length") self.min_length = self.paras_dict.get("min_length") self.top_p = self.paras_dict.get("top_p") self.top_k = self.paras_dict.get("top_k") def get_params(self): if not os.path.exists(self.params_url): logger.error(f"params_url:{self.params_url} is not exists.") content = open(self.params_url).read() return json.loads(content) def inference(self, message): input_ids = self.tokenizer.encode(message, return_tensors="pt").to(self.DEV) sentence = None with torch.no_grad(): generation_output = self.model.generate( input_ids, do_sample=True, min_length=self.min_length, max_length=self.max_length, top_p=self.top_p, top_k=self.top_k, temperature=self.temperature ) output = self.tokenizer.decode([el.item() for el in generation_output[0]]) sentence = output.split("ChatCSDN:")[1].strip() return sentence