diff --git a/README.md b/README.md index 6ed1d0dbbf1b36f02eccc373dc9711bb17efc2f0..e8b3018c49cdcc59a957d864d656698f8df96d70 100644 --- a/README.md +++ b/README.md @@ -49,9 +49,9 @@ python train.py --load_model "rwkv-80.pth" --wandb "" --proj_dir "out" \ 使用指令数据集进行监督训练,精调语言模型,指令数据集格式为句子对。这部分数据需要由开发人员来进行编写,有的语料需要涉及到推理过程。 ``` -python train_sft.py --load_model "rwkv-100.pth" --wandb "" --proj_dir "out_sft" \ +python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft" \ --data_file "data/prompts.csv" --data_type "utf-8" --vocab_size 50277 \ ---ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \ +--ctx_len 2048 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \ --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \ diff --git a/chat.py b/chat.py index 5c198b385f324877a611966dfb136c3dc06e608b..9967248d93e4db1f419697535e7fdb0fa82bba00 100644 --- a/chat.py +++ b/chat.py @@ -2,221 +2,213 @@ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## -print('Loading...') - -import numpy as np import os, copy, types, gc, sys -import torch -from src.utils import TOKENIZER +import numpy as np + try: os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] except: pass -torch.backends.cudnn.benchmark = True -torch.backends.cudnn.allow_tf32 = True -torch.backends.cuda.matmul.allow_tf32 = True np.set_printoptions(precision=4, suppress=True, linewidth=200) +args = types.SimpleNamespace() -os.environ["RWKV_JIT_ON"] = '1' -CHAT_LANG = 'Chinese' # English Chinese - -from src.model_run import RWKV_RNN +print('\n\nChatRWKV project: https://github.com/BlinkDL/ChatRWKV') -WORD_NAME = [ - "20B_tokenizer.json", - "20B_tokenizer.json", -] # [vocab, vocab] for Pile model +######################################################################################################## -UNKNOWN_CHAR = None -tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) +args.RUN_DEVICE = "cuda" # cuda // cpu +# fp16 (good for GPU, does NOT support CPU) // fp32 (good for CPU) // bf16 (worse accuracy, supports CPU) +args.FLOAT_MODE = "fp16" -args = types.SimpleNamespace() -args.RUN_DEVICE = "cuda" # 'cpu' (already very fast) // 'cuda' -args.FLOAT_MODE = "fp16" # fp32 (good for CPU) // fp16 (recommended for GPU) // bf16 (less accurate) -args.vocab_size = 50277 -args.head_qk = 0 -args.pre_ffn = 0 -args.grad_cp = 0 -args.my_pos_emb = 0 +os.environ["RWKV_JIT_ON"] = '1' # '1' or '0', please use torch 1.13+ and benchmark speed -args.MODEL_NAME = 'RWKV-4-Pile-1B5-EngChn-test4-20230115' -args.n_layer = 24 -args.n_embd = 2048 -args.ctx_len = 1024 +CHAT_LANG = 'Chinese' # English // Chinese // more to come -# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047' -# args.n_layer = 32 -# args.n_embd = 4096 -# args.ctx_len = 1024 +QA_PROMPT = True # True: Q & A prompt // False: User & Bot prompt +# 中文问答设置QA_PROMPT=True(只能问答,问答效果更好,但不能闲聊) 中文聊天设置QA_PROMPT=False(可以闲聊,但需要大模型才适合闲聊) -# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023' -# args.n_layer = 32 -# args.n_embd = 2560 -# args.ctx_len = 1024 +# Download RWKV-4 models from https://huggingface.co/BlinkDL (don't use Instruct-test models unless you use their prompt templates) if CHAT_LANG == 'English': - user = "User" - bot = "Bot" - interface = ":" + # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230213-8019' + # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047' + # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221110-ctx4096' + # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040' + # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-430m/RWKV-4-Pile-430M-20220808-8066' + # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023' + # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-340' + # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/14b-run1/rwkv-6210' + args.MODEL_NAME = 'out_sft/rwkv-20' - # The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. - # The following is a conversation between a highly knowledgeable and intelligent AI called {bot}, and a human called {user}. In the following interactions, {user} and {bot} converse in natural language, and {bot} do its best to answer {user}'s questions. {bot} is respectful, polite and inclusive. {bot} knows a lot, and always tells the truth. - - init_prompt = f''' -The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. +elif CHAT_LANG == 'Chinese': + # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-EngChn-test4-20230116' + # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-EngChn-test4-20230115' + # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-EngChn-test4-20230115' + # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-80' + # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/3-run1z/rwkv-170' + # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/1.5-run1z/rwkv-0' + args.MODEL_NAME = 'out_sft/rwkv-20' -{user}{interface} french revolution what year +args.ctx_len = 1024 -{bot}{interface} The French Revolution started in 1789, and lasted 10 years until 1799. +CHAT_LEN_SHORT = 40 +CHAT_LEN_LONG = 150 +FREE_GEN_LEN = 200 -{user}{interface} 3+5=? +GEN_TEMP = 1.0 +GEN_TOP_P = 0.85 -{bot}{interface} The answer is 8. +AVOID_REPEAT = ',。:?!' -{user}{interface} guess i marry who ? +######################################################################################################## -{bot}{interface} Only if you tell me more about yourself - what are your interests? +os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE +print(f'\nLoading ChatRWKV - {CHAT_LANG} - {args.RUN_DEVICE} - {args.FLOAT_MODE} - QA_PROMPT {QA_PROMPT}') +import torch -{user}{interface} solve for a: 9-a=2 +# please tune these (test True/False for all of them). can significantly improve speed. +# torch._C._jit_set_profiling_executor(True) +# torch._C._jit_set_profiling_mode(True) +# torch._C._jit_override_can_fuse_on_cpu(True) +# torch._C._jit_override_can_fuse_on_gpu(True) +# torch._C._jit_set_texpr_fuser_enabled(False) +# torch._C._jit_set_nvfuser_enabled(False) -{bot}{interface} The answer is a = 7, because 9 - 7 = 2. +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True +from src.model_run import RWKV_RNN +from src.utils import TOKENIZER -{user}{interface} wat is lhc +tokenizer = TOKENIZER("20B_tokenizer.json") -{bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. +args.vocab_size = 50277 +args.head_qk = 0 +args.pre_ffn = 0 +args.grad_cp = 0 +args.my_pos_emb = 0 +MODEL_NAME = args.MODEL_NAME -''' - HELP_MSG = '''Commands: -say something --> chat with bot. use \\n for new line. -+alt --> alternate chat reply -+reset --> reset chat - -+gen YOUR PROMPT --> free generation with any prompt. use \\n for new line. -+qa YOUR QUESTION --> free generation - ask any question (just ask the question). use \\n for new line. -+more --> continue last free generation (only for +gen / +qa) -+retry --> retry last free generation (only for +gen / +qa) - -Now talk with the bot and enjoy. Remember to +reset periodically to clean up the bot's memory. Use RWKV-4 14B for best results. -This is not instruct-tuned for conversation yet, so don't expect good quality. Better use +gen for free generation. -''' -elif CHAT_LANG == 'Chinese': - args.MODEL_NAME = 'RWKV-4-Pile-1B5-EngChn-test4-20230115' - args.n_layer = 24 - args.n_embd = 2048 - args.ctx_len = 1024 +interface = ":" +user = "Q" +bot = "A" - user = "Q" - bot = "A" - interface = ":" +init_prompt = f''' +Expert Questions & Helpful Answers - init_prompt = ''' -Q: 企鹅会飞吗? +Ask Research Experts -A: 企鹅是不会飞的。它们的翅膀主要用于游泳和平衡,而不是飞行。 +''' -Q: 西瓜是什么 -A: 西瓜是一种常见的水果,是一种多年生蔓生藤本植物。西瓜的果实呈圆形或卵形,通常是绿色的,里面有红色或黄色的肉和很多的籽。西瓜味甜,多吃可以增加水分,是夏季非常受欢迎的水果之一。 +HELP_MSG = '''指令: -''' - HELP_MSG = '''指令: -直接输入内容 --> 和机器人聊天,用\\n代表换行 -+alt --> 让机器人换个回答 +直接输入内容 --> 和机器人聊天(建议问机器人问题),用\\n代表换行 ++ --> 让机器人换个回答 +reset --> 重置对话 ++++ --> 继续回答 +++ --> 换个回答 -+gen 某某内容 --> 续写任何中英文内容,用\\n代表换行 -+qa 某某问题 --> 问独立的问题(忽略上下文),用\\n代表换行 -+more --> 继续 +gen / +qa 的回答 -+retry --> 换个 +gen / +qa 的回答 +现在可以输入内容和机器人聊天(注意它不大懂中文,它更懂英文)。请经常使用 +reset 重置机器人记忆。 +目前没有“重复惩罚”,所以机器人有时会重复,此时必须使用 + 换成正常回答,以免污染电脑记忆。 -现在可以输入内容和机器人聊天(注意它不怎么懂中文,它可能更懂英文)。请经常使用 +reset 重置机器人记忆。 ''' # Load Model -os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE -MODEL_NAME = args.MODEL_NAME - -print(f'loading... {MODEL_NAME}') +print(f'Loading model - {MODEL_NAME}') model = RWKV_RNN(args) model_tokens = [] +model_state = None + +AVOID_REPEAT_TOKENS = [] +for i in AVOID_REPEAT: + dd = tokenizer.encode(i) + assert len(dd) == 1 + AVOID_REPEAT_TOKENS += dd -current_state = None ######################################################################################################## -def run_rnn(tokens, newline_adj = 0): - global model_tokens, current_state - for i in range(len(tokens)): - model_tokens += [int(tokens[i])] - if i == len(tokens) - 1: - out, current_state = model.forward(model_tokens, current_state) - else: - current_state = model.forward(model_tokens, current_state, preprocess_only=True) - - # print(f'### model ###\n[{tokenizer.tokenizer.decode(model_tokens)}]') +def run_rnn(tokens, newline_adj=0): + global model_tokens, model_state - out[0] = -999999999 # disable <|endoftext|> - out[187] += newline_adj + tokens = [int(x) for x in tokens] + model_tokens += tokens + out, model_state = model.forward(tokens, model_state) + + # print(f'### model ###\n{tokens}\n[{tokenizer.decode(model_tokens)}]') + + # out[0] = -999999999 # disable <|endoftext|> + # out[187] += newline_adj # adjust \n probability # if newline_adj > 0: # out[15] += newline_adj / 2 # '.' + # if model_tokens[-1] in AVOID_REPEAT_TOKENS: + # out[model_tokens[-1]] = -999999999 + out[0] += newline_adj return out + all_state = {} + + def save_all_stat(srv, name, last_out): n = f'{name}_{srv}' all_state[n] = {} all_state[n]['out'] = last_out - all_state[n]['rnn'] = copy.deepcopy(current_state) + all_state[n]['rnn'] = copy.deepcopy(model_state) all_state[n]['token'] = copy.deepcopy(model_tokens) + def load_all_stat(srv, name): - global model_tokens, current_state + global model_tokens, model_state n = f'{name}_{srv}' - current_state = copy.deepcopy(all_state[n]['rnn']) + model_state = copy.deepcopy(all_state[n]['rnn']) model_tokens = copy.deepcopy(all_state[n]['token']) return all_state[n]['out'] + ######################################################################################################## # Run inference print(f'\nRun prompt...') -out = run_rnn(tokenizer.tokenizer.encode(init_prompt)) +out = run_rnn(tokenizer.encode(init_prompt)) +save_all_stat('', 'chat_init', out) gc.collect() torch.cuda.empty_cache() -save_all_stat('', 'chat_init', out) - srv_list = ['dummy_server'] for s in srv_list: save_all_stat(s, 'chat', out) -print(f'### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n') +print(f'### prompt ###\n[{tokenizer.decode(model_tokens)}]\n') + def reply_msg(msg): print(f'{bot}{interface} {msg}\n') + def on_message(message): - global model_tokens, current_state + global model_tokens, model_state srv = 'dummy_server' - msg = message.replace('\\n','\n').strip() - if len(msg) > 1000: - reply_msg('your message is too long (max 1000 tokens)') - return + msg = message.replace('\\n', '\n').strip() + # if len(msg) > 1000: + # reply_msg('your message is too long (max 1000 tokens)') + # return - x_temp = 1.0 - x_top_p = 0.85 - if ("-temp=" in msg): + x_temp = GEN_TEMP + x_top_p = GEN_TOP_P + if "-temp=" in msg: x_temp = float(msg.split("-temp=")[1].split(" ")[0]) - msg = msg.replace("-temp="+f'{x_temp:g}', "") + msg = msg.replace("-temp=" + f'{x_temp:g}', "") # print(f"temp: {x_temp}") - if ("-top_p=" in msg): + if "-top_p=" in msg: x_top_p = float(msg.split("-top_p=")[1].split(" ")[0]) - msg = msg.replace("-top_p="+f'{x_top_p:g}', "") + msg = msg.replace("-top_p=" + f'{x_top_p:g}', "") # print(f"top_p: {x_top_p}") if x_temp <= 0.2: x_temp = 0.2 @@ -224,136 +216,73 @@ def on_message(message): x_temp = 5 if x_top_p <= 0: x_top_p = 0 - + if msg == '+reset': out = load_all_stat('', 'chat_init') save_all_stat(srv, 'chat', out) reply_msg("Chat reset.") return - - elif msg[:5].lower() == '+gen ' or msg[:4].lower() == '+qa ' or msg.lower() == '+more' or msg.lower() == '+retry': - - if msg[:5].lower() == '+gen ': - new = '\n' + msg[5:].strip() - # print(f'### prompt ###\n[{new}]') - current_state = None - out = run_rnn(tokenizer.tokenizer.encode(new)) - save_all_stat(srv, 'gen_0', out) - - elif msg[:4].lower() == '+qa ': - out = load_all_stat('', 'chat_init') - - real_msg = msg[4:].strip() - new = f"{user}{interface} {real_msg}\n\n{bot}{interface}" - # print(f'### qa ###\n[{new}]') - - out = run_rnn(tokenizer.tokenizer.encode(new)) - save_all_stat(srv, 'gen_0', out) - - # new = f"\nThe following is an excellent Q&A session consists of detailed and factual information.\n\nQ: What is 3+5?\nA: The answer is 8.\n\nQ: {msg[9:].strip()}\nA:" - # print(f'### prompt ###\n[{new}]') - # current_state = None - # out = run_rnn(tokenizer.tokenizer.encode(new)) - # save_all_stat(srv, 'gen_0', out) - - elif msg.lower() == '+more': - try: - out = load_all_stat(srv, 'gen_1') - save_all_stat(srv, 'gen_0', out) - except: - return - - elif msg.lower() == '+retry': - try: - out = load_all_stat(srv, 'gen_0') - except: - return - - begin = len(model_tokens) - out_last = begin - for i in range(150): - token = tokenizer.sample_logits( - out, - model_tokens, - args.ctx_len, - temperature=x_temp, - top_p_usual=x_top_p, - top_p_newline=x_top_p, - ) - if msg[:4].lower() == '+qa ': - out = run_rnn([token], newline_adj=-1) - else: - out = run_rnn([token]) - - xxx = tokenizer.tokenizer.decode(model_tokens[out_last:]) - if '\ufffd' not in xxx: - print(xxx, end='', flush=True) - out_last = begin + i + 1 - print('\n') - # send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip() - # print(f'### send ###\n[{send_msg}]') - # reply_msg(send_msg) - save_all_stat(srv, 'gen_1', out) - - else: - if msg.lower() == '+alt': - try: - out = load_all_stat(srv, 'chat_pre') - except: - return - else: + elif msg.lower() == '+++': + try: out = load_all_stat(srv, 'chat') - new = f"{user}{interface} {msg}\n\n{bot}{interface}" - # print(f'### add ###\n[{new}]') - out = run_rnn(tokenizer.tokenizer.encode(new), newline_adj=-999999999) save_all_stat(srv, 'chat_pre', out) + except: + return + elif msg.lower() == '+': + try: + out = load_all_stat(srv, 'chat_pre') + except: + return + else: + # out = load_all_stat(srv, 'chat') + new = f"{user}{interface} {msg}\n\n{bot}{interface}" + out = run_rnn(tokenizer.encode(new), newline_adj=-999999999) + save_all_stat(srv, 'chat_pre', out) + + begin = len(model_tokens) + out_last = begin + print(f'{bot}{interface}', end='', flush=True) + for i in range(FREE_GEN_LEN+100): + if i <= 0: + newline_adj = -999999999 + elif i <= CHAT_LEN_SHORT: + newline_adj = (i - CHAT_LEN_SHORT) / 10 + elif i <= CHAT_LEN_LONG: + newline_adj = 0 + else: + newline_adj = (i - CHAT_LEN_LONG) * 0.25 # MUST END THE GENERATION + token = tokenizer.sample_logits( + out, + model_tokens, + args.ctx_len, + temperature=x_temp, + top_p=x_top_p, + ) + if token == 0: + break + out = run_rnn([token], newline_adj=newline_adj) + + xxx = tokenizer.decode(model_tokens[out_last:]) + if '\ufffd' not in xxx: # avoid utf-8 display issues + print(xxx, end='', flush=True) + out_last = begin + i + 1 + + send_msg = tokenizer.decode(model_tokens[begin:]) + # if '\n\n' in send_msg: + # send_msg = send_msg.strip() + # break + + # send_msg = tokenizer.decode(model_tokens[begin:]).strip() + # if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!! + # send_msg = send_msg[:-len(f'{user}{interface}')].strip() + # break + # if send_msg.endswith(f'{bot}{interface}'): + # send_msg = send_msg[:-len(f'{bot}{interface}')].strip() + # break + + print("\n") + save_all_stat(srv, 'chat', out) - begin = len(model_tokens) - out_last = begin - print(f'{bot}{interface}', end='', flush=True) - for i in range(999): - if i <= 0: - newline_adj = -999999999 - elif i <= 30: - newline_adj = (i - 30) / 10 - elif i <= 130: - newline_adj = 0 - else: - newline_adj = (i - 130) * 0.25 # MUST END THE GENERATION - token = tokenizer.sample_logits( - out, - model_tokens, - args.ctx_len, - temperature=x_temp, - top_p_usual=x_top_p, - top_p_newline=x_top_p, - ) - out = run_rnn([token], newline_adj=newline_adj) - - xxx = tokenizer.tokenizer.decode(model_tokens[out_last:]) - if '\ufffd' not in xxx: - print(xxx, end='', flush=True) - out_last = begin + i + 1 - - send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]) - if '\n\n' in send_msg: - send_msg = send_msg.strip() - break - - # send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip() - # if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!! - # send_msg = send_msg[:-len(f'{user}{interface}')].strip() - # break - # if send_msg.endswith(f'{bot}{interface}'): - # send_msg = send_msg[:-len(f'{bot}{interface}')].strip() - # break - - # print(f'{model_tokens}') - # print(f'[{tokenizer.tokenizer.decode(model_tokens)}]') - - # print(f'### send ###\n[{send_msg}]') - # reply_msg(send_msg) - save_all_stat(srv, 'chat', out) print(HELP_MSG) @@ -362,4 +291,4 @@ while True: if len(msg.strip()) > 0: on_message(msg) else: - print('Erorr: please say something') + print('Error: please say something') diff --git a/src/dataset.py b/src/dataset.py index 29188ec151d5decd957593bae3efd83bf9894d71..8e0160b3a4517164246b31a749fdebc3303170d7 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -180,9 +180,9 @@ class MyDataset(Dataset): if args.data_type == "binidx": dix = data.get(idx=0, offset=i, length=req_len).astype(int) elif args.data_type == "numpy": - dix = data[i : i + req_len] + dix = data[i: i + req_len] else: - dix = [self.stoi[s] for s in data[i : i + req_len]] + dix = [self.stoi[s] for s in data[i: i + req_len]] if args.my_qa_mask == 1: if data == self.data_pile: @@ -235,7 +235,8 @@ class S2SDataset(Dataset): for index, row in pf.iterrows(): question = row["question"] answer = row["answer"] - data_list.append((self.tokenizer.tokenizer.encode(question), self.tokenizer.tokenizer.encode(answer))) + data_list.append((self.tokenizer.tokenizer.encode(question), self.tokenizer.tokenizer.encode("\n"), + self.tokenizer.tokenizer.encode(answer))) self.data = data_list def __len__(self): @@ -244,16 +245,15 @@ class S2SDataset(Dataset): def __getitem__(self, index): ctx_len = self.args.ctx_len req_len = ctx_len + 1 - question, answer = self.data[index] - text = question + answer + question, sep, answer = self.data[index] + text = question + sep + answer text = text[:req_len] text = text + [0] * (req_len - len(text)) x = torch.tensor(text[:-1], dtype=torch.long) - y = torch.tensor(text[1:], dtype=torch.long) - z = [0] * (len(question) - 1) + [1] * (ctx_len - (len(question) - 1)) + z = [0] * len(question) + [1] * (ctx_len - len(question)) z = torch.tensor(z, dtype=torch.long) return x, y, z diff --git a/src/model_run.py b/src/model_run.py index 2516e508ce477558b320ee0f0c67ed7d0438c674..2ad1d895cc0551785b1ee9bb333294d1a87c6a7b 100644 --- a/src/model_run.py +++ b/src/model_run.py @@ -2,33 +2,31 @@ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## -import types +import types, math, os, gc import torch -import math, os, gc from torch.nn import functional as F -import torch.nn as nn -from typing import List, Dict -MyModule = nn.Module +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True + +MyModule = torch.nn.Module + + def __nop(ob): return ob -MyFunction = __nop -# # try torchdynamo -# import torchdynamo -# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output -# try torch jit --> faster for fp32, slower for fp16 (why?) -if os.environ["RWKV_JIT_ON"] == "1": +MyFunction = __nop + +if int(os.environ["RWKV_JIT_ON"]) > 0: MyModule = torch.jit.ScriptModule MyFunction = torch.jit.script_method -RWKV_HEAD_QK_DIM = 0 -print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM} RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n') +print(f'\nRWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n') -DEBUG_TIME = False # True False - show trained time-coeffs +RWKV_RESCALE_LAYER = 6 # set x = x/2 every X layer (to avoid FP16 overflow) -RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer ############################################################################################################ @@ -37,58 +35,67 @@ class RWKV_RNN(MyModule): super().__init__() self.args = args - self.FLOAT_MODE = args.FLOAT_MODE + if args.FLOAT_MODE == 'fp32': + self.FLOAT_MODE = torch.float + elif args.FLOAT_MODE == 'fp16': + self.FLOAT_MODE = torch.half + elif args.FLOAT_MODE == 'bf16': + self.FLOAT_MODE = torch.bfloat16 self.RUN_DEVICE = args.RUN_DEVICE with torch.no_grad(): w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu') - # refine weights and send to correct device - keys = list(w.keys()) - if 'pos_emb_x' in keys: - w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(args.ctx_len+1, -1)[:-1,:] - keys = list(w.keys()) + args.n_embd = w['emb.weight'].shape[1] + args.n_layer = 0 + keys = list(w.keys()) # refine weights and send to correct device print_need_newline = False for x in keys: - block_id = 0 - if 'blocks.' in x: - block_id = int(x.split('.')[1]) - if 'att.output.weight' in x: - w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER)) - if 'ffn.value.weight' in x: - w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER)) - + w[x].requires_grad = False + if x == 'emb.weight' or 'ln0' in x: + continue + + block_id = int(x.split('.')[1]) if ('blocks.' in x) else 0 + args.n_layer = max(args.n_layer, block_id + 1) + if '.time_' in x: w[x] = w[x].squeeze() - if DEBUG_TIME: - print(x, w[x].numpy()) + if 'key.weight' in x or 'value.weight' in x or 'receptance.weight' in x or 'output.weight' in x: + w[x] = w[x].t() + if '.time_decay' in x: w[x] = w[x].float() w[x] = -torch.exp(w[x]) elif '.time_first' in x: w[x] = w[x].float() else: - if self.FLOAT_MODE == "fp32": - w[x] = w[x].float() - elif self.FLOAT_MODE == "bf16": - w[x] = w[x].bfloat16() - elif self.FLOAT_MODE == "fp16": - w[x] = w[x].half() + w[x] = w[x].to(dtype=self.FLOAT_MODE) - w[x].requires_grad = False - if args.RUN_DEVICE == 'cuda' and x != 'emb.weight': + if args.FLOAT_MODE == 'fp16': + if 'att.output.weight' in x: + w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER)) + if 'ffn.value.weight' in x: + w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER)) + + if args.RUN_DEVICE == 'cuda': w[x] = w[x].cuda() - if ('blocks.' not in x) or ('blocks.0.' in x): + shape = w[x].shape + shape = [i for i in shape if i != 1] + if len(shape) > 1: + shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}" + else: + shape = f" {str(shape[0]).rjust(5)} " + if block_id == 0: if print_need_newline: - print('\n', end = '') + print('\n', end='') print_need_newline = False - print(x.ljust(40), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device) + print(x.ljust(32), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device, shape) else: print_need_newline = True - print('.', end = '', flush = True) + print('.', end='', flush=True) + print(f'\nn_layer {args.n_layer} n_embd {args.n_embd} ctx_len {args.ctx_len}') - # store weights in self.w - keys = list(w.keys()) + keys = list(w.keys()) # store weights in self.w self.w = types.SimpleNamespace() for x in keys: xx = x.split('.') @@ -103,12 +110,20 @@ class RWKV_RNN(MyModule): if i == len(xx) - 1: setattr(here, xx[i], w[x]) elif not hasattr(here, xx[i]): - if xx[i+1].isdigit(): + if xx[i + 1].isdigit(): setattr(here, xx[i], {}) else: setattr(here, xx[i], types.SimpleNamespace()) here = getattr(here, xx[i]) + with torch.no_grad(): # precompute embedding + try: + x = self.LN(self.w.emb.weight, self.w.blocks[0].ln0) + except: + x = F.layer_norm(self.w.emb.weight.float(), (self.args.n_embd,), + weight=self.w.blocks[0].ln0.weight.float(), bias=self.w.blocks[0].ln0.bias.float()) + self.w.emb.weight = x.to(dtype=self.FLOAT_MODE) + self.eval() gc.collect() torch.cuda.empty_cache() @@ -119,119 +134,136 @@ class RWKV_RNN(MyModule): # state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp @MyFunction - def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw): - if self.FLOAT_MODE == "bf16": - xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k) - xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r) - state[5*i+0] = x.float() - elif self.FLOAT_MODE == "fp16": - xk = x * time_mix_k + state[5*i+0].half() * (1 - time_mix_k) - xr = x * time_mix_r + state[5*i+0].half() * (1 - time_mix_r) - state[5*i+0] = x.float() - else: - xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k) - xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r) - state[5*i+0] = x - - r = torch.sigmoid(rw @ xr) - k = torch.square(torch.relu(kw @ xk)) - kv = vw @ k + def FF_one(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw): + xx = state[5 * i + 0].to(dtype=self.FLOAT_MODE) + xk = x * time_mix_k + xx * (1 - time_mix_k) + xr = x * time_mix_r + xx * (1 - time_mix_r) + state[5 * i + 0] = x.float() + r = torch.sigmoid(xr @ rw) + k = torch.square(torch.relu(xk @ kw)) + kv = k @ vw return r * kv @MyFunction - def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): - if self.FLOAT_MODE == "bf16": - xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k) - xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v) - xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r) - state[5*i+1] = x.float() - elif self.FLOAT_MODE == "fp16": - xk = x * time_mix_k + state[5*i+1].half() * (1 - time_mix_k) - xv = x * time_mix_v + state[5*i+1].half() * (1 - time_mix_v) - xr = x * time_mix_r + state[5*i+1].half() * (1 - time_mix_r) - state[5*i+1] = x.float() - else: - xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k) - xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v) - xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r) - state[5*i+1] = x - - r = torch.sigmoid(rw @ xr) - k = kw @ xk - v = vw @ xv - - if '16' in self.FLOAT_MODE: - kk = k.float() - vv = v.float() - else: - kk = k - vv = v - aa = state[5*i+2] - bb = state[5*i+3] - pp = state[5*i+4] - ww = time_first + kk + def FF_seq(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw): + xx = torch.cat((state[5 * i + 0].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1, :])) + xk = x * time_mix_k + xx * (1 - time_mix_k) + xr = x * time_mix_r + xx * (1 - time_mix_r) + state[5 * i + 0] = x[-1, :].float() + + r = torch.sigmoid(xr @ rw) + k = torch.square(torch.relu(xk @ kw)) + kv = k @ vw + return r * kv + + @MyFunction + def SA_one(self, x, state, i: int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): + xx = state[5 * i + 1].to(dtype=self.FLOAT_MODE) + xk = x * time_mix_k + xx * (1 - time_mix_k) + xv = x * time_mix_v + xx * (1 - time_mix_v) + xr = x * time_mix_r + xx * (1 - time_mix_r) + state[5 * i + 1] = x.float() + + r = torch.sigmoid(xr @ rw) + k = (xk @ kw).float() + v = (xv @ vw).float() + + aa = state[5 * i + 2] + bb = state[5 * i + 3] + pp = state[5 * i + 4] + ww = time_first + k p = torch.maximum(pp, ww) e1 = torch.exp(pp - p) e2 = torch.exp(ww - p) - a = e1 * aa + e2 * vv + a = e1 * aa + e2 * v b = e1 * bb + e2 ww = pp + time_decay - p = torch.maximum(ww, kk) + p = torch.maximum(ww, k) e1 = torch.exp(ww - p) - e2 = torch.exp(kk - p) - state[5*i+2] = e1 * aa + e2 * vv - state[5*i+3] = e1 * bb + e2 - state[5*i+4] = p - if self.FLOAT_MODE == "bf16": - wkv = (a / b).type(torch.bfloat16) - elif self.FLOAT_MODE == "fp16": - wkv = (a / b).half() - else: - wkv = a / b - - return ow @ (r * wkv) - - def forward(self, ctx, state, preprocess_only = False): + e2 = torch.exp(k - p) + state[5 * i + 2] = e1 * aa + e2 * v + state[5 * i + 3] = e1 * bb + e2 + state[5 * i + 4] = p + wkv = (a / b).to(dtype=self.FLOAT_MODE) + return (r * wkv) @ ow + + @MyFunction + def SA_seq(self, x, state, i: int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): + xx = torch.cat((state[5 * i + 1].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1, :])) + xk = x * time_mix_k + xx * (1 - time_mix_k) + xv = x * time_mix_v + xx * (1 - time_mix_v) + xr = x * time_mix_r + xx * (1 - time_mix_r) + state[5 * i + 1] = x[-1, :].float() + + r = torch.sigmoid(xr @ rw) + k = (xk @ kw).float() + v = (xv @ vw).float() + + aa = state[5 * i + 2] + bb = state[5 * i + 3] + pp = state[5 * i + 4] + T = x.shape[0] + for t in range(T): + ww = time_first + k[t] + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + a = e1 * aa + e2 * v[t] + b = e1 * bb + e2 + ww = pp + time_decay + p = torch.maximum(ww, k[t]) + e1 = torch.exp(ww - p) + e2 = torch.exp(k[t] - p) + if t != T - 1: + aa = e1 * aa + e2 * v[t] + bb = e1 * bb + e2 + pp = p + else: + state[5 * i + 2] = e1 * aa + e2 * v[t] + state[5 * i + 3] = e1 * bb + e2 + state[5 * i + 4] = p + xx[t] = (a / b).to(dtype=self.FLOAT_MODE) + return (r * xx) @ ow + + def forward(self, tokens, state, preprocess_only=False): with torch.no_grad(): w = self.w args = self.args - x = w.emb.weight[ctx[-1]] + seq_mode = len(tokens) > 1 + + x = w.emb.weight[tokens] if seq_mode else w.emb.weight[tokens[-1]] if self.RUN_DEVICE == 'cuda': x = x.cuda() - try: - pos_emb = w.pos_emb[len(ctx)-1] - x = x + pos_emb - except: - pass if state == None: state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE) for i in range(args.n_layer): - state[5*i+4] -= 1e30 + state[5 * i + 4] -= 1e30 + + SA = self.SA_seq if seq_mode else self.SA_one + FF = self.FF_seq if seq_mode else self.FF_one for i in range(args.n_layer): - if i == 0: - x = self.LN(x, w.blocks[i].ln0) - ww = w.blocks[i].att - x = x + self.SA(self.LN(x, w.blocks[i].ln1), state, i, - ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay, - ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight) - + x = x + SA(self.LN(x, w.blocks[i].ln1), state, i, + ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay, + ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight) + ww = w.blocks[i].ffn - x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i, - ww.time_mix_k, ww.time_mix_r, - ww.key.weight, ww.value.weight, ww.receptance.weight) - - if (i+1) % RWKV_RESCALE_LAYER == 0: - x = x / 2 + x = x + FF(self.LN(x, w.blocks[i].ln2), state, i, + ww.time_mix_k, ww.time_mix_r, + ww.key.weight, ww.value.weight, ww.receptance.weight) + + if args.FLOAT_MODE == 'fp16': + if (i + 1) % RWKV_RESCALE_LAYER == 0: + x = x / 2 if preprocess_only: return state - x = self.LN(x, w.ln_out) + x = self.LN(x[-1, :], w.ln_out) if seq_mode else self.LN(x, w.ln_out) x = w.head.weight @ x return x.float(), state diff --git a/src/utils.py b/src/utils.py index 5d8806279559e7b818c96ed5981621eaee055189..19e7572f3406bf00af6d297afa5d02bc25f99955 100644 --- a/src/utils.py +++ b/src/utils.py @@ -2,6 +2,7 @@ import json, time, random, os import numpy as np import torch from torch.nn import functional as F +from tokenizers import Tokenizer time_slot = {} time_ref = time.time_ns() @@ -14,8 +15,7 @@ def record_time(name): time_slot[name] = tt - -class TOKENIZER(): +class TOKENIZER(object): def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): if 'list' in str(type(WORD_NAME)): self.charMode = False @@ -26,6 +26,8 @@ class TOKENIZER(): from transformers import GPT2TokenizerFast self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) self.vocab_size = len(self.tokenizer) + elif 'str' in str(type(WORD_NAME)): + self.tokenizer = Tokenizer.from_file(WORD_NAME) else: self.charMode = True with open(WORD_NAME + '.json', "r", encoding="utf-16le") as result_file: @@ -38,6 +40,12 @@ class TOKENIZER(): self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] + def encode(self, x): + return self.tokenizer.encode(x).ids + + def decode(self, x): + return self.tokenizer.decode(x) + def refine_context(self, context): context = context.strip().split('\n') for c in range(len(context)): @@ -48,19 +56,8 @@ class TOKENIZER(): context = '\n' return context - def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): - # out[self.UNKNOWN_CHAR] = -float('Inf') - lastChar = int(x[-1]) - - probs = F.softmax(out, dim=-1) - - if self.charMode: - if self.itos[lastChar] == '\n': - top_p = top_p_newline - else: - top_p = top_p_usual - else: - top_p = top_p_usual + def sample_logits(self, logits, x, ctx_len, temperature=1.0, top_p=1.0): + probs = F.softmax(logits.float(), dim=-1) if os.environ["RWKV_RUN_DEVICE"] == "cpu": probs = probs.numpy() @@ -81,7 +78,7 @@ class TOKENIZER(): if temperature != 1.0: probs = probs.pow(1.0 / temperature) out = torch.multinomial(probs, num_samples=1)[0] - return out + return int(out) def MaybeIsPrime(number): if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):