From 515600340c09f161d3c5d07548addf68b19b9489 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Thu, 9 Mar 2023 18:23:38 +0800 Subject: [PATCH] opt reward model --- forward_demo.py | 6 +- src/model.py | 6 +- src/rlhf/reward.py | 70 +++----- src/rlhf/{rwkv.py => rwkv/model.py} | 17 +- src/rlhf/rwkv/utils.py | 104 ++++++++++++ train_rm.py | 254 ++++++++++++++++++++++++++-- 6 files changed, 395 insertions(+), 62 deletions(-) rename src/rlhf/{rwkv.py => rwkv/model.py} (96%) create mode 100644 src/rlhf/rwkv/utils.py diff --git a/forward_demo.py b/forward_demo.py index 07a1240..58920ec 100644 --- a/forward_demo.py +++ b/forward_demo.py @@ -30,7 +30,7 @@ os.environ['RWKV_JIT_ON'] = '1' os.environ["RWKV_CUDA_ON"] = '0' # if '1' then compile CUDA kernel for seq mode (much faster) # from rwkv.model import RWKV # pip install rwkv -from src.rlhf.rwkv import RWKV +from src.rlhf.rwkv.model import RWKV # model = RWKV(model='./model/rwkv-190.pth', strategy='cpu fp32') model = RWKV(model='./model/RWKV-4-Pile-169M-20220807-8023.pth', strategy='cpu fp32') @@ -46,7 +46,7 @@ model = RWKV(model='./model/RWKV-4-Pile-169M-20220807-8023.pth', strategy='cpu f # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230213-8019', strategy='cuda fp16 *0+ -> cpu fp32 *1') # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221110-ctx4096', strategy='cuda:0 fp16 *25 -> cuda:1 fp16') -out, state = model.forward([187, 510, 1563, 310, 247], None) +out, state, token_embed = model.forward([187, 510, 1563, 310, 247], None) print(out.detach().cpu().numpy()) # get logits # out, state = model.forward([187, 510], None) # out, state = model.forward([1563], state) # RNN has state (use deepcopy to clone states) @@ -58,7 +58,7 @@ ipdb.set_trace() # print('\n') -# from src.utils import PIPELINE, PIPELINE_ARGS +# from src.rlhf.rwkv.utils import PIPELINE, PIPELINE_ARGS # pipeline = PIPELINE(model, "20B_tokenizer.json") # ctx = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese." diff --git a/src/model.py b/src/model.py index 621a398..8097cdf 100644 --- a/src/model.py +++ b/src/model.py @@ -429,7 +429,7 @@ class RWKV(pl.LightningModule): return cfg.get("offload_optimizer") or cfg.get("offload_param") return False - def forward(self, idx): + def forward(self, idx, extra_embed=None): args = self.args B, T = idx.size() assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." @@ -437,6 +437,10 @@ class RWKV(pl.LightningModule): x = self.emb(idx) x_emb = x + # 给 x 加入额外的 embedding,例如在训练 RM 的时候,区分 prompt 和 response + if extra_embed is not None: + x_emb = x_emb + extra_embed + if args.tiny_att_dim > 0: for block in self.blocks: if args.grad_cp == 1: diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index 2f20dc2..2b68cd3 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -13,7 +13,7 @@ from einops import rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange, Reduce from src.rlhf.utils import masked_mean, gumbel_sample -from src.model import RWKV +from src.rlhf.rwkv.model import RWKV # helper functions @@ -26,35 +26,25 @@ def exists(val): class RewardModel(nn.Module): def __init__( self, - rwkv: RWKV, - dropout = 0.1, - num_binned_output = 0. + rwkv: RWKV ): super().__init__() # 用预训练模型初始化奖励模型 - self.rwkv = copy.deepcopy(rwkv) + self.rwkv = rwkv # 输出 token 向量的维度 - dim = rwkv.dim # todo(luxin) + dim = rwkv.args.n_embd - # 打分等级,如果为5,打分等级分为 [0, 1, 2, 3, 4],共 5 个等级 - self.binned_output = num_binned_output > 1 - - # todo(luxin):prompt_embed 和 response_embed 都是初始化为全0?不应该有区分么 + # 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0 self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim)) self.response_embed = nn.Parameter(torch.zeros(1, 1, dim)) - # self.response_embed = nn.Parameter(torch.ones(1, 1, dim)) - - if self.binned_output: - # 如果打分等级的类别数大于1,则为多分类问题 - self.to_pred = nn.Linear(dim, num_binned_output) - else: - # 否则,直接是一个二分类问题 - self.to_pred = nn.Sequential( - nn.Linear(dim, 1, bias = False), - Rearrange('... 1 -> ...') # 降维 - ) + + # reward 得分计算 + self.pred_reward = nn.Sequential( + nn.Linear(dim, 1), + Rearrange('... 1 -> ...') # 降维 + ) def load(self, path): path = Path(path) @@ -72,13 +62,10 @@ class RewardModel(nn.Module): x, mask = None, prompt_mask = None, - prompt_lengths = None, - labels = None, - sample = False, - sample_temperature = 1. + prompt_lengths = None ): - # prompt_mask 和 prompt_lengths 只能给1个 + # prompt_mask 和 prompt_lengths 只能二选一 assert not (exists(prompt_mask) and exists(prompt_lengths)) # derive prompt mask from prompt lengths @@ -98,26 +85,19 @@ class RewardModel(nn.Module): self.response_embed ) - # todo(luxin) get embeddings from rwkv - embeds = self.rwkv( + # 获得最后一个 token 的 embedding + last_token_embeds = self.rwkv( x, - extra_embed = extra_embed, - return_only_embedding = True + state=None, + extra_embed=extra_embed ) # 所有的 token 向量求平均,并输入到打分模块进行打分 - pooled = masked_mean(embeds, mask, dim = 1) - pred = self.to_pred(pooled) - - if sample and self.binned_output: - assert not exists(labels) - pred = gumbel_sample(pred, temperature = sample_temperature, dim = -1) - - if not exists(labels): - return pred - - # todo(luxin) 作者没有使用论文中考虑两个样本的 loss,而是单个样本的 loss - if not self.binned_output: - return F.mse_loss(pred, labels) - - return F.cross_entropy(pred, labels) + try: + pooled = masked_mean(last_token_embeds, mask, dim = 1) + except: + import ipdb + ipdb.set_trace() + reward = self.pred_reward(pooled) + + return reward diff --git a/src/rlhf/rwkv.py b/src/rlhf/rwkv/model.py similarity index 96% rename from src/rlhf/rwkv.py rename to src/rlhf/rwkv/model.py index b44f2a0..f55e033 100644 --- a/src/rlhf/rwkv.py +++ b/src/rlhf/rwkv/model.py @@ -4,6 +4,7 @@ import types, gc, os, time import torch +import copy from torch.nn import functional as F torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True @@ -364,12 +365,13 @@ class RWKV(MyModule): ######################################################################################################## - def forward(self, tokens, state, full_output=False): + def forward(self, tokens, state, full_output=False, extra_embed=None): with torch.no_grad(): w = self.w args = self.args if state == None: + # 初始化 state state = [None] * args.n_layer * 5 for i in range(args.n_layer): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx dd = self.strategy[i] @@ -382,10 +384,15 @@ class RWKV(MyModule): state[i*5+4] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev) seq_mode = len(tokens) > 1 + import ipdb + ipdb.set_trace() # 输入:根据 idx 取每个 token 的 embedding x = w['emb.weight'][tokens if seq_mode else tokens[0]] + if extra_embed is not None: + x = x + extra_embed + # 推理:N 层的 Block(Attention, Feed Forward) for i in range(args.n_layer): bbb = f'blocks.{i}.' att = f'blocks.{i}.att.' @@ -404,8 +411,10 @@ class RWKV(MyModule): ATT = self.att_one FFN = self.ffn_one + # Tensor dtype and/or device 类型转换 x = x.to(dtype=atype, device=dev) + # Attention 层 kw = self.get_w(f'{att}key.weight', atype) vw = self.get_w(f'{att}value.weight', atype) rw = self.get_w(f'{att}receptance.weight', atype) @@ -424,6 +433,7 @@ class RWKV(MyModule): if wtype == torch.uint8 or dd.stream: del kw, vw, rw, ow + # Feed Forward 层 kw = self.get_w(f'{ffn}key.weight', atype) vw = self.get_w(f'{ffn}value.weight', atype) rw = self.get_w(f'{ffn}receptance.weight', atype) @@ -443,15 +453,18 @@ class RWKV(MyModule): if (i+1) % self.RESCALE_LAYER == 0: x = x / 2 + # 取所有 token 还是最后一个 token 的 embedding dd = self.strategy[args.n_layer] x = x[-1,:] if (seq_mode and (not full_output)) else x x = x.to(dtype=dd.atype, device=dd.device) + # 对 token embedding 进行 LayerNorm,维度不变 x = F.layer_norm(x, (args.n_embd,), weight=w['ln_out.weight'], bias=w['ln_out.bias']) + token_embed = copy.deepcopy(x) if w['head.weight'].dtype != torch.uint8: x = x @ w['head.weight'] else: x = x @ self.get_w('head.weight', dd.atype) - return x.float(), state + return x.float(), state, token_embed.float() diff --git a/src/rlhf/rwkv/utils.py b/src/rlhf/rwkv/utils.py new file mode 100644 index 0000000..5233e9c --- /dev/null +++ b/src/rlhf/rwkv/utils.py @@ -0,0 +1,104 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import json, time, random, os +import numpy as np +import torch +from torch.nn import functional as F +from tokenizers import Tokenizer + +class PIPELINE_ARGS(): + def __init__(self, temperature=1.0, top_p=0.85, top_k=0, alpha_frequency=0.2, alpha_presence=0.2, token_ban=[], token_stop=[]): + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3) + self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3) + self.token_ban = token_ban # ban the generation of some tokens + self.token_stop = token_stop # stop generation whenever you see any token here + +class PIPELINE(): + def __init__(self, model, WORD_NAME): + self.model = model + self.tokenizer = Tokenizer.from_file(WORD_NAME) + + def refine_context(self, context): + context = context.strip().split('\n') + for c in range(len(context)): + context[c] = context[c].strip().strip('\u3000').strip('\r') + context = list(filter(lambda c: c != '', context)) + context = '\n' + ('\n'.join(context)).strip() + if context == '': + context = '\n' + return context + + def encode(self, x): + return self.tokenizer.encode(x).ids + + def decode(self, x): + return self.tokenizer.decode(x) + + def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0): + probs = F.softmax(logits.float(), dim=-1) + top_k = int(top_k) + if probs.device == torch.device('cpu'): + probs = probs.numpy() + sorted_ids = np.argsort(probs) + sorted_probs = probs[sorted_ids][::-1] + cumulative_probs = np.cumsum(sorted_probs) + cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) + probs[probs < cutoff] = 0 + if top_k < len(probs) and top_k > 0: + probs[sorted_ids[:-top_k]] = 0 + if temperature != 1.0: + probs = probs ** (1.0 / temperature) + probs = probs / np.sum(probs) + out = np.random.choice(a=len(probs), p=probs) + return int(out) + else: + sorted_ids = torch.argsort(probs) + sorted_probs = probs[sorted_ids] + sorted_probs = torch.flip(sorted_probs, dims=(0,)) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() + cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) + probs[probs < cutoff] = 0 + if top_k < len(probs) and top_k > 0: + probs[sorted_ids[:-top_k]] = 0 + if temperature != 1.0: + probs = probs ** (1.0 / temperature) + out = torch.multinomial(probs, num_samples=1)[0] + return int(out) + + def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None): + all_tokens = [] + out_last = 0 + out_str = '' + occurrence = {} + for i in range(token_count): + + # forward & adjust prob. + out, state = self.model.forward(self.encode(ctx) if i == 0 else [token], state) + for n in args.token_ban: + out[n] = -float('inf') + for n in occurrence: + out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) + + # sampler + token = self.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k) + if token in args.token_stop: + break + all_tokens += [token] + if token not in occurrence: + occurrence[token] = 1 + else: + occurrence[token] += 1 + + # output + tmp = self.decode(all_tokens[out_last:]) + if '\ufffd' not in tmp: # is valid utf-8 string? + if callback: + callback(tmp) + out_str += tmp + out_last = i + 1 + return out_str diff --git a/train_rm.py b/train_rm.py index b4c4c61..ed6ecea 100644 --- a/train_rm.py +++ b/train_rm.py @@ -6,27 +6,259 @@ ''' # here put the import lib +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +if __name__ == "__main__": + from argparse import ArgumentParser + from pytorch_lightning import Trainer + from pytorch_lightning.utilities import rank_zero_info, rank_zero_only + + rank_zero_info("########## work in progress ##########") + + ######################################################################################################## + # + # example: train a simple L12-D768 RWKV on dummy data + # + # python train.py --load_model "" --wandb "" --proj_dir "out" \ + # --data_file "" --data_type "dummy" --vocab_size 0 \ + # --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \ + # --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \ + # --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ + # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0 + + # example: train a simple L6-D512 RWKV from scratch on enwik8 + # + # python train.py --load_model "" --wandb "" --proj_dir "out" \ + # --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \ + # --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \ + # --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \ + # --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ + # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0 + + # example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M + # + # python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \ + # --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \ + # --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \ + # --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ + # --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ + # --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0 + + # example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow + # + # python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \ + # --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \ + # --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \ + # --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ + # --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ + # --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1 + + parser = ArgumentParser() + + parser.add_argument("--load_model", default="", type=str) # full path, with .pth + parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb + parser.add_argument("--proj_dir", default="out", type=str) + parser.add_argument("--random_seed", default="-1", type=int) + + parser.add_argument("--data_file", default="", type=str) + parser.add_argument("--data_type", default="utf-8", type=str) + parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data) + + parser.add_argument("--ctx_len", default=1024, type=int) + parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps + parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final + parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x + parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs" + + parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU) + parser.add_argument("--n_layer", default=6, type=int) + parser.add_argument("--n_embd", default=512, type=int) + parser.add_argument("--dim_att", default=0, type=int) + parser.add_argument("--dim_ffn", default=0, type=int) + parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better) + parser.add_argument("--head_qk", default=0, type=int) # my headQK trick + parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim + parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer + + parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 + parser.add_argument("--lr_final", default=1e-5, type=float) + parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model + parser.add_argument("--beta1", default=0.9, type=float) + parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence + parser.add_argument("--adam_eps", default=1e-8, type=float) + + parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower + parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode + parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift + parser.add_argument("--my_pile_edecay", default=0, type=int) + parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s) + parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough + # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful) + + parser.add_argument("--my_img_version", default=0, type=str) + parser.add_argument("--my_img_size", default=0, type=int) + parser.add_argument("--my_img_bit", default=0, type=int) + parser.add_argument("--my_img_clip", default='x', type=str) + parser.add_argument("--my_img_clip_scale", default=1, type=float) + parser.add_argument("--my_img_l1_scale", default=0, type=float) + parser.add_argument("--my_img_encoder", default='x', type=str) + # parser.add_argument("--my_img_noise_scale", default=0, type=float) + parser.add_argument("--my_sample_len", default=0, type=int) + parser.add_argument("--my_ffn_shift", default=1, type=int) + parser.add_argument("--my_att_shift", default=1, type=int) + parser.add_argument("--my_pos_emb", default=0, type=int) + parser.add_argument("--load_partial", default=0, type=int) + parser.add_argument("--magic_prime", default=0, type=int) + parser.add_argument("--my_qa_mask", default=0, type=int) + parser.add_argument("--my_testing", default='', type=str) + + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args() + + ######################################################################################################## + + import os, warnings, math, datetime, sys, time + import numpy as np + import torch + from torch.utils.data import DataLoader + import deepspeed + import pytorch_lightning as pl + from pytorch_lightning import seed_everything + + if args.random_seed >= 0: + print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3) + seed_everything(args.random_seed) + + np.set_printoptions(precision=4, suppress=True, linewidth=200) + warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") + warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*") + # os.environ["WDS_SHOW_SEED"] = "1" + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") + args.enable_checkpointing = False + args.replace_sampler_ddp = False + args.logger = False + args.gradient_clip_val = 1.0 + args.num_sanity_val_steps = 0 + args.check_val_every_n_epoch = int(1e20) + args.log_every_n_steps = int(1e20) + args.max_epochs = -1 # continue forever + args.betas = (args.beta1, args.beta2) + args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz + os.environ["RWKV_T_MAX"] = str(args.ctx_len) + os.environ["RWKV_MY_TESTING"] = args.my_testing + if args.dim_att <= 0: + args.dim_att = args.n_embd + if args.dim_ffn <= 0: + args.dim_ffn = args.n_embd * 4 + + args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" + if not os.path.exists(args.proj_dir): + os.makedirs(args.proj_dir) + + samples_per_epoch = args.epoch_steps * args.real_bsz + tokens_per_epoch = samples_per_epoch * args.ctx_len + rank_zero_info( + f""" +############################################################################ +# +# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''} +# +# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir} +# +# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch +# +# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens +# +# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len +# +# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps} +# +# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer +# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions) +# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer +# +############################################################################ +""" + ) + rank_zero_info(str(vars(args)) + "\n") + + assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"] + + if args.lr_final == 0 or args.lr_init == 0: + rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n") + + assert args.precision in ["fp32", "tf32", "fp16", "bf16"] + os.environ["RWKV_FLOAT_MODE"] = args.precision + if args.precision == "fp32": + rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n") + if args.precision == "fp16": + rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n") + + os.environ["RWKV_JIT_ON"] = "1" + if "deepspeed_stage_3" in args.strategy: + os.environ["RWKV_JIT_ON"] = "0" + + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.enabled = True + if args.precision == "fp32": + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + else: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + if "32" in args.precision: + args.precision = 32 + elif args.precision == "fp16": + args.precision = 16 + else: + args.precision = "bf16" + + ######################################################################################################## + + + + + + + import torch from src.rlhf.reward import RewardModel -from src.model import RWKV +from src.rlhf.rwkv.model import RWKV -rwkv_model = RWKV() +model = "./model/RWKV-4-Pile-169M-20220807-8023.pth" +strategy = "cpu fp32" +rwkv_model = RWKV(model, strategy) +dim = rwkv_model.args.n_embd reward_model = RewardModel( - rwkv_model, - num_binned_output = 5 # 打分等级,如果为5,打分等级分为 [0, 1, 2, 3, 4],共 5 个等级 + rwkv_model ) # mock data -seq = torch.randint(0, 20000, (1, 100)) -# prompt_mask = torch.zeros(1, 100).bool() # which part of the sequence is prompt, which part is response +prompt = torch.randint(0, dim, (1, 50)) +prefer_response = torch.randint(0, dim, (1, 50)) +alter_response = torch.randint(0, dim, (1, 50)) + +prefer_pair = torch.concat((prompt, prefer_response), dim=1) +alter_pair = torch.concat((prompt, alter_response), dim=1) + +# which part of the sequence is prompt, which part is response prompt_mask = torch.cat((torch.ones(1, 50).bool(), torch.zeros(1, 50).bool()), dim=1) -labels = torch.randint(0, 5, (1,)) +# labels = torch.randint(0, 5, (1,)) # train -loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels) -loss.backward() +# loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels) +# loss.backward() + +# inference +prefer_reward = reward_model(prefer_pair, prompt_mask = prompt_mask) +alter_reward = reward_model(alter_pair, prompt_mask = prompt_mask) -# after much training -reward = reward_model(seq, prompt_mask = prompt_mask) \ No newline at end of file +print("Preferred response reward:", prefer_reward) +print("Alternate response reward:", alter_reward) \ No newline at end of file -- GitLab