提交 580a99ad 编写于 作者: CSDN-Ada助手's avatar CSDN-Ada助手

修复解码问题

上级 f25b7ce4
......@@ -49,9 +49,9 @@ python train.py --load_model "rwkv-80.pth" --wandb "" --proj_dir "out" \
使用指令数据集进行监督训练,精调语言模型,指令数据集格式为句子对。这部分数据需要由开发人员来进行编写,有的语料需要涉及到推理过程。
```
python train_sft.py --load_model "rwkv-100.pth" --wandb "" --proj_dir "out_sft" \
python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft" \
--data_file "data/prompts.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
--ctx_len 2048 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
--micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \
......
......@@ -2,221 +2,213 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
print('Loading...')
import numpy as np
import os, copy, types, gc, sys
import torch
from src.utils import TOKENIZER
import numpy as np
try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except:
pass
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)
args = types.SimpleNamespace()
os.environ["RWKV_JIT_ON"] = '1'
CHAT_LANG = 'Chinese' # English Chinese
from src.model_run import RWKV_RNN
print('\n\nChatRWKV project: https://github.com/BlinkDL/ChatRWKV')
WORD_NAME = [
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
########################################################################################################
UNKNOWN_CHAR = None
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
args.RUN_DEVICE = "cuda" # cuda // cpu
# fp16 (good for GPU, does NOT support CPU) // fp32 (good for CPU) // bf16 (worse accuracy, supports CPU)
args.FLOAT_MODE = "fp16"
args = types.SimpleNamespace()
args.RUN_DEVICE = "cuda" # 'cpu' (already very fast) // 'cuda'
args.FLOAT_MODE = "fp16" # fp32 (good for CPU) // fp16 (recommended for GPU) // bf16 (less accurate)
args.vocab_size = 50277
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
os.environ["RWKV_JIT_ON"] = '1' # '1' or '0', please use torch 1.13+ and benchmark speed
args.MODEL_NAME = 'RWKV-4-Pile-1B5-EngChn-test4-20230115'
args.n_layer = 24
args.n_embd = 2048
args.ctx_len = 1024
CHAT_LANG = 'Chinese' # English // Chinese // more to come
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
# args.n_layer = 32
# args.n_embd = 4096
# args.ctx_len = 1024
QA_PROMPT = True # True: Q & A prompt // False: User & Bot prompt
# 中文问答设置QA_PROMPT=True(只能问答,问答效果更好,但不能闲聊) 中文聊天设置QA_PROMPT=False(可以闲聊,但需要大模型才适合闲聊)
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023'
# args.n_layer = 32
# args.n_embd = 2560
# args.ctx_len = 1024
# Download RWKV-4 models from https://huggingface.co/BlinkDL (don't use Instruct-test models unless you use their prompt templates)
if CHAT_LANG == 'English':
user = "User"
bot = "Bot"
interface = ":"
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230213-8019'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221110-ctx4096'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-430m/RWKV-4-Pile-430M-20220808-8066'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023'
# args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-340'
# args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/14b-run1/rwkv-6210'
args.MODEL_NAME = 'out_sft/rwkv-20'
# The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.
# The following is a conversation between a highly knowledgeable and intelligent AI called {bot}, and a human called {user}. In the following interactions, {user} and {bot} converse in natural language, and {bot} do its best to answer {user}'s questions. {bot} is respectful, polite and inclusive. {bot} knows a lot, and always tells the truth.
init_prompt = f'''
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.
elif CHAT_LANG == 'Chinese':
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-EngChn-test4-20230116'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-EngChn-test4-20230115'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-EngChn-test4-20230115'
# args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-80'
# args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/3-run1z/rwkv-170'
# args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/1.5-run1z/rwkv-0'
args.MODEL_NAME = 'out_sft/rwkv-20'
{user}{interface} french revolution what year
args.ctx_len = 1024
{bot}{interface} The French Revolution started in 1789, and lasted 10 years until 1799.
CHAT_LEN_SHORT = 40
CHAT_LEN_LONG = 150
FREE_GEN_LEN = 200
{user}{interface} 3+5=?
GEN_TEMP = 1.0
GEN_TOP_P = 0.85
{bot}{interface} The answer is 8.
AVOID_REPEAT = ',。:?!'
{user}{interface} guess i marry who ?
########################################################################################################
{bot}{interface} Only if you tell me more about yourself - what are your interests?
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
print(f'\nLoading ChatRWKV - {CHAT_LANG} - {args.RUN_DEVICE} - {args.FLOAT_MODE} - QA_PROMPT {QA_PROMPT}')
import torch
{user}{interface} solve for a: 9-a=2
# please tune these (test True/False for all of them). can significantly improve speed.
# torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True)
# torch._C._jit_override_can_fuse_on_cpu(True)
# torch._C._jit_override_can_fuse_on_gpu(True)
# torch._C._jit_set_texpr_fuser_enabled(False)
# torch._C._jit_set_nvfuser_enabled(False)
{bot}{interface} The answer is a = 7, because 9 - 7 = 2.
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
from src.model_run import RWKV_RNN
from src.utils import TOKENIZER
{user}{interface} wat is lhc
tokenizer = TOKENIZER("20B_tokenizer.json")
{bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
args.vocab_size = 50277
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
MODEL_NAME = args.MODEL_NAME
'''
HELP_MSG = '''Commands:
say something --> chat with bot. use \\n for new line.
+alt --> alternate chat reply
+reset --> reset chat
+gen YOUR PROMPT --> free generation with any prompt. use \\n for new line.
+qa YOUR QUESTION --> free generation - ask any question (just ask the question). use \\n for new line.
+more --> continue last free generation (only for +gen / +qa)
+retry --> retry last free generation (only for +gen / +qa)
Now talk with the bot and enjoy. Remember to +reset periodically to clean up the bot's memory. Use RWKV-4 14B for best results.
This is not instruct-tuned for conversation yet, so don't expect good quality. Better use +gen for free generation.
'''
elif CHAT_LANG == 'Chinese':
args.MODEL_NAME = 'RWKV-4-Pile-1B5-EngChn-test4-20230115'
args.n_layer = 24
args.n_embd = 2048
args.ctx_len = 1024
interface = ":"
user = "Q"
bot = "A"
user = "Q"
bot = "A"
interface = ":"
init_prompt = f'''
Expert Questions & Helpful Answers
init_prompt = '''
Q: 企鹅会飞吗?
Ask Research Experts
A: 企鹅是不会飞的。它们的翅膀主要用于游泳和平衡,而不是飞行。
'''
Q: 西瓜是什么
A: 西瓜是一种常见的水果,是一种多年生蔓生藤本植物。西瓜的果实呈圆形或卵形,通常是绿色的,里面有红色或黄色的肉和很多的籽。西瓜味甜,多吃可以增加水分,是夏季非常受欢迎的水果之一。
HELP_MSG = '''指令:
'''
HELP_MSG = '''指令:
直接输入内容 --> 和机器人聊天,用\\n代表换行
+alt --> 让机器人换个回答
直接输入内容 --> 和机器人聊天(建议问机器人问题),用\\n代表换行
+ --> 让机器人换个回答
+reset --> 重置对话
+++ --> 继续回答
++ --> 换个回答
+gen 某某内容 --> 续写任何中英文内容,用\\n代表换行
+qa 某某问题 --> 问独立的问题(忽略上下文),用\\n代表换行
+more --> 继续 +gen / +qa 的回答
+retry --> 换个 +gen / +qa 的回答
现在可以输入内容和机器人聊天(注意它不大懂中文,它更懂英文)。请经常使用 +reset 重置机器人记忆。
目前没有“重复惩罚”,所以机器人有时会重复,此时必须使用 + 换成正常回答,以免污染电脑记忆。
现在可以输入内容和机器人聊天(注意它不怎么懂中文,它可能更懂英文)。请经常使用 +reset 重置机器人记忆。
'''
# Load Model
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
MODEL_NAME = args.MODEL_NAME
print(f'loading... {MODEL_NAME}')
print(f'Loading model - {MODEL_NAME}')
model = RWKV_RNN(args)
model_tokens = []
model_state = None
AVOID_REPEAT_TOKENS = []
for i in AVOID_REPEAT:
dd = tokenizer.encode(i)
assert len(dd) == 1
AVOID_REPEAT_TOKENS += dd
current_state = None
########################################################################################################
def run_rnn(tokens, newline_adj = 0):
global model_tokens, current_state
for i in range(len(tokens)):
model_tokens += [int(tokens[i])]
if i == len(tokens) - 1:
out, current_state = model.forward(model_tokens, current_state)
else:
current_state = model.forward(model_tokens, current_state, preprocess_only=True)
# print(f'### model ###\n[{tokenizer.tokenizer.decode(model_tokens)}]')
def run_rnn(tokens, newline_adj=0):
global model_tokens, model_state
out[0] = -999999999 # disable <|endoftext|>
out[187] += newline_adj
tokens = [int(x) for x in tokens]
model_tokens += tokens
out, model_state = model.forward(tokens, model_state)
# print(f'### model ###\n{tokens}\n[{tokenizer.decode(model_tokens)}]')
# out[0] = -999999999 # disable <|endoftext|>
# out[187] += newline_adj # adjust \n probability
# if newline_adj > 0:
# out[15] += newline_adj / 2 # '.'
# if model_tokens[-1] in AVOID_REPEAT_TOKENS:
# out[model_tokens[-1]] = -999999999
out[0] += newline_adj
return out
all_state = {}
def save_all_stat(srv, name, last_out):
n = f'{name}_{srv}'
all_state[n] = {}
all_state[n]['out'] = last_out
all_state[n]['rnn'] = copy.deepcopy(current_state)
all_state[n]['rnn'] = copy.deepcopy(model_state)
all_state[n]['token'] = copy.deepcopy(model_tokens)
def load_all_stat(srv, name):
global model_tokens, current_state
global model_tokens, model_state
n = f'{name}_{srv}'
current_state = copy.deepcopy(all_state[n]['rnn'])
model_state = copy.deepcopy(all_state[n]['rnn'])
model_tokens = copy.deepcopy(all_state[n]['token'])
return all_state[n]['out']
########################################################################################################
# Run inference
print(f'\nRun prompt...')
out = run_rnn(tokenizer.tokenizer.encode(init_prompt))
out = run_rnn(tokenizer.encode(init_prompt))
save_all_stat('', 'chat_init', out)
gc.collect()
torch.cuda.empty_cache()
save_all_stat('', 'chat_init', out)
srv_list = ['dummy_server']
for s in srv_list:
save_all_stat(s, 'chat', out)
print(f'### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n')
print(f'### prompt ###\n[{tokenizer.decode(model_tokens)}]\n')
def reply_msg(msg):
print(f'{bot}{interface} {msg}\n')
def on_message(message):
global model_tokens, current_state
global model_tokens, model_state
srv = 'dummy_server'
msg = message.replace('\\n','\n').strip()
if len(msg) > 1000:
reply_msg('your message is too long (max 1000 tokens)')
return
msg = message.replace('\\n', '\n').strip()
# if len(msg) > 1000:
# reply_msg('your message is too long (max 1000 tokens)')
# return
x_temp = 1.0
x_top_p = 0.85
if ("-temp=" in msg):
x_temp = GEN_TEMP
x_top_p = GEN_TOP_P
if "-temp=" in msg:
x_temp = float(msg.split("-temp=")[1].split(" ")[0])
msg = msg.replace("-temp="+f'{x_temp:g}', "")
msg = msg.replace("-temp=" + f'{x_temp:g}', "")
# print(f"temp: {x_temp}")
if ("-top_p=" in msg):
if "-top_p=" in msg:
x_top_p = float(msg.split("-top_p=")[1].split(" ")[0])
msg = msg.replace("-top_p="+f'{x_top_p:g}', "")
msg = msg.replace("-top_p=" + f'{x_top_p:g}', "")
# print(f"top_p: {x_top_p}")
if x_temp <= 0.2:
x_temp = 0.2
......@@ -224,136 +216,73 @@ def on_message(message):
x_temp = 5
if x_top_p <= 0:
x_top_p = 0
if msg == '+reset':
out = load_all_stat('', 'chat_init')
save_all_stat(srv, 'chat', out)
reply_msg("Chat reset.")
return
elif msg[:5].lower() == '+gen ' or msg[:4].lower() == '+qa ' or msg.lower() == '+more' or msg.lower() == '+retry':
if msg[:5].lower() == '+gen ':
new = '\n' + msg[5:].strip()
# print(f'### prompt ###\n[{new}]')
current_state = None
out = run_rnn(tokenizer.tokenizer.encode(new))
save_all_stat(srv, 'gen_0', out)
elif msg[:4].lower() == '+qa ':
out = load_all_stat('', 'chat_init')
real_msg = msg[4:].strip()
new = f"{user}{interface} {real_msg}\n\n{bot}{interface}"
# print(f'### qa ###\n[{new}]')
out = run_rnn(tokenizer.tokenizer.encode(new))
save_all_stat(srv, 'gen_0', out)
# new = f"\nThe following is an excellent Q&A session consists of detailed and factual information.\n\nQ: What is 3+5?\nA: The answer is 8.\n\nQ: {msg[9:].strip()}\nA:"
# print(f'### prompt ###\n[{new}]')
# current_state = None
# out = run_rnn(tokenizer.tokenizer.encode(new))
# save_all_stat(srv, 'gen_0', out)
elif msg.lower() == '+more':
try:
out = load_all_stat(srv, 'gen_1')
save_all_stat(srv, 'gen_0', out)
except:
return
elif msg.lower() == '+retry':
try:
out = load_all_stat(srv, 'gen_0')
except:
return
begin = len(model_tokens)
out_last = begin
for i in range(150):
token = tokenizer.sample_logits(
out,
model_tokens,
args.ctx_len,
temperature=x_temp,
top_p_usual=x_top_p,
top_p_newline=x_top_p,
)
if msg[:4].lower() == '+qa ':
out = run_rnn([token], newline_adj=-1)
else:
out = run_rnn([token])
xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
if '\ufffd' not in xxx:
print(xxx, end='', flush=True)
out_last = begin + i + 1
print('\n')
# send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip()
# print(f'### send ###\n[{send_msg}]')
# reply_msg(send_msg)
save_all_stat(srv, 'gen_1', out)
else:
if msg.lower() == '+alt':
try:
out = load_all_stat(srv, 'chat_pre')
except:
return
else:
elif msg.lower() == '+++':
try:
out = load_all_stat(srv, 'chat')
new = f"{user}{interface} {msg}\n\n{bot}{interface}"
# print(f'### add ###\n[{new}]')
out = run_rnn(tokenizer.tokenizer.encode(new), newline_adj=-999999999)
save_all_stat(srv, 'chat_pre', out)
except:
return
elif msg.lower() == '+':
try:
out = load_all_stat(srv, 'chat_pre')
except:
return
else:
# out = load_all_stat(srv, 'chat')
new = f"{user}{interface} {msg}\n\n{bot}{interface}"
out = run_rnn(tokenizer.encode(new), newline_adj=-999999999)
save_all_stat(srv, 'chat_pre', out)
begin = len(model_tokens)
out_last = begin
print(f'{bot}{interface}', end='', flush=True)
for i in range(FREE_GEN_LEN+100):
if i <= 0:
newline_adj = -999999999
elif i <= CHAT_LEN_SHORT:
newline_adj = (i - CHAT_LEN_SHORT) / 10
elif i <= CHAT_LEN_LONG:
newline_adj = 0
else:
newline_adj = (i - CHAT_LEN_LONG) * 0.25 # MUST END THE GENERATION
token = tokenizer.sample_logits(
out,
model_tokens,
args.ctx_len,
temperature=x_temp,
top_p=x_top_p,
)
if token == 0:
break
out = run_rnn([token], newline_adj=newline_adj)
xxx = tokenizer.decode(model_tokens[out_last:])
if '\ufffd' not in xxx: # avoid utf-8 display issues
print(xxx, end='', flush=True)
out_last = begin + i + 1
send_msg = tokenizer.decode(model_tokens[begin:])
# if '\n\n' in send_msg:
# send_msg = send_msg.strip()
# break
# send_msg = tokenizer.decode(model_tokens[begin:]).strip()
# if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!!
# send_msg = send_msg[:-len(f'{user}{interface}')].strip()
# break
# if send_msg.endswith(f'{bot}{interface}'):
# send_msg = send_msg[:-len(f'{bot}{interface}')].strip()
# break
print("\n")
save_all_stat(srv, 'chat', out)
begin = len(model_tokens)
out_last = begin
print(f'{bot}{interface}', end='', flush=True)
for i in range(999):
if i <= 0:
newline_adj = -999999999
elif i <= 30:
newline_adj = (i - 30) / 10
elif i <= 130:
newline_adj = 0
else:
newline_adj = (i - 130) * 0.25 # MUST END THE GENERATION
token = tokenizer.sample_logits(
out,
model_tokens,
args.ctx_len,
temperature=x_temp,
top_p_usual=x_top_p,
top_p_newline=x_top_p,
)
out = run_rnn([token], newline_adj=newline_adj)
xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
if '\ufffd' not in xxx:
print(xxx, end='', flush=True)
out_last = begin + i + 1
send_msg = tokenizer.tokenizer.decode(model_tokens[begin:])
if '\n\n' in send_msg:
send_msg = send_msg.strip()
break
# send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip()
# if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!!
# send_msg = send_msg[:-len(f'{user}{interface}')].strip()
# break
# if send_msg.endswith(f'{bot}{interface}'):
# send_msg = send_msg[:-len(f'{bot}{interface}')].strip()
# break
# print(f'{model_tokens}')
# print(f'[{tokenizer.tokenizer.decode(model_tokens)}]')
# print(f'### send ###\n[{send_msg}]')
# reply_msg(send_msg)
save_all_stat(srv, 'chat', out)
print(HELP_MSG)
......@@ -362,4 +291,4 @@ while True:
if len(msg.strip()) > 0:
on_message(msg)
else:
print('Erorr: please say something')
print('Error: please say something')
......@@ -180,9 +180,9 @@ class MyDataset(Dataset):
if args.data_type == "binidx":
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
elif args.data_type == "numpy":
dix = data[i : i + req_len]
dix = data[i: i + req_len]
else:
dix = [self.stoi[s] for s in data[i : i + req_len]]
dix = [self.stoi[s] for s in data[i: i + req_len]]
if args.my_qa_mask == 1:
if data == self.data_pile:
......@@ -235,7 +235,8 @@ class S2SDataset(Dataset):
for index, row in pf.iterrows():
question = row["question"]
answer = row["answer"]
data_list.append((self.tokenizer.tokenizer.encode(question), self.tokenizer.tokenizer.encode(answer)))
data_list.append((self.tokenizer.tokenizer.encode(question), self.tokenizer.tokenizer.encode("\n"),
self.tokenizer.tokenizer.encode(answer)))
self.data = data_list
def __len__(self):
......@@ -244,16 +245,15 @@ class S2SDataset(Dataset):
def __getitem__(self, index):
ctx_len = self.args.ctx_len
req_len = ctx_len + 1
question, answer = self.data[index]
text = question + answer
question, sep, answer = self.data[index]
text = question + sep + answer
text = text[:req_len]
text = text + [0] * (req_len - len(text))
x = torch.tensor(text[:-1], dtype=torch.long)
y = torch.tensor(text[1:], dtype=torch.long)
z = [0] * (len(question) - 1) + [1] * (ctx_len - (len(question) - 1))
z = [0] * len(question) + [1] * (ctx_len - len(question))
z = torch.tensor(z, dtype=torch.long)
return x, y, z
......@@ -2,33 +2,31 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import types
import types, math, os, gc
import torch
import math, os, gc
from torch.nn import functional as F
import torch.nn as nn
from typing import List, Dict
MyModule = nn.Module
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
MyModule = torch.nn.Module
def __nop(ob):
return ob
MyFunction = __nop
# # try torchdynamo
# import torchdynamo
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
# try torch jit --> faster for fp32, slower for fp16 (why?)
if os.environ["RWKV_JIT_ON"] == "1":
MyFunction = __nop
if int(os.environ["RWKV_JIT_ON"]) > 0:
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
RWKV_HEAD_QK_DIM = 0
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM} RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n')
print(f'\nRWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n')
DEBUG_TIME = False # True False - show trained time-coeffs
RWKV_RESCALE_LAYER = 6 # set x = x/2 every X layer (to avoid FP16 overflow)
RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer
############################################################################################################
......@@ -37,58 +35,67 @@ class RWKV_RNN(MyModule):
super().__init__()
self.args = args
self.FLOAT_MODE = args.FLOAT_MODE
if args.FLOAT_MODE == 'fp32':
self.FLOAT_MODE = torch.float
elif args.FLOAT_MODE == 'fp16':
self.FLOAT_MODE = torch.half
elif args.FLOAT_MODE == 'bf16':
self.FLOAT_MODE = torch.bfloat16
self.RUN_DEVICE = args.RUN_DEVICE
with torch.no_grad():
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
# refine weights and send to correct device
keys = list(w.keys())
if 'pos_emb_x' in keys:
w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(args.ctx_len+1, -1)[:-1,:]
keys = list(w.keys())
args.n_embd = w['emb.weight'].shape[1]
args.n_layer = 0
keys = list(w.keys()) # refine weights and send to correct device
print_need_newline = False
for x in keys:
block_id = 0
if 'blocks.' in x:
block_id = int(x.split('.')[1])
if 'att.output.weight' in x:
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
if 'ffn.value.weight' in x:
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
w[x].requires_grad = False
if x == 'emb.weight' or 'ln0' in x:
continue
block_id = int(x.split('.')[1]) if ('blocks.' in x) else 0
args.n_layer = max(args.n_layer, block_id + 1)
if '.time_' in x:
w[x] = w[x].squeeze()
if DEBUG_TIME:
print(x, w[x].numpy())
if 'key.weight' in x or 'value.weight' in x or 'receptance.weight' in x or 'output.weight' in x:
w[x] = w[x].t()
if '.time_decay' in x:
w[x] = w[x].float()
w[x] = -torch.exp(w[x])
elif '.time_first' in x:
w[x] = w[x].float()
else:
if self.FLOAT_MODE == "fp32":
w[x] = w[x].float()
elif self.FLOAT_MODE == "bf16":
w[x] = w[x].bfloat16()
elif self.FLOAT_MODE == "fp16":
w[x] = w[x].half()
w[x] = w[x].to(dtype=self.FLOAT_MODE)
w[x].requires_grad = False
if args.RUN_DEVICE == 'cuda' and x != 'emb.weight':
if args.FLOAT_MODE == 'fp16':
if 'att.output.weight' in x:
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
if 'ffn.value.weight' in x:
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
if args.RUN_DEVICE == 'cuda':
w[x] = w[x].cuda()
if ('blocks.' not in x) or ('blocks.0.' in x):
shape = w[x].shape
shape = [i for i in shape if i != 1]
if len(shape) > 1:
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}"
else:
shape = f" {str(shape[0]).rjust(5)} "
if block_id == 0:
if print_need_newline:
print('\n', end = '')
print('\n', end='')
print_need_newline = False
print(x.ljust(40), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device)
print(x.ljust(32), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device, shape)
else:
print_need_newline = True
print('.', end = '', flush = True)
print('.', end='', flush=True)
print(f'\nn_layer {args.n_layer} n_embd {args.n_embd} ctx_len {args.ctx_len}')
# store weights in self.w
keys = list(w.keys())
keys = list(w.keys()) # store weights in self.w
self.w = types.SimpleNamespace()
for x in keys:
xx = x.split('.')
......@@ -103,12 +110,20 @@ class RWKV_RNN(MyModule):
if i == len(xx) - 1:
setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]):
if xx[i+1].isdigit():
if xx[i + 1].isdigit():
setattr(here, xx[i], {})
else:
setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i])
with torch.no_grad(): # precompute embedding
try:
x = self.LN(self.w.emb.weight, self.w.blocks[0].ln0)
except:
x = F.layer_norm(self.w.emb.weight.float(), (self.args.n_embd,),
weight=self.w.blocks[0].ln0.weight.float(), bias=self.w.blocks[0].ln0.bias.float())
self.w.emb.weight = x.to(dtype=self.FLOAT_MODE)
self.eval()
gc.collect()
torch.cuda.empty_cache()
......@@ -119,119 +134,136 @@ class RWKV_RNN(MyModule):
# state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp
@MyFunction
def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
if self.FLOAT_MODE == "bf16":
xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+0] = x.float()
elif self.FLOAT_MODE == "fp16":
xk = x * time_mix_k + state[5*i+0].half() * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].half() * (1 - time_mix_r)
state[5*i+0] = x.float()
else:
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
state[5*i+0] = x
r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(kw @ xk))
kv = vw @ k
def FF_one(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw):
xx = state[5 * i + 0].to(dtype=self.FLOAT_MODE)
xk = x * time_mix_k + xx * (1 - time_mix_k)
xr = x * time_mix_r + xx * (1 - time_mix_r)
state[5 * i + 0] = x.float()
r = torch.sigmoid(xr @ rw)
k = torch.square(torch.relu(xk @ kw))
kv = k @ vw
return r * kv
@MyFunction
def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
if self.FLOAT_MODE == "bf16":
xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r)
state[5*i+1] = x.float()
elif self.FLOAT_MODE == "fp16":
xk = x * time_mix_k + state[5*i+1].half() * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1].half() * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1].half() * (1 - time_mix_r)
state[5*i+1] = x.float()
else:
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
state[5*i+1] = x
r = torch.sigmoid(rw @ xr)
k = kw @ xk
v = vw @ xv
if '16' in self.FLOAT_MODE:
kk = k.float()
vv = v.float()
else:
kk = k
vv = v
aa = state[5*i+2]
bb = state[5*i+3]
pp = state[5*i+4]
ww = time_first + kk
def FF_seq(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw):
xx = torch.cat((state[5 * i + 0].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1, :]))
xk = x * time_mix_k + xx * (1 - time_mix_k)
xr = x * time_mix_r + xx * (1 - time_mix_r)
state[5 * i + 0] = x[-1, :].float()
r = torch.sigmoid(xr @ rw)
k = torch.square(torch.relu(xk @ kw))
kv = k @ vw
return r * kv
@MyFunction
def SA_one(self, x, state, i: int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
xx = state[5 * i + 1].to(dtype=self.FLOAT_MODE)
xk = x * time_mix_k + xx * (1 - time_mix_k)
xv = x * time_mix_v + xx * (1 - time_mix_v)
xr = x * time_mix_r + xx * (1 - time_mix_r)
state[5 * i + 1] = x.float()
r = torch.sigmoid(xr @ rw)
k = (xk @ kw).float()
v = (xv @ vw).float()
aa = state[5 * i + 2]
bb = state[5 * i + 3]
pp = state[5 * i + 4]
ww = time_first + k
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * vv
a = e1 * aa + e2 * v
b = e1 * bb + e2
ww = pp + time_decay
p = torch.maximum(ww, kk)
p = torch.maximum(ww, k)
e1 = torch.exp(ww - p)
e2 = torch.exp(kk - p)
state[5*i+2] = e1 * aa + e2 * vv
state[5*i+3] = e1 * bb + e2
state[5*i+4] = p
if self.FLOAT_MODE == "bf16":
wkv = (a / b).type(torch.bfloat16)
elif self.FLOAT_MODE == "fp16":
wkv = (a / b).half()
else:
wkv = a / b
return ow @ (r * wkv)
def forward(self, ctx, state, preprocess_only = False):
e2 = torch.exp(k - p)
state[5 * i + 2] = e1 * aa + e2 * v
state[5 * i + 3] = e1 * bb + e2
state[5 * i + 4] = p
wkv = (a / b).to(dtype=self.FLOAT_MODE)
return (r * wkv) @ ow
@MyFunction
def SA_seq(self, x, state, i: int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
xx = torch.cat((state[5 * i + 1].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1, :]))
xk = x * time_mix_k + xx * (1 - time_mix_k)
xv = x * time_mix_v + xx * (1 - time_mix_v)
xr = x * time_mix_r + xx * (1 - time_mix_r)
state[5 * i + 1] = x[-1, :].float()
r = torch.sigmoid(xr @ rw)
k = (xk @ kw).float()
v = (xv @ vw).float()
aa = state[5 * i + 2]
bb = state[5 * i + 3]
pp = state[5 * i + 4]
T = x.shape[0]
for t in range(T):
ww = time_first + k[t]
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * v[t]
b = e1 * bb + e2
ww = pp + time_decay
p = torch.maximum(ww, k[t])
e1 = torch.exp(ww - p)
e2 = torch.exp(k[t] - p)
if t != T - 1:
aa = e1 * aa + e2 * v[t]
bb = e1 * bb + e2
pp = p
else:
state[5 * i + 2] = e1 * aa + e2 * v[t]
state[5 * i + 3] = e1 * bb + e2
state[5 * i + 4] = p
xx[t] = (a / b).to(dtype=self.FLOAT_MODE)
return (r * xx) @ ow
def forward(self, tokens, state, preprocess_only=False):
with torch.no_grad():
w = self.w
args = self.args
x = w.emb.weight[ctx[-1]]
seq_mode = len(tokens) > 1
x = w.emb.weight[tokens] if seq_mode else w.emb.weight[tokens[-1]]
if self.RUN_DEVICE == 'cuda':
x = x.cuda()
try:
pos_emb = w.pos_emb[len(ctx)-1]
x = x + pos_emb
except:
pass
if state == None:
state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE)
for i in range(args.n_layer):
state[5*i+4] -= 1e30
state[5 * i + 4] -= 1e30
SA = self.SA_seq if seq_mode else self.SA_one
FF = self.FF_seq if seq_mode else self.FF_one
for i in range(args.n_layer):
if i == 0:
x = self.LN(x, w.blocks[i].ln0)
ww = w.blocks[i].att
x = x + self.SA(self.LN(x, w.blocks[i].ln1), state, i,
ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay,
ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight)
x = x + SA(self.LN(x, w.blocks[i].ln1), state, i,
ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay,
ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight)
ww = w.blocks[i].ffn
x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i,
ww.time_mix_k, ww.time_mix_r,
ww.key.weight, ww.value.weight, ww.receptance.weight)
if (i+1) % RWKV_RESCALE_LAYER == 0:
x = x / 2
x = x + FF(self.LN(x, w.blocks[i].ln2), state, i,
ww.time_mix_k, ww.time_mix_r,
ww.key.weight, ww.value.weight, ww.receptance.weight)
if args.FLOAT_MODE == 'fp16':
if (i + 1) % RWKV_RESCALE_LAYER == 0:
x = x / 2
if preprocess_only:
return state
x = self.LN(x, w.ln_out)
x = self.LN(x[-1, :], w.ln_out) if seq_mode else self.LN(x, w.ln_out)
x = w.head.weight @ x
return x.float(), state
......@@ -2,6 +2,7 @@ import json, time, random, os
import numpy as np
import torch
from torch.nn import functional as F
from tokenizers import Tokenizer
time_slot = {}
time_ref = time.time_ns()
......@@ -14,8 +15,7 @@ def record_time(name):
time_slot[name] = tt
class TOKENIZER():
class TOKENIZER(object):
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
if 'list' in str(type(WORD_NAME)):
self.charMode = False
......@@ -26,6 +26,8 @@ class TOKENIZER():
from transformers import GPT2TokenizerFast
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
self.vocab_size = len(self.tokenizer)
elif 'str' in str(type(WORD_NAME)):
self.tokenizer = Tokenizer.from_file(WORD_NAME)
else:
self.charMode = True
with open(WORD_NAME + '.json', "r", encoding="utf-16le") as result_file:
......@@ -38,6 +40,12 @@ class TOKENIZER():
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
def encode(self, x):
return self.tokenizer.encode(x).ids
def decode(self, x):
return self.tokenizer.decode(x)
def refine_context(self, context):
context = context.strip().split('\n')
for c in range(len(context)):
......@@ -48,19 +56,8 @@ class TOKENIZER():
context = '\n'
return context
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
# out[self.UNKNOWN_CHAR] = -float('Inf')
lastChar = int(x[-1])
probs = F.softmax(out, dim=-1)
if self.charMode:
if self.itos[lastChar] == '\n':
top_p = top_p_newline
else:
top_p = top_p_usual
else:
top_p = top_p_usual
def sample_logits(self, logits, x, ctx_len, temperature=1.0, top_p=1.0):
probs = F.softmax(logits.float(), dim=-1)
if os.environ["RWKV_RUN_DEVICE"] == "cpu":
probs = probs.numpy()
......@@ -81,7 +78,7 @@ class TOKENIZER():
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
out = torch.multinomial(probs, num_samples=1)[0]
return out
return int(out)
def MaybeIsPrime(number):
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册