chatglm2_inference.py 2.5 KB
Newer Older
CSDN-Ada助手's avatar
CSDN-Ada助手 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
import os
import json
import torch
import logging
from transformers import AutoModel, AutoTokenizer
from inference import Inference

logger = logging.getLogger(__name__)


class GLMInference(Inference):

    def __int__(self):
        self.params_url = "llm_set/params/chatgpt.json"
        self.paras_dict = self.get_params()
        self.paras_dict.update(self.paras_base_dict)

        self.temperature = self.paras_dict.get("temperature")
        self.quantize = self.paras_dict.get("quantize")
        self.device = self.paras_dict.get("device")
        self.model_path = self.paras_dict.get("model_path")
        if self.device == "cpu":
            self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True,
                                                   ignore_mismatched_sizes=True).float()
            self.DEV = torch.device('cpu')
        else:
            os.environ["CUDA_VISIBLE_DEVICES"] = self.device.split(":")[1]
            self.DEV = torch.device(self.device)
            if self.quantize:
                self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).quantize(8).cuda()
            else:
                self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).cuda()

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
        self.model = self.model.eval()
        self.max_length = self.paras_dict.get("max_length")
        self.min_length = self.paras_dict.get("min_length")
        self.top_p = self.paras_dict.get("top_p")
        self.top_k = self.paras_dict.get("top_k")

    def get_params(self):
        if not os.path.exists(self.params_url):
            logger.error(f"params_url:{self.params_url} is not exists.")
        content = open(self.params_url).read()
        return json.loads(content)

    def inference(self, message):
        input_ids = self.tokenizer.encode(message, return_tensors="pt").to(self.DEV)

        sentence = None
        with torch.no_grad():
            generation_output = self.model.generate(
                input_ids,
                do_sample=True,
                min_length=self.min_length,
                max_length=self.max_length,
                top_p=self.top_p,
                top_k=self.top_k,
                temperature=self.temperature
            )
            output = self.tokenizer.decode([el.item() for el in generation_output[0]])
            sentence = output.split("ChatCSDN:")[1].strip()
        return sentence