######################################################################################################## # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## import os, math, gc import torch import torch.nn as nn from torch.nn import functional as F import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.strategies import DeepSpeedStrategy import deepspeed from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam from pathlib import Path from tqdm import tqdm from einops import pack from einops import unpack from src.rlhf.utils import exists from src.rlhf.utils import gumbel_sample from src.rlhf.utils import top_k from src.rlhf.utils import eval_decorator # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam try: print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"]) except: os.environ["RWKV_MY_TESTING"] = '' def __nop(ob): return ob MyModule = nn.Module MyFunction = __nop if os.environ["RWKV_JIT_ON"] == "1": MyModule = torch.jit.ScriptModule MyFunction = torch.jit.script_method ######################################################################################################## # CUDA Kernel ######################################################################################################## T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM! # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice from torch.utils.cpp_extension import load wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", f"-DTmax={T_MAX}"]) class WKV(torch.autograd.Function): @staticmethod def forward(ctx, B, T, C, w, u, k, v): ctx.B = B ctx.T = T ctx.C = C assert T <= T_MAX assert B * C % min(C, 32) == 0 if "32" in os.environ["RWKV_FLOAT_MODE"]: w = -torch.exp(w.contiguous()) u = u.contiguous() k = k.contiguous() v = v.contiguous() else: w = -torch.exp(w.float().contiguous()) u = u.float().contiguous() k = k.float().contiguous() v = v.float().contiguous() ctx.save_for_backward(w, u, k, v) y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format) wkv_cuda.forward(B, T, C, w, u, k, v, y) if "32" in os.environ["RWKV_FLOAT_MODE"]: return y elif os.environ["RWKV_FLOAT_MODE"] == "fp16": return y.half() elif os.environ["RWKV_FLOAT_MODE"] == "bf16": return y.bfloat16() @staticmethod def backward(ctx, gy): B = ctx.B T = ctx.T C = ctx.C assert T <= T_MAX assert B * C % min(C, 32) == 0 w, u, k, v = ctx.saved_tensors gw = torch.zeros((B, C), device=gy.device).contiguous() gu = torch.zeros((B, C), device=gy.device).contiguous() gk = torch.zeros((B, T, C), device=gy.device).contiguous() gv = torch.zeros((B, T, C), device=gy.device).contiguous() if "32" in os.environ["RWKV_FLOAT_MODE"]: wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) else: wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) gw = torch.sum(gw, dim=0) gu = torch.sum(gu, dim=0) if "32" in os.environ["RWKV_FLOAT_MODE"]: return (None, None, None, gw, gu, gk, gv) elif os.environ["RWKV_FLOAT_MODE"] == "fp16": return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) elif os.environ["RWKV_FLOAT_MODE"] == "bf16": return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) def RUN_CUDA(B, T, C, w, u, k, v): return WKV.apply(B, T, C, w, u, k, v) ######################################################################################################## # RWKV: RWKV Time-mix + RWKV Channel-mix ######################################################################################################## class RWKV_TimeMix(MyModule): def __init__(self, args, layer_id): super().__init__() self.args = args self.layer_id = layer_id self.ctx_len = args.ctx_len self.n_embd = args.n_embd with torch.no_grad(): # fancy init ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 ddd = torch.ones(1, 1, args.n_embd) for i in range(args.n_embd): ddd[0, 0, i] = i / args.n_embd # fancy time_decay decay_speed = torch.ones(args.dim_att) for h in range(args.dim_att): decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1) self.time_decay = nn.Parameter(decay_speed) # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) # fancy time_first zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5 self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag) # fancy time_mix self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.key = nn.Linear(args.n_embd, args.dim_att, bias=False) self.value = nn.Linear(args.n_embd, args.dim_att, bias=False) self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False) self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) if 'a' in os.environ["RWKV_MY_TESTING"]: self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) d_qkv = args.n_embd // 16 self.qq = nn.Linear(args.n_embd, d_qkv, bias=False) self.kk = nn.Linear(args.n_embd, d_qkv, bias=False) self.vv = nn.Linear(args.n_embd, d_qkv, bias=False) self.oo = nn.Linear(d_qkv, args.n_embd, bias=False) with torch.no_grad(): self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) if 'a' not in os.environ["RWKV_MY_TESTING"]: @MyFunction def jit_func(self, x): xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) k = self.key(xk) v = self.value(xv) r = self.receptance(xr) sr = torch.sigmoid(r) return sr, k, v def forward(self, x): B, T, C = x.size() # x = (Batch,Time,Channel) sr, k, v = self.jit_func(x) rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v) return self.output(rwkv) if 'a' in os.environ["RWKV_MY_TESTING"]: @MyFunction def QKV(self, q, k, v): att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = att.masked_fill(self.att_mask == 0, float('-inf')) att = F.softmax(att, dim = -1) x = att @ v return x @MyFunction def jit_funcQKV(self, x): xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq) xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk) xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv) k = self.key(xk) v = self.value(xv) r = self.receptance(xr) sr = torch.sigmoid(r) qq = self.qq(xqq) kk = self.kk(xkk) vv = self.vv(xvv) return sr, k, v, qq, kk, vv def forward(self, x): B, T, C = x.size() # x = (Batch,Time,Channel) sr, k, v, qq, kk, vv = self.jit_funcQKV(x) rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v) rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv)) return rwkv ######################################################################################################## class RWKV_ChannelMix(MyModule): def __init__(self, args, layer_id): super().__init__() self.args = args self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) with torch.no_grad(): # fancy init of time_mix ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 ddd = torch.ones(1, 1, args.n_embd) for i in range(args.n_embd): ddd[0, 0, i] = i / args.n_embd self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False) self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False) self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) @MyFunction def forward(self, x): xx = self.time_shift(x) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) k = self.key(xk) k = torch.square(torch.relu(k)) kv = self.value(k) return torch.sigmoid(self.receptance(xr)) * kv class MishGLU(MyModule): def __init__(self, args, layer_id): super().__init__() self.args = args self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) with torch.no_grad(): ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) x = torch.ones(1, 1, args.n_embd) for i in range(args.n_embd): x[0, 0, i] = i / args.n_embd self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False) self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False) self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) @MyFunction def forward(self, x): xx = self.time_shift(x) xa = x * self.time_mix_k + xx * (1 - self.time_mix_k) xb = x * self.time_mix_r + xx * (1 - self.time_mix_r) a = self.aa(xa) b = self.bb(xb) return self.value(a * F.mish(b)) ######################################################################################################## # The RWKV Model with our blocks ######################################################################################################## class Block(nn.Module): def __init__(self, args, layer_id): super().__init__() self.args = args self.layer_id = layer_id self.ln1 = nn.LayerNorm(args.n_embd) self.ln2 = nn.LayerNorm(args.n_embd) if self.layer_id == 0: self.ln0 = nn.LayerNorm(args.n_embd) if args.my_pos_emb > 0: self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd))) self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd))) if self.layer_id == 0 and self.args.pre_ffn > 0: self.ffnPre = RWKV_ChannelMix(args, 0) else: self.att = RWKV_TimeMix(args, layer_id) if 'g' in os.environ["RWKV_MY_TESTING"]: self.ffn = MishGLU(args, layer_id) else: self.ffn = RWKV_ChannelMix(args, layer_id) if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: self.tiny_ln = nn.LayerNorm(args.n_embd) self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False) self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) def forward(self, x, x_emb=None): args = self.args B, T, C = x.size() if self.layer_id == 0: x = self.ln0(x) if args.my_pos_emb > 0: pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:] x = x + pos_emb if self.layer_id == 0 and args.pre_ffn > 0: x = x + self.ffnPre(self.ln1(x)) else: x = x + self.att(self.ln1(x)) x = x + self.ffn(self.ln2(x)) if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: xx = self.tiny_ln(x) q = self.tiny_q(xx)[:, :T, :] k = self.tiny_k(xx)[:, :T, :] c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5)) c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0) x = x + c @ self.tiny_v(x_emb) return x class L2Wrap(torch.autograd.Function): @staticmethod def forward(ctx, loss, y): ctx.save_for_backward(y) return loss @staticmethod def backward(ctx, grad_output): y = ctx.saved_tensors[0] # to encourage the logits to be close to 0 factor = 1e-4 / (y.shape[0] * y.shape[1]) maxx, ids = torch.max(y, -1, keepdim=True) gy = torch.zeros_like(y) gy.scatter_(-1, ids, maxx * factor) return (grad_output, gy) class RWKV(pl.LightningModule): def __init__(self, args): super().__init__() self.args = args if not hasattr(args, 'dim_att'): args.dim_att = args.n_embd if not hasattr(args, 'dim_ffn'): args.dim_ffn = args.n_embd * 4 if not hasattr(args, 'tiny_att_layer'): args.tiny_att_layer = -1 if not hasattr(args, 'tiny_att_dim'): args.tiny_att_dim = -1 self.emb = nn.Embedding(args.vocab_size, args.n_embd) self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)]) self.ln_out = nn.LayerNorm(args.n_embd) self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False) if args.head_qk > 0: self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False) self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False) self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) def load(self, path): path = Path(path) assert path.exists() self.load_state_dict(torch.load(str(path), map_location="cpu")) def configure_optimizers(self): args = self.args if args.layerwise_lr > 0: lr_1x = set() lr_2x = set() lr_3x = set() for n, p in self.named_parameters(): if "time_mix" in n: if args.my_pile_stage == 2: lr_2x.add(n) else: lr_1x.add(n) elif "time_decay" in n: if args.my_pile_stage == 2: lr_3x.add(n) else: lr_2x.add(n) elif "time_first" in n: lr_3x.add(n) else: lr_1x.add(n) lr_1x = sorted(list(lr_1x)) lr_2x = sorted(list(lr_2x)) lr_3x = sorted(list(lr_3x)) # print('1x', lr_1x) # print('2x', lr_2x) # print('3x', lr_3x) param_dict = {n: p for n, p in self.named_parameters()} if args.my_pile_stage == 2: optim_groups = [ {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, ] else: optim_groups = [ {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, ] else: optim_groups = [ {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, ] if self.deepspeed_offload: return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) @property def deepspeed_offload(self) -> bool: strategy = self.trainer.strategy if isinstance(strategy, DeepSpeedStrategy): cfg = strategy.config["zero_optimization"] return cfg.get("offload_optimizer") or cfg.get("offload_param") return False def forward(self, idx, extra_embed=None, rm_train=False, ppo_train=False): args = self.args B, T = idx.size() assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." x = self.emb(idx) x_emb = x # 给 x 加入额外的 embedding,例如在训练 RM 的时候,区分 prompt 和 response if extra_embed is not None: x_emb = x_emb + extra_embed if args.tiny_att_dim > 0: for block in self.blocks: if args.grad_cp == 1: x = deepspeed.checkpointing.checkpoint(block, x, x_emb) else: x = block(x, x_emb) else: for block in self.blocks: if args.grad_cp == 1: x = deepspeed.checkpointing.checkpoint(block, x) else: x = block(x) embeds = self.ln_out(x) # 用于 RM 模型的编码 if rm_train is True: return embeds if args.head_qk > 0: q = self.head_q(embeds)[:, :T, :] k = self.head_k(embeds)[:, :T, :] c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk) c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) if "32" in os.environ["RWKV_FLOAT_MODE"]: c = c @ F.one_hot(idx, num_classes=args.vocab_size) elif os.environ["RWKV_FLOAT_MODE"] == "fp16": c = c @ F.one_hot(idx, num_classes=args.vocab_size).half() elif os.environ["RWKV_FLOAT_MODE"] == "bf16": c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16() logits = self.head(embeds) + c else: logits = self.head(embeds) # 用于 PPO 模型 if ppo_train is True: return logits, embeds return logits @torch.no_grad() @eval_decorator def generate( self, seq_len, prompt = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, pad_value = 0., eos_token = None, return_seq_without_prompt = True ): ''' 生成 response,用于 ppo 模型的训练 ''' prompt, leading_dims = pack([prompt], '* n') n, out = prompt.shape[-1], prompt.clone() sample_num_times = max(1, seq_len - prompt.shape[-1]) for _ in tqdm(range(sample_num_times), desc="gen responses"): logits, embeds = self.forward(out, ppo_train=True) logits, embeds = logits[:, -1], embeds[:, -1] if exists(filter_logits_fn): logits = filter_logits_fn(logits, thres = filter_thres) sample = gumbel_sample(logits, temperature = temperature, dim = -1) out, _ = pack([out, sample], 'b *') if exists(eos_token): is_eos_tokens = (out == eos_token) if is_eos_tokens.any(dim = -1).all(): # mask out everything after the eos tokens shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 out = out.masked_fill(mask, pad_value) break out, = unpack(out, leading_dims, '* n') if not return_seq_without_prompt: return out return out[..., n:] def training_step(self, batch, batch_idx): args = self.args if args.my_qa_mask != 1: idx, targets = batch logits = self(idx) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) else: idx, targets, mask = batch mask = mask.view(-1) sum_mask = torch.sum(mask).item() # if sum_mask == 0: # return torch.tensor([0.0], requires_grad=True) logits = self(idx) if sum_mask == mask.shape[0]: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) # print('rank', self.global_rank, 'loss', loss.item()) else: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none') # loss_raw = loss loss = torch.sum(loss * mask) / sum_mask # torch.set_printoptions(threshold=10000) # if True: #self.global_rank == 1: # tmp = '' # sss = 0 # ccc = 0 # for i in range(mask.shape[0]): # if mask[i] > 0: # tmp += str(idx.view(-1)[i].item()) + ',' # sss += loss_raw.view(-1)[i].float().item() # ccc += 1 # print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx) return L2Wrap.apply(loss, logits) def training_step_end(self, batch_parts): all = self.all_gather(batch_parts) if self.trainer.is_global_zero: self.trainer.my_loss_all = all def generate_init_weight(self): print( f""" ############################################################################ # # Init model weight (slow for large models)... # ############################################################################ """ ) m = {} for n in self.state_dict(): p = self.state_dict()[n] shape = p.shape gain = 1.0 scale = 1.0 if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n: m[n] = p else: if n == "emb.weight": scale = -1 * self.args.lr_init else: if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']: if kk in n: scale = 0 if n == "head.weight": scale = 0.5 if "head_k." in n: scale = 0.1 if "head_q." in n: scale = 0 print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}") if self.args.accelerator.upper() == "GPU": m[n] = torch.empty((shape[0], shape[1]), device="cuda") else: m[n] = torch.empty((shape[0], shape[1])) if scale == 0: nn.init.zeros_(m[n]) elif scale < 0: nn.init.uniform_(m[n], a=scale, b=-scale) else: nn.init.orthogonal_(m[n], gain=gain * scale) m[n] = m[n].cpu() if os.environ["RWKV_FLOAT_MODE"] == "fp16": m[n] = m[n].half() elif os.environ["RWKV_FLOAT_MODE"] == "bf16": m[n] = m[n].bfloat16() # if n == "emb.weight": # print(m[n]) gc.collect() torch.cuda.empty_cache() return m