######################################################################################################## # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## import os, math, gc import torch import torch.nn as nn from torch.nn import functional as F import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.strategies import DeepSpeedStrategy import deepspeed from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam from tqdm import tqdm from einops import pack from einops import unpack from src.rlhf.utils import exists from src.rlhf.utils import gumbel_sample from src.rlhf.utils import top_k from src.rlhf.utils import identity # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam try: print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"]) except: os.environ["RWKV_MY_TESTING"] = '' def __nop(ob): return ob MyModule = nn.Module MyFunction = __nop if os.environ["RWKV_JIT_ON"] == "1": MyModule = torch.jit.ScriptModule MyFunction = torch.jit.script_method ######################################################################################################## # CUDA Kernel ######################################################################################################## T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM! # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice from torch.utils.cpp_extension import load wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", f"-DTmax={T_MAX}"]) class WKV(torch.autograd.Function): @staticmethod def forward(ctx, B, T, C, w, u, k, v): ctx.B = B ctx.T = T ctx.C = C assert T <= T_MAX assert B * C % min(C, 32) == 0 if "32" in os.environ["RWKV_FLOAT_MODE"]: w = -torch.exp(w.contiguous()) u = u.contiguous() k = k.contiguous() v = v.contiguous() else: w = -torch.exp(w.float().contiguous()) u = u.float().contiguous() k = k.float().contiguous() v = v.float().contiguous() ctx.save_for_backward(w, u, k, v) y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format) wkv_cuda.forward(B, T, C, w, u, k, v, y) if "32" in os.environ["RWKV_FLOAT_MODE"]: return y elif os.environ["RWKV_FLOAT_MODE"] == "fp16": return y.half() elif os.environ["RWKV_FLOAT_MODE"] == "bf16": return y.bfloat16() @staticmethod def backward(ctx, gy): B = ctx.B T = ctx.T C = ctx.C assert T <= T_MAX assert B * C % min(C, 32) == 0 w, u, k, v = ctx.saved_tensors gw = torch.zeros((B, C), device=gy.device).contiguous() gu = torch.zeros((B, C), device=gy.device).contiguous() gk = torch.zeros((B, T, C), device=gy.device).contiguous() gv = torch.zeros((B, T, C), device=gy.device).contiguous() if "32" in os.environ["RWKV_FLOAT_MODE"]: wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) else: wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) gw = torch.sum(gw, dim=0) gu = torch.sum(gu, dim=0) if "32" in os.environ["RWKV_FLOAT_MODE"]: return (None, None, None, gw, gu, gk, gv) elif os.environ["RWKV_FLOAT_MODE"] == "fp16": return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) elif os.environ["RWKV_FLOAT_MODE"] == "bf16": return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) def RUN_CUDA(B, T, C, w, u, k, v): return WKV.apply(B, T, C, w, u, k, v) ######################################################################################################## # RWKV: RWKV Time-mix + RWKV Channel-mix ######################################################################################################## class RWKV_TimeMix(MyModule): def __init__(self, args, layer_id): super().__init__() self.args = args self.layer_id = layer_id self.ctx_len = args.ctx_len self.n_embd = args.n_embd with torch.no_grad(): # fancy init ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 ddd = torch.ones(1, 1, args.n_embd) for i in range(args.n_embd): ddd[0, 0, i] = i / args.n_embd # fancy time_decay decay_speed = torch.ones(args.dim_att) for h in range(args.dim_att): decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1) self.time_decay = nn.Parameter(decay_speed) # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) # fancy time_first zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5 self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag) # fancy time_mix self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.key = nn.Linear(args.n_embd, args.dim_att, bias=False) self.value = nn.Linear(args.n_embd, args.dim_att, bias=False) self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False) self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) if 'a' in os.environ["RWKV_MY_TESTING"]: self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) d_qkv = args.n_embd // 16 self.qq = nn.Linear(args.n_embd, d_qkv, bias=False) self.kk = nn.Linear(args.n_embd, d_qkv, bias=False) self.vv = nn.Linear(args.n_embd, d_qkv, bias=False) self.oo = nn.Linear(d_qkv, args.n_embd, bias=False) with torch.no_grad(): self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) if 'a' not in os.environ["RWKV_MY_TESTING"]: @MyFunction def jit_func(self, x): xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) k = self.key(xk) v = self.value(xv) r = self.receptance(xr) sr = torch.sigmoid(r) return sr, k, v def forward(self, x): B, T, C = x.size() # x = (Batch,Time,Channel) sr, k, v = self.jit_func(x) rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v) return self.output(rwkv) if 'a' in os.environ["RWKV_MY_TESTING"]: @MyFunction def QKV(self, q, k, v): att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = att.masked_fill(self.att_mask == 0, float('-inf')) att = F.softmax(att, dim = -1) x = att @ v return x @MyFunction def jit_funcQKV(self, x): xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq) xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk) xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv) k = self.key(xk) v = self.value(xv) r = self.receptance(xr) sr = torch.sigmoid(r) qq = self.qq(xqq) kk = self.kk(xkk) vv = self.vv(xvv) return sr, k, v, qq, kk, vv def forward(self, x): B, T, C = x.size() # x = (Batch,Time,Channel) sr, k, v, qq, kk, vv = self.jit_funcQKV(x) rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v) rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv)) return rwkv ######################################################################################################## class RWKV_ChannelMix(MyModule): def __init__(self, args, layer_id): super().__init__() self.args = args self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) with torch.no_grad(): # fancy init of time_mix ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 ddd = torch.ones(1, 1, args.n_embd) for i in range(args.n_embd): ddd[0, 0, i] = i / args.n_embd self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False) self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False) self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) @MyFunction def forward(self, x): xx = self.time_shift(x) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) k = self.key(xk) k = torch.square(torch.relu(k)) kv = self.value(k) return torch.sigmoid(self.receptance(xr)) * kv class MishGLU(MyModule): def __init__(self, args, layer_id): super().__init__() self.args = args self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) with torch.no_grad(): ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) x = torch.ones(1, 1, args.n_embd) for i in range(args.n_embd): x[0, 0, i] = i / args.n_embd self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False) self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False) self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) @MyFunction def forward(self, x): xx = self.time_shift(x) xa = x * self.time_mix_k + xx * (1 - self.time_mix_k) xb = x * self.time_mix_r + xx * (1 - self.time_mix_r) a = self.aa(xa) b = self.bb(xb) return self.value(a * F.mish(b)) ######################################################################################################## # The RWKV Model with our blocks ######################################################################################################## class Block(nn.Module): def __init__(self, args, layer_id): super().__init__() self.args = args self.layer_id = layer_id self.ln1 = nn.LayerNorm(args.n_embd) self.ln2 = nn.LayerNorm(args.n_embd) if self.layer_id == 0: self.ln0 = nn.LayerNorm(args.n_embd) if args.my_pos_emb > 0: self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd))) self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd))) if self.layer_id == 0 and self.args.pre_ffn > 0: self.ffnPre = RWKV_ChannelMix(args, 0) else: self.att = RWKV_TimeMix(args, layer_id) if 'g' in os.environ["RWKV_MY_TESTING"]: self.ffn = MishGLU(args, layer_id) else: self.ffn = RWKV_ChannelMix(args, layer_id) if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: self.tiny_ln = nn.LayerNorm(args.n_embd) self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False) self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) def forward(self, x, x_emb=None): args = self.args B, T, C = x.size() if self.layer_id == 0: x = self.ln0(x) if args.my_pos_emb > 0: pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:] x = x + pos_emb if self.layer_id == 0 and args.pre_ffn > 0: x = x + self.ffnPre(self.ln1(x)) else: x = x + self.att(self.ln1(x)) x = x + self.ffn(self.ln2(x)) if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: xx = self.tiny_ln(x) q = self.tiny_q(xx)[:, :T, :] k = self.tiny_k(xx)[:, :T, :] c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5)) c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0) x = x + c @ self.tiny_v(x_emb) return x class L2Wrap(torch.autograd.Function): @staticmethod def forward(ctx, loss, y): ctx.save_for_backward(y) return loss @staticmethod def backward(ctx, grad_output): y = ctx.saved_tensors[0] # to encourage the logits to be close to 0 factor = 1e-4 / (y.shape[0] * y.shape[1]) maxx, ids = torch.max(y, -1, keepdim=True) gy = torch.zeros_like(y) gy.scatter_(-1, ids, maxx * factor) return (grad_output, gy) class RWKV(pl.LightningModule): def __init__(self, args): super().__init__() self.args = args if not hasattr(args, 'dim_att'): args.dim_att = args.n_embd if not hasattr(args, 'dim_ffn'): args.dim_ffn = args.n_embd * 4 if not hasattr(args, 'tiny_att_layer'): args.tiny_att_layer = -1 if not hasattr(args, 'tiny_att_dim'): args.tiny_att_dim = -1 self.emb = nn.Embedding(args.vocab_size, args.n_embd) self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)]) self.ln_out = nn.LayerNorm(args.n_embd) self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False) if args.head_qk > 0: self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False) self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False) self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) def configure_optimizers(self): args = self.args if args.layerwise_lr > 0: lr_1x = set() lr_2x = set() lr_3x = set() for n, p in self.named_parameters(): if "time_mix" in n: if args.my_pile_stage == 2: lr_2x.add(n) else: lr_1x.add(n) elif "time_decay" in n: if args.my_pile_stage == 2: lr_3x.add(n) else: lr_2x.add(n) elif "time_first" in n: lr_3x.add(n) else: lr_1x.add(n) lr_1x = sorted(list(lr_1x)) lr_2x = sorted(list(lr_2x)) lr_3x = sorted(list(lr_3x)) # print('1x', lr_1x) # print('2x', lr_2x) # print('3x', lr_3x) param_dict = {n: p for n, p in self.named_parameters()} if args.my_pile_stage == 2: optim_groups = [ {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, ] else: optim_groups = [ {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, ] else: optim_groups = [ {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, ] if self.deepspeed_offload: return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) @property def deepspeed_offload(self) -> bool: strategy = self.trainer.strategy if isinstance(strategy, DeepSpeedStrategy): cfg = strategy.config["zero_optimization"] return cfg.get("offload_optimizer") or cfg.get("offload_param") return False def forward(self, idx, extra_embed=None, rm_train=False, ppo_train=False): args = self.args B, T = idx.size() assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." x = self.emb(idx) x_emb = x # 给 x 加入额外的 embedding,例如在训练 RM 的时候,区分 prompt 和 response if extra_embed is not None: x_emb = x_emb + extra_embed if args.tiny_att_dim > 0: for block in self.blocks: if args.grad_cp == 1: x = deepspeed.checkpointing.checkpoint(block, x, x_emb) else: x = block(x, x_emb) else: for block in self.blocks: if args.grad_cp == 1: x = deepspeed.checkpointing.checkpoint(block, x) else: x = block(x) embeds = self.ln_out(x) # 用于 RM 模型的编码 if rm_train is True: return embeds if args.head_qk > 0: q = self.head_q(embeds)[:, :T, :] k = self.head_k(embeds)[:, :T, :] c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk) c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) if "32" in os.environ["RWKV_FLOAT_MODE"]: c = c @ F.one_hot(idx, num_classes=args.vocab_size) elif os.environ["RWKV_FLOAT_MODE"] == "fp16": c = c @ F.one_hot(idx, num_classes=args.vocab_size).half() elif os.environ["RWKV_FLOAT_MODE"] == "bf16": c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16() logits = self.head(embeds) + c else: logits = self.head(embeds) # 用于 PPO 模型 if ppo_train is True: return logits, embeds return logits @torch.no_grad() def generate( self, seq_len, prompt = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, pad_value = 0., eos_token = None, return_seq_without_prompt = True, use_tqdm = False ): ''' 生成 response,用于 ppo 模型的训练 ''' prompt, leading_dims = pack([prompt], '* n') n, out = prompt.shape[-1], prompt.clone() wrapper_fn = identity if not use_tqdm else tqdm sample_num_times = max(1, seq_len - prompt.shape[-1]) for _ in wrapper_fn(range(sample_num_times)): logits, embeds = self.forward(out, ppo_train=True) logits, embeds = logits[:, -1], embeds[:, -1] if exists(filter_logits_fn): logits = filter_logits_fn(logits, thres = filter_thres) sample = gumbel_sample(logits, temperature = temperature, dim = -1) out, _ = pack([out, sample], 'b *') if exists(eos_token): is_eos_tokens = (out == eos_token) if is_eos_tokens.any(dim = -1).all(): # mask out everything after the eos tokens shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 out = out.masked_fill(mask, pad_value) break out, = unpack(out, leading_dims, '* n') if not return_seq_without_prompt: return out return out[..., n:] def training_step(self, batch, batch_idx): args = self.args if args.my_qa_mask != 1: idx, targets = batch logits = self(idx) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) else: idx, targets, mask = batch mask = mask.view(-1) sum_mask = torch.sum(mask).item() # if sum_mask == 0: # return torch.tensor([0.0], requires_grad=True) logits = self(idx) if sum_mask == mask.shape[0]: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) # print('rank', self.global_rank, 'loss', loss.item()) else: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none') # loss_raw = loss loss = torch.sum(loss * mask) / sum_mask # torch.set_printoptions(threshold=10000) # if True: #self.global_rank == 1: # tmp = '' # sss = 0 # ccc = 0 # for i in range(mask.shape[0]): # if mask[i] > 0: # tmp += str(idx.view(-1)[i].item()) + ',' # sss += loss_raw.view(-1)[i].float().item() # ccc += 1 # print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx) return L2Wrap.apply(loss, logits) def training_step_end(self, batch_parts): all = self.all_gather(batch_parts) if self.trainer.is_global_zero: self.trainer.my_loss_all = all def generate_init_weight(self): print( f""" ############################################################################ # # Init model weight (slow for large models)... # ############################################################################ """ ) m = {} for n in self.state_dict(): p = self.state_dict()[n] shape = p.shape gain = 1.0 scale = 1.0 if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n: m[n] = p else: if n == "emb.weight": scale = -1 * self.args.lr_init else: if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']: if kk in n: scale = 0 if n == "head.weight": scale = 0.5 if "head_k." in n: scale = 0.1 if "head_q." in n: scale = 0 print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}") if self.args.accelerator.upper() == "GPU": m[n] = torch.empty((shape[0], shape[1]), device="cuda") else: m[n] = torch.empty((shape[0], shape[1])) if scale == 0: nn.init.zeros_(m[n]) elif scale < 0: nn.init.uniform_(m[n], a=scale, b=-scale) else: nn.init.orthogonal_(m[n], gain=gain * scale) m[n] = m[n].cpu() if os.environ["RWKV_FLOAT_MODE"] == "fp16": m[n] = m[n].half() elif os.environ["RWKV_FLOAT_MODE"] == "bf16": m[n] = m[n].bfloat16() # if n == "emb.weight": # print(m[n]) gc.collect() torch.cuda.empty_cache() return m