diff --git a/README.md b/README.md index 26351ecbf834a6783992d27fc5cf4c87c6e1baad..0e2d942776dec82b7aa5b116648664386e7296f5 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,7 @@ python evaluate_humaneval_x.py --language_type python | bbt-13B | 2.49% | 0% | 1.90% | 1.83% | 0.61% | | chatglm2-6B | 7.93% | 5.45% | 0.61% | 6.70% | 1.83% | | codegeex2-6B | 29.90% | 27.43% | 6.70% | 24.40% | 17.68% | +| llama2-7B | 5.49% | 8.54% | 1.22% | 3.66% | 6.10% | diff --git a/llm_set/params/llama2.json b/llm_set/params/llama2.json new file mode 100644 index 0000000000000000000000000000000000000000..0642912bdc933a3ee0d9c226234fdb5c3f3fc7ec --- /dev/null +++ b/llm_set/params/llama2.json @@ -0,0 +1,8 @@ +{ + "model_path": "../llm_set/models/llama-2-7b/", + "tokenizer_path": "../llm_set/models/llama-2-7b/tokenizer.model", + "device": "cuda:0", + "max_length": 1024, + "min_length": 10, + "max_batch_size": 8 +} \ No newline at end of file diff --git a/src/evaluate_humaneval_x.py b/src/evaluate_humaneval_x.py index f8f0cf8b02e1529e6b8f0f8d2768cb2ac979ca65..5ea019ab4b272626806597c94b641b72c175a6b9 100644 --- a/src/evaluate_humaneval_x.py +++ b/src/evaluate_humaneval_x.py @@ -43,6 +43,8 @@ def process_humaneval_test(sample, problems, example_test=False): continue if line and line[0] != ' ' and line[0] != '\t': line = " " + line + if not code_ and len(line) > 4 and line[0] == " " and line[1] == " " and line[2] == " " and line[3] != " ": + line = " " + line code_.append(line) code = "\n".join(code_) test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n" diff --git a/src/generate_humaneval_x.py b/src/generate_humaneval_x.py index e18598fbcbe99160bfe59f567f3085ba985922c3..9e47a37d549cc6b5fa6a30ebc383590d379661ec 100644 --- a/src/generate_humaneval_x.py +++ b/src/generate_humaneval_x.py @@ -58,6 +58,9 @@ if __name__ == "__main__": elif args.model_name == "codegeex2": from inference.codegeex2_inference import CodeGeex2Inference model_inference = CodeGeex2Inference() + elif args.model_name == "llama2": + from inference.llama2_inference import LLAMA2Inference + model_inference = LLAMA2Inference() else: print(f"{args.model_name} not supported.") diff --git a/src/inference/llama/__init__.py b/src/inference/llama/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..354342dd9cded69ae63de34229cc0cb8d1345029 --- /dev/null +++ b/src/inference/llama/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from .generation import Llama +from .model import ModelArgs, Transformer +from .tokenizer import Tokenizer diff --git a/src/inference/llama/generation.py b/src/inference/llama/generation.py new file mode 100755 index 0000000000000000000000000000000000000000..91cb7224e629795ad4d059e8694f938f7a2fc93b --- /dev/null +++ b/src/inference/llama/generation.py @@ -0,0 +1,305 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import json +import os +import sys +import time +from pathlib import Path +from typing import List, Literal, Optional, Tuple, TypedDict + +import torch +import torch.nn.functional as F +from fairscale.nn.model_parallel.initialize import ( + get_model_parallel_rank, + initialize_model_parallel, + model_parallel_is_initialized, +) + +from .model import ModelArgs, Transformer +from .tokenizer import Tokenizer + +Role = Literal["system", "user", "assistant"] + + +class Message(TypedDict): + role: Role + content: str + + +class CompletionPrediction(TypedDict, total=False): + generation: str + tokens: List[str] # not required + logprobs: List[float] # not required + + +class ChatPrediction(TypedDict, total=False): + generation: Message + tokens: List[str] # not required + logprobs: List[float] # not required + + +Dialog = List[Message] + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" +DEFAULT_SYSTEM_PROMPT = """\ +You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" + + +class Llama: + @staticmethod + def build( + ckpt_dir: str, + tokenizer_path: str, + max_seq_len: int, + max_batch_size: int, + model_parallel_size: Optional[int] = None, + ) -> "Llama": + if not torch.distributed.is_initialized(): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '5678' + torch.distributed.init_process_group("nccl", rank=0, world_size=int(os.environ.get("WORLD_SIZE", 1))) + if not model_parallel_is_initialized(): + if model_parallel_size is None: + model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) + initialize_model_parallel(model_parallel_size) + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + # seed must be the same in all processes + torch.manual_seed(1) + + if local_rank > 0: + sys.stdout = open(os.devnull, "w") + + start_time = time.time() + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" + assert model_parallel_size == len( + checkpoints + ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" + ckpt_path = checkpoints[get_model_parallel_rank()] + checkpoint = torch.load(ckpt_path, map_location="cpu") + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + **params, + ) + tokenizer = Tokenizer(model_path=tokenizer_path) + model_args.vocab_size = tokenizer.n_words + torch.set_default_tensor_type(torch.cuda.HalfTensor) + model = Transformer(model_args) + model.load_state_dict(checkpoint, strict=False) + print(f"Loaded in {time.time() - start_time:.2f} seconds") + + return Llama(model, tokenizer) + + def __init__(self, model: Transformer, tokenizer: Tokenizer): + self.model = model + self.tokenizer = tokenizer + + @torch.inference_mode() + def generate( + self, + prompt_tokens: List[List[int]], + max_gen_len: int, + temperature: float = 0.6, + top_p: float = 0.9, + logprobs: bool = False, + echo: bool = False, + ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: + params = self.model.params + bsz = len(prompt_tokens) + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + + min_prompt_len = min(len(t) for t in prompt_tokens) + max_prompt_len = max(len(t) for t in prompt_tokens) + assert max_prompt_len <= params.max_seq_len + total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) + + pad_id = self.tokenizer.pad_id + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + if logprobs: + token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + + prev_pos = 0 + eos_reached = torch.tensor([False] * bsz, device="cuda") + input_text_mask = tokens != pad_id + for cur_pos in range(min_prompt_len, total_len): + logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + if logprobs: + token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens[:, prev_pos + 1 : cur_pos + 1], + reduction="none", + ignore_index=pad_id, + ) + if temperature > 0: + probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits[:, -1], dim=-1) + + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where( + input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) + tokens[:, cur_pos] = next_token + eos_reached |= (~input_text_mask[:, cur_pos]) & ( + next_token == self.tokenizer.eos_id + ) + prev_pos = cur_pos + if all(eos_reached): + break + + if logprobs: + token_logprobs = token_logprobs.tolist() + out_tokens, out_logprobs = [], [] + for i, toks in enumerate(tokens.tolist()): + # cut to max gen len + start = 0 if echo else len(prompt_tokens[i]) + toks = toks[start : len(prompt_tokens[i]) + max_gen_len] + if logprobs: + probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] + # cut to eos tok if any + if self.tokenizer.eos_id in toks: + eos_idx = toks.index(self.tokenizer.eos_id) + toks = toks[:eos_idx] + probs = probs[:eos_idx] if logprobs else None + out_tokens.append(toks) + out_logprobs.append(probs) + return (out_tokens, out_logprobs if logprobs else None) + + def text_completion( + self, + prompts: List[str], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + echo: bool = False, + ) -> List[CompletionPrediction]: + if max_gen_len is None: + max_gen_len = self.model.params.max_seq_len - 1 + prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] + generation_tokens, generation_logprobs = self.generate( + prompt_tokens=prompt_tokens, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + echo=echo, + ) + if logprobs: + return [ + { + "generation": self.tokenizer.decode(t), + "tokens": [self.tokenizer.decode(x) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i in zip(generation_tokens, generation_logprobs) + ] + return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] + + def chat_completion( + self, + dialogs: List[Dialog], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + ) -> List[ChatPrediction]: + if max_gen_len is None: + max_gen_len = self.model.params.max_seq_len - 1 + prompt_tokens = [] + for dialog in dialogs: + if dialog[0]["role"] != "system": + dialog = [ + { + "role": "system", + "content": DEFAULT_SYSTEM_PROMPT, + } + ] + dialog + dialog = [ + { + "role": dialog[1]["role"], + "content": B_SYS + + dialog[0]["content"] + + E_SYS + + dialog[1]["content"], + } + ] + dialog[2:] + assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( + [msg["role"] == "assistant" for msg in dialog[1::2]] + ), ( + "model only supports 'system', 'user' and 'assistant' roles, " + "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" + ) + dialog_tokens: List[int] = sum( + [ + self.tokenizer.encode( + f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", + bos=True, + eos=True, + ) + for prompt, answer in zip( + dialog[::2], + dialog[1::2], + ) + ], + [], + ) + assert ( + dialog[-1]["role"] == "user" + ), f"Last message must be from user, got {dialog[-1]['role']}" + dialog_tokens += self.tokenizer.encode( + f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", + bos=True, + eos=False, + ) + prompt_tokens.append(dialog_tokens) + + generation_tokens, generation_logprobs = self.generate( + prompt_tokens=prompt_tokens, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + ) + if logprobs: + return [ + { + "generation": { + "role": "assistant", + "content": self.tokenizer.decode(t), + }, + "tokens": [self.tokenizer.decode(x) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i in zip(generation_tokens, generation_logprobs) + ] + return [ + {"generation": {"role": "assistant", "content": self.tokenizer.decode(t)}} + for t in generation_tokens + ] + + +def sample_top_p(probs, p): + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token diff --git a/src/inference/llama/model.py b/src/inference/llama/model.py new file mode 100755 index 0000000000000000000000000000000000000000..258a7dc191cd1745a4ec0a73adeaa9df012e4bf6 --- /dev/null +++ b/src/inference/llama/model.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import math +from dataclasses import dataclass +from typing import Any, Optional, Tuple + +import fairscale.nn.model_parallel.initialize as fs_init +import torch +import torch.nn.functional as F +from fairscale.nn.model_parallel.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from torch import nn + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + model_parallel_size = fs_init.get_model_parallel_world_size() + self.n_local_heads = args.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + + self.wq = ColumnParallelLinear( + args.dim, + args.n_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wk = ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wv = ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wo = RowParallelLinear( + args.n_heads * self.head_dim, + args.dim, + bias=False, + input_is_parallel=True, + init_method=lambda x: x, + ) + + self.cache_k = torch.zeros( + ( + args.max_batch_size, + args.max_seq_len, + self.n_local_kv_heads, + self.head_dim, + ) + ).cuda() + self.cache_v = torch.zeros( + ( + args.max_batch_size, + args.max_seq_len, + self.n_local_kv_heads, + self.head_dim, + ) + ).cuda() + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + bsz, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + self.cache_k = self.cache_k.to(xq) + self.cache_v = self.cache_v.to(xq) + + self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk + self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv + + keys = self.cache_k[:bsz, : start_pos + seqlen] + values = self.cache_v[:bsz, : start_pos + seqlen] + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = ColumnParallelLinear( + dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x + ) + self.w2 = RowParallelLinear( + hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x + ) + self.w3 = ColumnParallelLinear( + dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x + ) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(args) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + h = x + self.attention.forward( + self.attention_norm(x), start_pos, freqs_cis, mask + ) + out = h + self.feed_forward.forward(self.ffn_norm(h)) + return out + + +class Transformer(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.tok_embeddings = ParallelEmbedding( + params.vocab_size, params.dim, init_method=lambda x: x + ) + + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + self.layers.append(TransformerBlock(layer_id, params)) + + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = ColumnParallelLinear( + params.dim, params.vocab_size, bias=False, init_method=lambda x: x + ) + + self.freqs_cis = precompute_freqs_cis( + self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 + ) + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, start_pos: int): + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + self.freqs_cis = self.freqs_cis.to(h.device) + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + + mask = None + if seqlen > 1: + mask = torch.full( + (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device + ) + mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) + + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + output = self.output(h).float() + return output diff --git a/src/inference/llama/tokenizer.py b/src/inference/llama/tokenizer.py new file mode 100755 index 0000000000000000000000000000000000000000..e3af011126ecdd4abcf41ad3035499796642483f --- /dev/null +++ b/src/inference/llama/tokenizer.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import os +from logging import getLogger +from typing import List + +from sentencepiece import SentencePieceProcessor + + +logger = getLogger() + + +class Tokenizer: + def __init__(self, model_path: str): + # reload tokenizer + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + logger.info(f"Reloaded SentencePiece model from {model_path}") + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.pad_id() + logger.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + def encode(self, s: str, bos: bool, eos: bool) -> List[int]: + assert type(s) is str + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + return self.sp_model.decode(t) diff --git a/src/inference/llama2_inference.py b/src/inference/llama2_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..39766bf6915b48a0dd7ee6756d7145e37c429931 --- /dev/null +++ b/src/inference/llama2_inference.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/7/26 10:28 +# @Author : clong +# @File : llama2_inference.py + + +import os +import json +import torch +import logging +from .llama import Llama +from .inference import Inference + +logger = logging.getLogger(__name__) + + +class LLAMA2Inference(Inference): + + def __init__(self): + super(LLAMA2Inference, self).__init__() + self.params_url = "../llm_set/params/llama2.json" + self.paras_dict = self.get_params() + self.paras_dict.update(self.paras_base_dict) + + self.temperature = self.paras_dict.get("temperature") + self.model_path = self.paras_dict.get("model_path") + self.tokenizer_path = self.paras_dict.get("tokenizer_path") + self.max_batch_size = self.paras_dict.get("max_batch_size") + + self.max_length = self.paras_dict.get("max_length") + self.min_length = self.paras_dict.get("min_length") + self.top_p = self.paras_dict.get("top_p") + self.top_k = self.paras_dict.get("top_k") + + self.model = Llama.build( + ckpt_dir=self.model_path, + tokenizer_path=self.tokenizer_path, + max_seq_len=self.max_length, + max_batch_size=self.max_batch_size, + ) + + def get_params(self): + if not os.path.exists(self.params_url): + logger.error(f"params_url:{self.params_url} is not exists.") + content = open(self.params_url).read() + return json.loads(content) + + def inference(self, message): + results = self.model.text_completion( + [message], + max_gen_len=self.max_length, + temperature=self.temperature, + top_p=self.top_p + ) + return results[0]['generation']