# Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import json import os import sys import time from pathlib import Path from typing import List, Literal, Optional, Tuple, TypedDict import torch import torch.nn.functional as F from fairscale.nn.model_parallel.initialize import ( get_model_parallel_rank, initialize_model_parallel, model_parallel_is_initialized, ) from .model import ModelArgs, Transformer from .tokenizer import Tokenizer Role = Literal["system", "user", "assistant"] class Message(TypedDict): role: Role content: str class CompletionPrediction(TypedDict, total=False): generation: str tokens: List[str] # not required logprobs: List[float] # not required class ChatPrediction(TypedDict, total=False): generation: Message tokens: List[str] # not required logprobs: List[float] # not required Dialog = List[Message] B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" DEFAULT_SYSTEM_PROMPT = """\ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" class Llama: @staticmethod def build( ckpt_dir: str, tokenizer_path: str, max_seq_len: int, max_batch_size: int, model_parallel_size: Optional[int] = None, ) -> "Llama": if not torch.distributed.is_initialized(): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '5678' torch.distributed.init_process_group("nccl", rank=0, world_size=int(os.environ.get("WORLD_SIZE", 1))) if not model_parallel_is_initialized(): if model_parallel_size is None: model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) initialize_model_parallel(model_parallel_size) local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) # seed must be the same in all processes torch.manual_seed(1) if local_rank > 0: sys.stdout = open(os.devnull, "w") start_time = time.time() checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" assert model_parallel_size == len( checkpoints ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" ckpt_path = checkpoints[get_model_parallel_rank()] checkpoint = torch.load(ckpt_path, map_location="cpu") with open(Path(ckpt_dir) / "params.json", "r") as f: params = json.loads(f.read()) model_args: ModelArgs = ModelArgs( max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params, ) tokenizer = Tokenizer(model_path=tokenizer_path) model_args.vocab_size = tokenizer.n_words torch.set_default_tensor_type(torch.cuda.HalfTensor) model = Transformer(model_args) model.load_state_dict(checkpoint, strict=False) print(f"Loaded in {time.time() - start_time:.2f} seconds") return Llama(model, tokenizer) def __init__(self, model: Transformer, tokenizer: Tokenizer): self.model = model self.tokenizer = tokenizer @torch.inference_mode() def generate( self, prompt_tokens: List[List[int]], max_gen_len: int, temperature: float = 0.6, top_p: float = 0.9, logprobs: bool = False, echo: bool = False, ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: params = self.model.params bsz = len(prompt_tokens) assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) min_prompt_len = min(len(t) for t in prompt_tokens) max_prompt_len = max(len(t) for t in prompt_tokens) assert max_prompt_len <= params.max_seq_len total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) pad_id = self.tokenizer.pad_id tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") for k, t in enumerate(prompt_tokens): tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") if logprobs: token_logprobs = torch.zeros_like(tokens, dtype=torch.float) prev_pos = 0 eos_reached = torch.tensor([False] * bsz, device="cuda") input_text_mask = tokens != pad_id for cur_pos in range(min_prompt_len, total_len): logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) if logprobs: token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( input=logits.transpose(1, 2), target=tokens[:, prev_pos + 1 : cur_pos + 1], reduction="none", ignore_index=pad_id, ) if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) next_token = sample_top_p(probs, top_p) else: next_token = torch.argmax(logits[:, -1], dim=-1) next_token = next_token.reshape(-1) # only replace token if prompt has already been generated next_token = torch.where( input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token ) tokens[:, cur_pos] = next_token eos_reached |= (~input_text_mask[:, cur_pos]) & ( next_token == self.tokenizer.eos_id ) prev_pos = cur_pos if all(eos_reached): break if logprobs: token_logprobs = token_logprobs.tolist() out_tokens, out_logprobs = [], [] for i, toks in enumerate(tokens.tolist()): # cut to max gen len start = 0 if echo else len(prompt_tokens[i]) toks = toks[start : len(prompt_tokens[i]) + max_gen_len] if logprobs: probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] # cut to eos tok if any if self.tokenizer.eos_id in toks: eos_idx = toks.index(self.tokenizer.eos_id) toks = toks[:eos_idx] probs = probs[:eos_idx] if logprobs else None out_tokens.append(toks) out_logprobs.append(probs) return (out_tokens, out_logprobs if logprobs else None) def text_completion( self, prompts: List[str], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, echo: bool = False, ) -> List[CompletionPrediction]: if max_gen_len is None: max_gen_len = self.model.params.max_seq_len - 1 prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] generation_tokens, generation_logprobs = self.generate( prompt_tokens=prompt_tokens, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, logprobs=logprobs, echo=echo, ) if logprobs: return [ { "generation": self.tokenizer.decode(t), "tokens": [self.tokenizer.decode(x) for x in t], "logprobs": logprobs_i, } for t, logprobs_i in zip(generation_tokens, generation_logprobs) ] return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] def chat_completion( self, dialogs: List[Dialog], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, ) -> List[ChatPrediction]: if max_gen_len is None: max_gen_len = self.model.params.max_seq_len - 1 prompt_tokens = [] for dialog in dialogs: if dialog[0]["role"] != "system": dialog = [ { "role": "system", "content": DEFAULT_SYSTEM_PROMPT, } ] + dialog dialog = [ { "role": dialog[1]["role"], "content": B_SYS + dialog[0]["content"] + E_SYS + dialog[1]["content"], } ] + dialog[2:] assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( [msg["role"] == "assistant" for msg in dialog[1::2]] ), ( "model only supports 'system', 'user' and 'assistant' roles, " "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" ) dialog_tokens: List[int] = sum( [ self.tokenizer.encode( f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", bos=True, eos=True, ) for prompt, answer in zip( dialog[::2], dialog[1::2], ) ], [], ) assert ( dialog[-1]["role"] == "user" ), f"Last message must be from user, got {dialog[-1]['role']}" dialog_tokens += self.tokenizer.encode( f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", bos=True, eos=False, ) prompt_tokens.append(dialog_tokens) generation_tokens, generation_logprobs = self.generate( prompt_tokens=prompt_tokens, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, logprobs=logprobs, ) if logprobs: return [ { "generation": { "role": "assistant", "content": self.tokenizer.decode(t), }, "tokens": [self.tokenizer.decode(x) for x in t], "logprobs": logprobs_i, } for t, logprobs_i in zip(generation_tokens, generation_logprobs) ] return [ {"generation": {"role": "assistant", "content": self.tokenizer.decode(t)}} for t in generation_tokens ] def sample_top_p(probs, p): probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) mask = probs_sum - probs_sort > p probs_sort[mask] = 0.0 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) return next_token