From 5f0fb2fdd4cac0f72038b7a4df8fb2095355c624 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Wed, 8 Mar 2023 15:06:17 +0800 Subject: [PATCH] add reward model --- forward_demo.py | 235 +++++++++++++++ src/rlhf/__init__.py | 3 + src/rlhf/lora.py | 34 +++ src/rlhf/optimizer.py | 39 +++ src/rlhf/palm.py | 530 +++++++++++++++++++++++++++++++++ src/rlhf/ppo.py | 660 ++++++++++++++++++++++++++++++++++++++++++ src/rlhf/reward.py | 124 ++++++++ src/rlhf/utils.py | 67 +++++ train_rm.py | 24 ++ 9 files changed, 1716 insertions(+) create mode 100644 forward_demo.py create mode 100644 src/rlhf/__init__.py create mode 100644 src/rlhf/lora.py create mode 100644 src/rlhf/optimizer.py create mode 100644 src/rlhf/palm.py create mode 100644 src/rlhf/ppo.py create mode 100644 src/rlhf/reward.py create mode 100644 src/rlhf/utils.py create mode 100644 train_rm.py diff --git a/forward_demo.py b/forward_demo.py new file mode 100644 index 0000000..1608ff5 --- /dev/null +++ b/forward_demo.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/3/1 11:54 +# @Author : clong +# @File : train_sft.py + + + +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +if __name__ == "__main__": + from argparse import ArgumentParser + from pytorch_lightning import Trainer + from pytorch_lightning.utilities import rank_zero_info, rank_zero_only + + rank_zero_info("########## work in progress ##########") + + ######################################################################################################## + # + # example: train a simple L12-D768 RWKV on dummy data + # + # python train.py --load_model "" --wandb "" --proj_dir "out" \ + # --data_file "" --data_type "dummy" --vocab_size 0 \ + # --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \ + # --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \ + # --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ + # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0 + + # example: train a simple L6-D512 RWKV from scratch on enwik8 + # + # python train.py --load_model "" --wandb "" --proj_dir "out" \ + # --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \ + # --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \ + # --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \ + # --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ + # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0 + + # example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M + # + # python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \ + # --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \ + # --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \ + # --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ + # --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ + # --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0 + + # example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow + # + # python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \ + # --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \ + # --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \ + # --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ + # --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ + # --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1 + + parser = ArgumentParser() + + parser.add_argument("--load_model", default="", type=str) # full path, with .pth + parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb + parser.add_argument("--proj_dir", default="out", type=str) + parser.add_argument("--random_seed", default="-1", type=int) + + parser.add_argument("--data_file", default="", type=str) + parser.add_argument("--data_type", default="utf-8", type=str) + parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data) + + parser.add_argument("--ctx_len", default=1024, type=int) + parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps + parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final + parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x + parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs" + + parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU) + parser.add_argument("--n_layer", default=6, type=int) + parser.add_argument("--n_embd", default=512, type=int) + parser.add_argument("--dim_att", default=0, type=int) + parser.add_argument("--dim_ffn", default=0, type=int) + parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better) + parser.add_argument("--head_qk", default=0, type=int) # my headQK trick + parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim + parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer + + parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 + parser.add_argument("--lr_final", default=1e-5, type=float) + parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model + parser.add_argument("--beta1", default=0.9, type=float) + parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence + parser.add_argument("--adam_eps", default=1e-8, type=float) + + parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower + parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode + parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift + parser.add_argument("--my_pile_edecay", default=0, type=int) + parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s) + parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough + # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful) + + parser.add_argument("--my_img_version", default=0, type=str) + parser.add_argument("--my_img_size", default=0, type=int) + parser.add_argument("--my_img_bit", default=0, type=int) + parser.add_argument("--my_img_clip", default='x', type=str) + parser.add_argument("--my_img_clip_scale", default=1, type=float) + parser.add_argument("--my_img_l1_scale", default=0, type=float) + parser.add_argument("--my_img_encoder", default='x', type=str) + # parser.add_argument("--my_img_noise_scale", default=0, type=float) + parser.add_argument("--my_sample_len", default=0, type=int) + parser.add_argument("--my_ffn_shift", default=1, type=int) + parser.add_argument("--my_att_shift", default=1, type=int) + parser.add_argument("--my_pos_emb", default=0, type=int) + parser.add_argument("--load_partial", default=0, type=int) + parser.add_argument("--magic_prime", default=0, type=int) + parser.add_argument("--my_qa_mask", default=0, type=int) + parser.add_argument("--my_testing", default='', type=str) + + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args() + + ######################################################################################################## + + import os, warnings, math, datetime, sys, time + import numpy as np + import torch + from torch.utils.data import DataLoader + import deepspeed + import pytorch_lightning as pl + from pytorch_lightning import seed_everything + + if args.random_seed >= 0: + print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3) + seed_everything(args.random_seed) + + np.set_printoptions(precision=4, suppress=True, linewidth=200) + warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") + warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*") + # os.environ["WDS_SHOW_SEED"] = "1" + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") + args.enable_checkpointing = False + args.replace_sampler_ddp = False + args.logger = False + args.gradient_clip_val = 1.0 + args.num_sanity_val_steps = 0 + args.check_val_every_n_epoch = int(1e20) + args.log_every_n_steps = int(1e20) + args.max_epochs = -1 # continue forever + args.betas = (args.beta1, args.beta2) + args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz + os.environ["RWKV_T_MAX"] = str(args.ctx_len) + os.environ["RWKV_MY_TESTING"] = args.my_testing + if args.dim_att <= 0: + args.dim_att = args.n_embd + if args.dim_ffn <= 0: + args.dim_ffn = args.n_embd * 4 + + args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" + if not os.path.exists(args.proj_dir): + os.makedirs(args.proj_dir) + + samples_per_epoch = args.epoch_steps * args.real_bsz + tokens_per_epoch = samples_per_epoch * args.ctx_len + rank_zero_info( + f""" +############################################################################ +# +# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''} +# +# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir} +# +# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch +# +# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens +# +# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len +# +# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps} +# +# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer +# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions) +# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer +# +############################################################################ +""" + ) + rank_zero_info(str(vars(args)) + "\n") + + assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"] + + if args.lr_final == 0 or args.lr_init == 0: + rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n") + + assert args.precision in ["fp32", "tf32", "fp16", "bf16"] + os.environ["RWKV_FLOAT_MODE"] = args.precision + if args.precision == "fp32": + rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n") + if args.precision == "fp16": + rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n") + + os.environ["RWKV_JIT_ON"] = "1" + if "deepspeed_stage_3" in args.strategy: + os.environ["RWKV_JIT_ON"] = "0" + + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.enabled = True + if args.precision == "fp32": + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + else: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + if "32" in args.precision: + args.precision = 32 + elif args.precision == "fp16": + args.precision = 16 + else: + args.precision = "bf16" + + ######################################################################################################## + + from src.trainer import train_callback, generate_init_weight + + args.vocab_size = 20000 + + from src.model import RWKV + model = RWKV(args) + + seq = torch.randint(0, 20000, (1, 100)) + model(seq) + + import ipdb + ipdb.set_trace() + diff --git a/src/rlhf/__init__.py b/src/rlhf/__init__.py new file mode 100644 index 0000000..3a5b596 --- /dev/null +++ b/src/rlhf/__init__.py @@ -0,0 +1,3 @@ +# from palm_rlhf_pytorch.palm import PaLM +# from palm_rlhf_pytorch.reward import RewardModel +# from palm_rlhf_pytorch.ppo import RLHFTrainer, ActorCritic diff --git a/src/rlhf/lora.py b/src/rlhf/lora.py new file mode 100644 index 0000000..ea4af3f --- /dev/null +++ b/src/rlhf/lora.py @@ -0,0 +1,34 @@ +import torch +from torch import nn + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# LoRA - https://arxiv.org/abs/2106.09685 + +class LoRA(nn.Module): + def __init__( + self, + dim, + dim_out, + r = 8, + alpha = None + ): + super().__init__() + alpha = default(alpha, r) + self.scale = alpha / r + + self.A = nn.Parameter(torch.randn(dim, r)) + self.B = nn.Parameter(torch.zeros(r, dim_out)) + + @property + def weight(self): + return (self.A @ self.B) * self.scale + + def forward(self, x): + return x @ self.weight diff --git a/src/rlhf/optimizer.py b/src/rlhf/optimizer.py new file mode 100644 index 0000000..95e739c --- /dev/null +++ b/src/rlhf/optimizer.py @@ -0,0 +1,39 @@ +from torch.optim import AdamW, Adam +from lion_pytorch import Lion + +def separate_weight_decayable_params(params): + wd_params, no_wd_params = [], [] + for param in params: + param_list = no_wd_params if param.ndim < 2 else wd_params + param_list.append(param) + return wd_params, no_wd_params + +def get_optimizer( + params, + lr = 1e-4, + wd = 1e-2, + betas = (0.9, 0.99), + eps = 1e-8, + filter_by_requires_grad = False, + group_wd_params = True, + use_lion = True, + **kwargs +): + if filter_by_requires_grad: + params = list(filter(lambda t: t.requires_grad, params)) + + if group_wd_params and wd > 0: + wd_params, no_wd_params = separate_weight_decayable_params(params) + + params = [ + {'params': wd_params}, + {'params': no_wd_params, 'weight_decay': 0}, + ] + + if use_lion: + return Lion(params, lr = lr, betas = betas, weight_decay = wd) + + if wd == 0: + return Adam(params, lr = lr, betas = betas, eps = eps) + + return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps) diff --git a/src/rlhf/palm.py b/src/rlhf/palm.py new file mode 100644 index 0000000..ba9836a --- /dev/null +++ b/src/rlhf/palm.py @@ -0,0 +1,530 @@ +import math +import copy +from pathlib import Path +from collections import namedtuple +from itertools import zip_longest + +from tqdm import tqdm +from beartype import beartype +from beartype.typing import Tuple, Optional + +import torch +from torch import einsum, nn +import torch.nn.functional as F + +from einops import rearrange, repeat, reduce, pack, unpack +from einops.layers.torch import Rearrange, Reduce + +from src.rlhf.utils import top_p, top_k, masked_mean, gumbel_sample, eval_decorator +from src.rlhf.lora import LoRA + +# functions and decorators + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def identity(t, *args, **kwargs): + return t + +def l2norm(t): + return F.normalize(t, dim = -1) + +# normalization +# they use layernorm without bias, something that pytorch does not offer + +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.ones(dim)) + self.register_buffer("beta", torch.zeros(dim)) + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) + +# residual + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, **kwargs): + y = self.fn(x, **kwargs) + + if not any([t.requires_grad for t in (x, y)]): + return x.add_(y) + + return y + x + +# rotary positional embedding w/ xpos +# https://arxiv.org/abs/2104.09864 +# https://arxiv.org/abs/2212.10554v1 + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, scale_base = 512, use_xpos = True): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + self.use_xpos = use_xpos + self.scale_base = scale_base + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.register_buffer('scale', scale) + + def forward(self, seq_len, device): + t = torch.arange(seq_len, device = device).type_as(self.inv_freq) + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + freqs = torch.cat((freqs, freqs), dim = -1) + + if not self.use_xpos: + return freqs, torch.ones(1, device = device) + + power = (t - (seq_len // 2)) / self.scale_base + scale = self.scale ** rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim = -1) + + return freqs, scale + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(pos, t, scale = 1.): + return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale) + + +# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward +# https://arxiv.org/abs/2002.05202 + + +class SwiGLU(nn.Module): + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + +# parallel attention and feedforward with residual +# discovered by Wang et al + EleutherAI from GPT-J fame + + +class ParallelTransformerBlock(nn.Module): + def __init__( + self, + dim, + dim_head = 64, + causal = True, + heads = 8, + qk_rmsnorm = False, + qk_scale = 8, + ff_mult = 4, + attn_dropout = 0., + ff_dropout = 0., + use_xpos = True, + xpos_scale_base = 512 + ): + super().__init__() + self.norm = LayerNorm(dim) + + attn_inner_dim = dim_head * heads + ff_inner_dim = dim * ff_mult + self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) + + self.qk_rmsnorm = qk_rmsnorm + + if qk_rmsnorm: + self.q_scale = nn.Parameter(torch.ones(dim_head)) + self.k_scale = nn.Parameter(torch.ones(dim_head)) + + self.heads = heads + self.scale = (dim_head ** -0.5) if not qk_rmsnorm else qk_scale + self.causal = causal + + self.rotary_emb = RotaryEmbedding(dim_head, scale_base = xpos_scale_base, use_xpos = use_xpos and causal) + + self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) + + self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) + self.attn_dropout = nn.Dropout(attn_dropout) + + # parallel feedforward tail + + self.ff_out = nn.Sequential( + SwiGLU(), + nn.Dropout(ff_dropout), + nn.Linear(ff_inner_dim, dim, bias=False) + ) + + # for caching causal mask and rotary embeddings + + self.register_buffer("mask", None, persistent=False) + self.register_buffer("pos_emb", None, persistent=False) + self.register_buffer("pos_emb_scale", None, persistent=False) + + def get_mask(self, n, device): + if exists(self.mask) and self.mask.shape[-1] >= n: + return self.mask[:n, :n] + + mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) + self.register_buffer("mask", mask, persistent=False) + return mask + + def get_rotary_embedding(self, n, device): + if exists(self.pos_emb) and self.pos_emb.shape[-2] >= n: + return self.pos_emb[:n], self.pos_emb_scale[:n] + + pos_emb, scale = self.rotary_emb(n, device=device) + self.register_buffer("pos_emb", pos_emb, persistent=False) + self.register_buffer("pos_emb_scale", scale, persistent=False) + return pos_emb, scale + + def forward( + self, + x, + mask = None, + finetune_modules = None + ): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + n, device, h = x.shape[1], x.device, self.heads + + # pre layernorm + + x = self.norm(x) + + # attention queries, keys, values, and feedforward inner + + q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) + + # finetune loras + + lora_q = lora_k = lora_v = lora_o = None + + if exists(finetune_modules): + lora_q, lora_k, lora_v, lora_o = finetune_modules + q = q + lora_q(x) + k = k + lora_k(x) + v = v + lora_v(x) + + # split heads + # they use multi-query single-key-value attention, yet another Noam Shazeer paper + # they found no performance loss past a certain scale, and more efficient decoding obviously + # https://arxiv.org/abs/1911.02150 + + q = rearrange(q, "b n (h d) -> b h n d", h=h) + + # qk rmsnorm + + if self.qk_rmsnorm: + q, k = map(l2norm, (q, k)) + q = q * self.q_scale + k = k * self.k_scale + + # rotary embeddings with xpos decay for better length extrapolation + + positions, scale = self.get_rotary_embedding(n, device) + + q = apply_rotary_pos_emb(positions, q, scale) + k = apply_rotary_pos_emb(positions, k, scale ** -1) + + # similarity + + sim = einsum("b h i d, b j d -> b h i j", q, k) * self.scale + + # key padding mask + + if exists(mask): + mask = rearrange(mask, 'b j -> b 1 1 j') + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # causal mask + + if self.causal: + causal_mask = self.get_mask(n, device) + sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + + # attention + + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + + out = einsum("b h i j, b j d -> b h i d", attn, v) + + # merge heads + + out = rearrange(out, "b h n d -> b n (h d)") + + attn_out = self.attn_out(out) + + ff_out = self.ff_out(ff) + + if exists(lora_o): + attn_out = attn_out + lora_o(out) + + return attn_out + ff_out + +# transformer + +@beartype +class PaLM(nn.Module): + def __init__( + self, + *, + dim, + num_tokens, + depth, + causal = True, + dim_head = 64, + heads = 8, + ff_mult = 4, + attn_dropout = 0., + ff_dropout = 0., + qk_rmsnorm = False, + lora_r = 8, + rotary_xpos_scale_base = 512, + finetune_scopes = tuple(), + cross_entropy_ignore_index = 0 + ): + super().__init__() + self.dim = dim + self.dim_head = dim_head + self.heads = heads + self.causal = causal + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, dim) + self.layers = nn.ModuleList([]) + + for _ in range(depth): + block = Residual(ParallelTransformerBlock( + dim = dim, + causal = causal, + dim_head = dim_head, + heads = heads, + qk_rmsnorm = qk_rmsnorm, + ff_mult = ff_mult, + attn_dropout = attn_dropout, + ff_dropout = ff_dropout, + xpos_scale_base = rotary_xpos_scale_base + )) + + self.layers.append(block) + + self.norm = LayerNorm(dim) + self.to_logits = nn.Linear(dim, num_tokens, bias=False) + + self.to_logits.weight = self.token_emb.weight + + nn.init.normal_(self.token_emb.weight, std=0.02) + + # fine tuning related + + self.lora_r = lora_r + self.finetune_modules = nn.ModuleDict({}) + + for scope in finetune_scopes: + self.add_finetune_params(scope) + + # loss related + + self.cross_entropy_ignore_index = cross_entropy_ignore_index + + @property + def device(self): + return next(self.parameters()).device + + def load(self, path): + path = Path(path) + assert path.exists() + self.load_state_dict(torch.load(str(path))) + + def set_dropout(self, dropout): + for module in self.layers.modules(): + if isinstance(module, nn.Dropout): + module.p = dropout + return self + + def add_finetune_params(self, scope, lora_r = None): + assert scope not in self.finetune_modules, f'finetune scope {scope} already found' + dim, dim_head, heads, r, device = self.dim, self.dim_head, self.heads, default(lora_r, self.lora_r), self.device + + q_inner_dim = heads * dim_head + kv_inner_dim = dim_head + + lora_modules = nn.ModuleList([]) + + for _ in range(len(self.layers)): + lora_modules.append(nn.ModuleList([ + LoRA(dim, q_inner_dim, r = r), # queries + LoRA(dim, kv_inner_dim, r = r), # keys + LoRA(dim, kv_inner_dim, r = r), # values + LoRA(q_inner_dim, dim, r = r) # wo + ])) + + self.finetune_modules[scope] = lora_modules.to(device) + + def remove_finetune_params(self, scope): + assert scope in self.finetune_modules, f'finetune scope {scope} not found' + return self.finetune_modules.pop(scope) + + @torch.no_grad() + def merge_finetune_params(self, scope): + """ in the case one wants to merge the fine-tuned actor LORA parameters and do multiple rounds of fine tuning off different reward models """ + + assert scope in self.finetune_modules, f'finetune scope {scope} not found' + + lora_modules = self.finetune_modules.pop(scope) + + for layer, (lora_q, lora_k, lora_v, lora_o) in zip(self.layers, lora_modules): + block = layer.fn + + fused_attn_ff_weight = block.fused_attn_ff_proj.weight + attn_out_weight = block.attn_out.weight + + fused_proj_out_dim = fused_attn_ff_weight.shape[0] + + lora_qkv_weight, _ = pack([lora_q.weight, lora_k.weight, lora_v.weight], 'i *') + lora_qkv_weight = F.pad(lora_qkv_weight, (0, fused_proj_out_dim - lora_qkv_weight.shape[1])) + + lora_qkv_weight = rearrange(lora_qkv_weight, 'i o -> o i') + lora_o_weight = rearrange(lora_o.weight, 'i o -> o i') + + fused_attn_ff_weight.add_(lora_qkv_weight) + attn_out_weight.add_(lora_o_weight) + + # researcher train palm parameters first + # before finetuning + + def palm_parameters(self): + return set(self.parameters()) - set(self.finetune_modules.parameters()) + + def finetune_parameters(self, scope = 'default'): + assert scope in self.finetune_modules, f'finetune parameters of scope {scope} not found' + return self.finetune_modules[scope].parameters() + + # generate function + + @torch.no_grad() + @eval_decorator + def generate( + self, + seq_len, + prompt = None, + temperature = 1., + filter_logits_fn = top_k, + filter_thres = 0.9, + pad_value = 0., + eos_token = None, + return_seq_without_prompt = True, + use_tqdm = False, + **kwargs + ): + if not exists(prompt): + prompt = torch.randint(0, self.num_tokens, (1, 1)) + prompt = prompt.to(self.device) + return_seq_without_prompt = False + + prompt, leading_dims = pack([prompt], '* n') + + n, out = prompt.shape[-1], prompt.clone() + + wrapper_fn = identity if not use_tqdm else tqdm + sample_num_times = max(1, seq_len - prompt.shape[-1]) + + for _ in wrapper_fn(range(sample_num_times)): + logits, embeds = self.forward(out, return_logits_with_embedding = True, **kwargs) + logits, embeds = logits[:, -1], embeds[:, -1] + + if exists(filter_logits_fn): + logits = filter_logits_fn(logits, thres = filter_thres) + + sample = gumbel_sample(logits, temperature = temperature, dim = -1) + out, _ = pack([out, sample], 'b *') + + if exists(eos_token): + is_eos_tokens = (out == eos_token) + + if is_eos_tokens.any(dim = -1).all(): + # mask out everything after the eos tokens + shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) + mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 + out = out.masked_fill(mask, pad_value) + break + + out, = unpack(out, leading_dims, '* n') + + if not return_seq_without_prompt: + return out + + return out[..., n:] + + def forward( + self, + x, + return_loss = False, + disable_lora = False, + finetune_scope = None, + extra_embed = None, + return_only_embedding = False, + return_logits_with_embedding = False + ): + if return_loss: + x, labels = x[:, :-1], x[:, 1:] + + # mask if encoder + # treat any token ids that are negative as tokens to mask out - only needed if not autoregressive + + if not self.causal: + mask = x >= 0 + x = x.masked_fill(~mask, 0) + else: + mask = None + + # get token embedding + x = self.token_emb(x) + + if exists(extra_embed): + x = x + extra_embed + + # finetune modules + + finetune_modules = tuple() + if exists(finetune_scope) and not disable_lora: + assert finetune_scope in self.finetune_modules + finetune_modules = self.finetune_modules[finetune_scope] + + # parallel attention / ff blocks, passing in finetuning loras + + for layer, finetune_modules in zip_longest(self.layers, finetune_modules): + x = layer(x, mask = mask, finetune_modules = finetune_modules) + + # final norm + + embeds = self.norm(x) + + if return_only_embedding: + return embeds + + # to logits + + logits = self.to_logits(x) + + ret = (logits, embeds) if return_logits_with_embedding else logits + + if not return_loss: + return ret + + logits = rearrange(logits, 'b n c -> b c n') + return F.cross_entropy(logits, labels, ignore_index = self.cross_entropy_ignore_index) diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py new file mode 100644 index 0000000..e853f43 --- /dev/null +++ b/src/rlhf/ppo.py @@ -0,0 +1,660 @@ +import math +from pathlib import Path +import copy +from tqdm import tqdm +from functools import partial +from collections import deque, namedtuple +from random import randrange + +from beartype import beartype +from beartype.typing import List, Optional, Callable, Deque + +import torch +from torch import nn +import torch.nn.functional as F + +from torch.optim import Adam +from torch.utils.data import Dataset, DataLoader +from torch.nn.utils.rnn import pad_sequence + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from palm_rlhf_pytorch.palm import PaLM +from palm_rlhf_pytorch.reward import RewardModel +from palm_rlhf_pytorch.optimizer import get_optimizer +from palm_rlhf_pytorch.utils import masked_mean, eval_decorator + +from accelerate import Accelerator + +# actor critic - PaLM with lora + +PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ + 'actions', + 'sequence', + 'mask', + 'prompt_mask', + 'action_logits', + 'values' +]) + +@beartype +class ActorCritic(nn.Module): + def __init__( + self, + palm: PaLM, + critic_palm: Optional[PaLM] = None, + pooled_values = False, + actor_lora = True, + critic_lora = True, + actor_lora_r = 8, + critic_lora_r = 8, + actor_lora_scope = 'actor', + critic_lora_scope = 'critic', + actor_dropout = 0., + critic_dropout = 0. + ): + super().__init__() + self.actor_palm = palm + + self.critic_palm = critic_palm + + if not exists(self.critic_palm): + self.critic_palm = copy.deepcopy(palm) + + self.actor_palm.set_dropout(actor_dropout) + self.critic_palm.set_dropout(critic_dropout) + + self.actor_lora = actor_lora + self.critic_lora = critic_lora + + self.actor_lora_scope = actor_lora_scope if actor_lora else None + self.critic_lora_scope = critic_lora_scope if critic_lora else None + + if self.actor_lora: + self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r) + + if self.critic_lora: + self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r) + + self.pooled_values = pooled_values + self.value_head = nn.Sequential( + nn.Linear(palm.dim, 1), + Rearrange('... 1 -> ...') + ) + + nn.init.zeros_(self.value_head[0].bias) + nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2)) + + def actor_parameters(self): + if not self.actor_lora: + return self.actor_palm.parameters() + + return [ + *self.actor_palm.finetune_parameters(self.actor_lora_scope) + ] + + def critic_parameters(self): + if not self.actor_lora: + return [*self.critic_palm.parameters(), *self.value_head.parameters()] + + return [ + *self.critic_palm.finetune_parameters(self.critic_lora_scope), + *self.value_head.parameters() + ] + + @torch.no_grad() + @eval_decorator + def generate( + self, + state, + max_seq_len, + eos_token = None, + return_values = False, + **kwargs + ): + actions = self.actor_palm.generate( + max_seq_len, + prompt = state, + eos_token = eos_token, + finetune_scope = self.actor_lora_scope, + use_tqdm = True, + **kwargs + ) + + sequence = torch.cat((state, actions), dim = -1) + action_len = actions.shape[-1] + state_len = state.shape[-1] + + prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len + prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0]) + + action_mask = ~prompt_mask + + mask = None + if exists(eos_token): + mask = ((sequence == eos_token).cumsum(dim = -1) == 0) + mask = F.pad(mask, (1, -1), value = True) # include eos token + action_mask &= mask + + action_logits, value = self.forward( + sequence, + mask = action_mask, + return_values = return_values + ) + + return PPOActionCriticReturn( + actions, + sequence, + mask, + prompt_mask, + action_logits, + value + ) + + def forward( + self, + x, + mask = None, + return_values = True + ): + action_logits = self.actor_palm( + x, + finetune_scope = self.actor_lora_scope + ) + + if not return_values: + return action_logits, None + + critic_embeds = self.critic_palm( + x, + return_only_embedding = True, + finetune_scope = self.critic_lora_scope + ) + + if self.pooled_values: + critic_embeds = shift(critic_embeds, shift = 1, dim = -2) + critic_embeds = masked_mean(critic_embeds, mask, dim = 1) + + values = self.value_head(critic_embeds) + + return action_logits, values + +# data + +Memory = namedtuple('Memory', [ + 'sequence', + 'prompt_mask', + 'mask', + 'action_prob', + 'action_log_prob', + 'reward', + 'value' +]) + +@beartype +class ExperienceDataset(Dataset): + def __init__( + self, + data: List[torch.Tensor], + device = None + ): + super().__init__() + self.data = data + self.device = device + + def __len__(self): + return self.data[0].shape[0] + + def __getitem__(self, ind): + return tuple(map(lambda t: t[ind].to(self.device), self.data)) + +def create_dataloader(data, batch_size, shuffle = True, device = None, **kwargs): + ds = ExperienceDataset(data, device = device) + return DataLoader(ds, batch_size = batch_size, shuffle = shuffle, **kwargs) + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def masked_normalize(t, eps = 1e-5, mask = None, dim = None): + dim = default(dim, tuple(range(t.ndim))) + kwargs = dict(dim = dim, keepdim = True) + + mean = masked_mean(t, mask = mask, **kwargs) + mean_centered = t - mean + var = masked_mean(mean_centered ** 2, mask = mask, **kwargs) + + return mean_centered * var.clamp(min = eps).rsqrt() + +def pad_sequence_fixed(sequences, *args, **kwargs): + first_el = sequences[0] + has_no_dimension = first_el.ndim == 0 + + # if no dimensions, add a single dimension + if has_no_dimension: + sequences = tuple(map(lambda t: t[None], sequences)) + + out = pad_sequence(sequences, *args, **kwargs) + + if has_no_dimension: + out = rearrange(out, '... 1 -> ...') + + return out + +def log(t, eps = 1e-20): + return torch.log(t.clamp(min = eps)) + +def log_prob(prob, indices): + assert prob.shape[:2] == indices.shape, f'preceding shapes of prob {prob.shape[:2]} and indices {indices.shape} must match' + return log(prob.gather(-1, indices[..., None])).squeeze(-1) + +def shift(t, value = 0, shift = 1, dim = -1): + zeros = (0, 0) * (-dim - 1) + return F.pad(t, (*zeros, shift, -shift), value = value) + +def masked_entropy(prob, dim = -1, mask = None): + entropies = (prob * log(prob)).sum(dim = -1) + return masked_mean(entropies, mask = mask).mean() + +def masked_kl_div(prob1, prob2, mask = None): + """ + need to account for variable sequence lengths, therefore not using the built-in functional version + """ + kl_divs = (prob1 * (log(prob2) - log(prob1))).sum(dim = -1) + + if not exists(mask): + return kl_divs.mean() + + return masked_mean(kl_divs, mask).mean() + +def clipped_value_loss(values, rewards, old_values, clip): + value_clipped = old_values + (values - old_values).clamp(-clip, clip) + value_loss_1 = (value_clipped.flatten() - rewards) ** 2 + value_loss_2 = (values.flatten() - rewards) ** 2 + return torch.mean(torch.max(value_loss_1, value_loss_2)) + +# rlhf trainer + +@beartype +class RLHFTrainer(nn.Module): + def __init__( + self, + *, + prompts: Optional[List[str]] = None, + prompts_path: Optional[str] = None, + prompt_token_ids: Optional[torch.Tensor] = None, + tokenizer: Callable = None, + palm: PaLM, + reward_model: RewardModel, + actor_critic: Optional[ActorCritic] = None, + actor_lr = 1e-4, + critic_lr = 1e-4, + actor_wd = 0., + critic_wd = 0., + actor_adam_eps = 1e-7, + critic_adam_eps = 1e-7, + actor_lora = True, + critic_lora = True, + actor_lora_r = 8, + critic_lora_r = 8, + critic_pooled_values = True, + actor_dropout = 0., + critic_dropout = 0., + betas = (0.9, 0.999), + max_norm = None, + eps_clip = 0.2, + value_clip = 0.4, + beta_s = .01, + pad_value = 0., + minibatch_size = 16, + epochs = 1, + kl_div_loss_weight = 0.1, # between old action probs and new action probs - not sure what the right value is + accelerate_kwargs: dict = {}, + use_lion = False + ): + super().__init__() + + self.accelerate = Accelerator(**accelerate_kwargs) + + # take care of prompts -> token ids + + assert (exists(prompts) + exists(prompts_path) + exists(prompt_token_ids)) == 1 + + if exists(prompts_path): + path = Path(prompts_path) + prompts = path.read_text().split('\n') + + if exists(prompts): + assert len(prompts) > 0, 'no prompts' + assert exists(tokenizer), 'tokenizer must be passed in if raw text prompts are given' + prompt_token_ids = tokenizer(prompts) + + self.pad_value = pad_value # token pad value + self.num_prompts = prompt_token_ids.shape[0] + self.register_buffer('prompt_token_ids', prompt_token_ids) + + # models + + self.palm = palm + + if not exists(actor_critic): + actor_critic = ActorCritic( + palm = palm, + actor_lora = actor_lora, + critic_lora = critic_lora, + actor_lora_r = actor_lora_r, + critic_lora_r = critic_lora_r, + pooled_values = critic_pooled_values, + actor_dropout = actor_dropout, + critic_dropout = critic_dropout + ).to(palm.device) + + self.actor_critic = actor_critic + + self.reward_model = reward_model.eval() + + # train hyperparameters + + self.epochs = epochs + self.minibatch_size = minibatch_size + self.max_norm = max_norm + + self.kl_div_loss_weight = kl_div_loss_weight + + # optimizers + + self.actor_optim = get_optimizer(actor_critic.actor_parameters(), lr = actor_lr, wd = actor_wd, betas = betas, eps = actor_adam_eps, use_lion = use_lion) + self.critic_optim = get_optimizer(actor_critic.critic_parameters(), lr = critic_lr, wd = critic_wd, betas = betas, eps = critic_adam_eps, use_lion = use_lion) + + # ppo hyperparams + + self.eps_clip = eps_clip + self.value_clip = value_clip + self.beta_s = beta_s + + # prepare with accelerator + + ( + self.actor_critic, + self.reward_model, + self.actor_optim, + self.critic_optim + ) = self.accelerate.prepare( + self.actor_critic, + self.reward_model, + self.actor_optim, + self.critic_optim + ) + + + def print(self, msg): + return self.accelerate.print(msg) + + def save(self, filepath = './checkpoint.pt'): + torch.save(self.actor_critic.state_dict(), filepath) + + def load(self, filepath = './checkpoint.pt'): + state_dict = torch.load(filepath) + self.actor_critic.load_state_dict(state_dict) + + @property + def device(self): + return self.accelerate.device + + @torch.no_grad() + def generate( + self, + max_seq_len, + *args, + prompt, + num_samples = 4, # sample 4 per prompt and select the one with highest reward + **kwargs + ): + assert prompt.ndim == 1, 'only one prompt allowed at a time for now' + prompt = repeat(prompt, 'n -> b n', b = num_samples) + + actor_critic = self.accelerate.unwrap_model(self.actor_critic) + reward_model = self.accelerate.unwrap_model(self.reward_model) + + actor_critic.eval() + + ( + actions, + sequences, + mask, + prompt_mask, + action_logits, + _ + ) = actor_critic.generate( + prompt, + *args, + max_seq_len = max_seq_len, + return_values = False, + **kwargs + ) + + rewards = reward_model( + sequences, + prompt_mask = prompt_mask, + mask = mask, + sample = True + ) + + best_sequence_index = rewards.topk(1, dim = -1).indices + + best_sequence = sequences[best_sequence_index] + best_sequence = rearrange(best_sequence, '1 ... -> ...') + + return best_sequence + + def learn( + self, + memories: Deque[Memory] + ): + # stack all data stored in the memories + + all_memories_stacked_and_padded = list(map(partial(pad_sequence_fixed, batch_first = True), zip(*memories))) + + # prepare dataloader for policy phase training + + dl = create_dataloader(all_memories_stacked_and_padded, self.minibatch_size, device = self.device) + + self.actor_critic.train() + + # PPO training + + for _ in range(self.epochs): + for ( + sequences, + prompt_masks, + masks, + old_action_probs, + old_log_probs, + rewards, + old_values + ) in dl: + action_masks = ~prompt_masks & masks + + action_logits, values = self.actor_critic( + sequences, + mask = action_masks + ) + + action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token + action_len = old_log_probs.shape[-1] + + action_probs = action_logits.softmax(dim = -1) + action_log_probs = log_prob(action_probs, sequences) + action_log_probs = action_log_probs[:, -action_len:] + + # calculate entropies, taking into account which part of the sequence is actually an action + + entropies = masked_entropy(action_probs, mask = action_masks) + + # calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not + + kl_div_loss = 0. + + if self.kl_div_loss_weight > 0: + kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.kl_div_loss_weight + + # handle non-pooled values + + normalize_kwargs = dict() + + if old_values.ndim == 2: + old_values, values = map(lambda t: shift(t, shift = 1, dim = -2), (old_values, values)) + + old_values = old_values[:, -action_len:] + values = values[:, -action_len:] + rewards = rearrange(rewards, 'b -> b 1') + normalize_kwargs = dict(dim = -1, mask = action_masks[:, -action_len:]) + + if values.ndim < rewards.ndim: + values = rearrange(values, '... -> ... 1') + + # calculate clipped surrogate objective, classic PPO loss + + ratios = (action_log_probs - old_log_probs).exp() + advantages = masked_normalize(rewards - old_values, **normalize_kwargs) + + if advantages.ndim == 1: + advantages = rearrange(advantages, 'b -> b 1') + + surr1 = ratios * advantages + surr2 = ratios.clamp(1 - self.eps_clip, 1 + self.eps_clip) * advantages + policy_loss = - torch.min(surr1, surr2) - self.beta_s * entropies + + # combine losses + + loss = policy_loss.mean() + kl_div_loss + + # update actor + + self.accelerate.backward(loss) + + self.print(f'policy_loss: {loss.item():.3f}') + + if exists(self.max_norm): + self.accelerator.clip_grad_norm_(self.actor_critic.actor_parameters(), self.max_norm) + + self.actor_optim.step() + self.actor_optim.zero_grad() + + # calculate value loss and update value network separate from policy network + + value_loss = clipped_value_loss(values, rewards, old_values, self.value_clip) + value_loss = value_loss.mean() + + self.print(f'critic_loss: {value_loss.item():.3f}') + + self.accelerate.backward(value_loss) + + if exists(self.max_norm): + self.accelerator.clip_grad_norm_(self.actor_critic.critic_parameters(), self.max_norm) + + self.critic_optim.step() + self.critic_optim.zero_grad() + + def train( + self, + num_episodes = 50000, + max_timesteps = 500, + update_timesteps = 5000, + max_batch_size = 16, + max_seq_len = 2048, + eos_token = None, + temperature = 1. + ): + device = self.device + + time = 0 + memories = deque([]) + + for eps in tqdm(range(num_episodes), desc = 'episodes'): + for timestep in range(max_timesteps): + time += 1 + + # select a bunch of random states (prompts) + # and get the action (sampled sequence from palm as well as the action probs) + # also calculate the reward using reward model and store + + rand_prompt_index = randrange(0, self.num_prompts) + + state = self.prompt_token_ids[rand_prompt_index] + + # remove padding from state + + state_mask = state != self.pad_value + state = state[state_mask] + + # get predicted sequence + + ( + actions, + sequence, + mask, + prompt_mask, + action_logits, + value + ) = self.actor_critic.generate( + rearrange(state, 'n -> 1 n'), + max_seq_len = max_seq_len, + eos_token = eos_token, + temperature = temperature, + return_values = True + ) + action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token + + action_prob = action_logits.softmax(dim = -1) + + action_len = actions.shape[-1] + action_log_prob = log_prob(action_prob, sequence) + action_log_prob = action_log_prob[:, -action_len:] + + actions = rearrange(actions, '1 ... -> ...') + + # get reward as given by supervised trained reward model + + sequence = torch.cat((state, actions), dim = 0) + + prompt_length = len(state) + prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length + + sequence = rearrange(sequence, 'n -> 1 n') + prompt_mask = rearrange(prompt_mask, 'n -> 1 n') + mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device) + + reward = self.reward_model( + sequence, + prompt_mask = prompt_mask, + mask = mask, + sample = True + ) + + detach_to_cpu_ = lambda t: rearrange(t.detach().cpu(), '1 ... -> ...') + + # store memory for learning + + memories.append(Memory(*map(detach_to_cpu_, ( + sequence, + prompt_mask, + mask, + action_prob, + action_log_prob, + reward, + value + )))) + + # learn from the stored memories + + if time % update_timesteps == 0: + self.learn(memories) + memories.clear() + + print('rlhf training complete') diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py new file mode 100644 index 0000000..b264e2d --- /dev/null +++ b/src/rlhf/reward.py @@ -0,0 +1,124 @@ +import copy +from pathlib import Path + +from tqdm import tqdm +from beartype import beartype +from beartype.typing import Tuple, Optional + +import torch +from torch import nn +import torch.nn.functional as F + +from einops import rearrange, repeat, reduce, pack, unpack +from einops.layers.torch import Rearrange, Reduce + +from src.rlhf.utils import masked_mean, gumbel_sample +from src.model import RWKV + +# helper functions + +def exists(val): + return val is not None + +# Reward Model - RWKV with a scalar head + +@beartype +class RewardModel(nn.Module): + def __init__( + self, + rwkv: RWKV, + dropout = 0.1, + num_binned_output = 0. + ): + super().__init__() + + # 用预训练模型初始化奖励模型 + self.rwkv = copy.deepcopy(rwkv) + self.rwkv.set_dropout(dropout) # todo(luxin) + + # 输出 token 向量的维度 + dim = rwkv.dim # todo(luxin) + + # 打分等级,如果为5,打分等级分为 [0, 1, 2, 3, 4],共 5 个等级 + self.binned_output = num_binned_output > 1 + + # todo(luxin):prompt_embed 和 response_embed 都是初始化为全0?不应该有区分么 + self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim)) + self.response_embed = nn.Parameter(torch.zeros(1, 1, dim)) + # self.response_embed = nn.Parameter(torch.ones(1, 1, dim)) + + if self.binned_output: + # 如果打分等级的类别数大于1,则为多分类问题 + self.to_pred = nn.Linear(dim, num_binned_output) + else: + # 否则,直接是一个二分类问题 + self.to_pred = nn.Sequential( + nn.Linear(dim, 1, bias = False), + Rearrange('... 1 -> ...') # 降维 + ) + + def load(self, path): + path = Path(path) + assert path.exists() + self.load_state_dict(torch.load(str(path))) + + def finetune_parameters(self): + return [ + *self.to_pred.parameters(), + *self.rwkv.parameters() + ] + + def forward( + self, + x, + mask = None, + prompt_mask = None, + prompt_lengths = None, + labels = None, + sample = False, + sample_temperature = 1. + ): + + # prompt_mask 和 prompt_lengths 只能给1个 + assert not (exists(prompt_mask) and exists(prompt_lengths)) + + # derive prompt mask from prompt lengths + if exists(prompt_lengths): + batch, seq_len = x.shape + arange = torch.arange(seq_len, device = x.device) + prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1') + + # reward model should have an understanding of which section is prompt, and which section is response + # 根据 prompt_mask 中 token 的 True 和 False,从 prompt_embed 或 response_embed 中取值 + # 如果为 True,则从 prompt_embed 中选,否则从 response_embed 中选 + extra_embed = None + if exists(prompt_mask): + extra_embed = torch.where( + rearrange(prompt_mask, 'b n -> b n 1'), + self.prompt_embed, + self.response_embed + ) + + # todo(luxin) get embeddings from rwkv + embeds = self.rwkv( + x, + extra_embed = extra_embed, + return_only_embedding = True + ) + + # 所有的 token 向量求平均,并输入到打分模块进行打分 + pooled = masked_mean(embeds, mask, dim = 1) + pred = self.to_pred(pooled) + + if sample and self.binned_output: + assert not exists(labels) + pred = gumbel_sample(pred, temperature = sample_temperature, dim = -1) + + if not exists(labels): + return pred + + # todo(luxin) 作者没有使用论文中考虑两个样本的 loss,而是单个样本的 loss + if not self.binned_output: + return F.mse_loss(pred, labels) + + return F.cross_entropy(pred, labels) diff --git a/src/rlhf/utils.py b/src/rlhf/utils.py new file mode 100644 index 0000000..8f632bd --- /dev/null +++ b/src/rlhf/utils.py @@ -0,0 +1,67 @@ +import math +import torch +from torch import einsum, nn +import torch.nn.functional as F + +from einops import rearrange + +def exists(val): + return val is not None + +# decorators + +def eval_decorator(fn): + def inner(self, *args, **kwargs): + was_training = self.training + self.eval() + out = fn(self, *args, **kwargs) + self.train(was_training) + return out + return inner + +# tensor helpers + +def log(t, eps = 1e-20): + return torch.log(t.clamp(min = eps)) + +def masked_mean(seq, mask = None, dim = 1, keepdim = False): + if not exists(mask): + return seq.mean(dim = dim) + + if seq.ndim == 3: + mask = rearrange(mask, 'b n -> b n 1') + + masked_seq = seq.masked_fill(~mask, 0.) + numer = masked_seq.sum(dim = dim, keepdim = keepdim) + denom = mask.sum(dim = dim, keepdim = keepdim) + + masked_mean = numer / denom.clamp(min = 1e-3) + masked_mean = masked_mean.masked_fill(denom == 0, 0.) + return masked_mean + +# sampling helpers + +def gumbel_noise(t): + noise = torch.zeros_like(t).uniform_(0, 1) + return -log(-log(noise)) + +def gumbel_sample(t, temperature = 1., dim = -1): + return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim) + +def top_p(logits, thres = 0.9): + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + sorted_indices_to_remove = cum_probs > (1 - thres) + sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() + sorted_indices_to_remove[:, 0] = 0 + + sorted_logits[sorted_indices_to_remove] = float('-inf') + return sorted_logits.scatter(1, sorted_indices, sorted_logits) + +def top_k(logits, thres = 0.9): + k = math.ceil((1 - thres) * logits.shape[-1]) + val, ind = torch.topk(logits, k) + probs = torch.full_like(logits, float('-inf')) + probs.scatter_(1, ind, val) + return probs diff --git a/train_rm.py b/train_rm.py new file mode 100644 index 0000000..84caa16 --- /dev/null +++ b/train_rm.py @@ -0,0 +1,24 @@ +import torch + +from src.rlhf.reward import RewardModel +from src.model import RWKV + +rwkv_model = RWKV() + +reward_model = RewardModel( + rwkv_model, + num_binned_output = 5 # 打分等级,如果为5,打分等级分为 [0, 1, 2, 3, 4],共 5 个等级 +) + +# mock data +seq = torch.randint(0, 20000, (1, 100)) +# prompt_mask = torch.zeros(1, 100).bool() # which part of the sequence is prompt, which part is response +prompt_mask = torch.cat((torch.ones(1, 50).bool(), torch.zeros(1, 50).bool()), dim=1) +labels = torch.randint(0, 5, (1,)) + +# train +loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels) +loss.backward() + +# after much training +reward = reward_model(seq, prompt_mask = prompt_mask) \ No newline at end of file -- GitLab