提交 580a99ad 编写于 作者: CSDN-Ada助手's avatar CSDN-Ada助手

修复解码问题

上级 f25b7ce4
...@@ -49,9 +49,9 @@ python train.py --load_model "rwkv-80.pth" --wandb "" --proj_dir "out" \ ...@@ -49,9 +49,9 @@ python train.py --load_model "rwkv-80.pth" --wandb "" --proj_dir "out" \
使用指令数据集进行监督训练,精调语言模型,指令数据集格式为句子对。这部分数据需要由开发人员来进行编写,有的语料需要涉及到推理过程。 使用指令数据集进行监督训练,精调语言模型,指令数据集格式为句子对。这部分数据需要由开发人员来进行编写,有的语料需要涉及到推理过程。
``` ```
python train_sft.py --load_model "rwkv-100.pth" --wandb "" --proj_dir "out_sft" \ python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft" \
--data_file "data/prompts.csv" --data_type "utf-8" --vocab_size 50277 \ --data_file "data/prompts.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \ --ctx_len 2048 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
--micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \ --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \
......
...@@ -2,221 +2,213 @@ ...@@ -2,221 +2,213 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
######################################################################################################## ########################################################################################################
print('Loading...')
import numpy as np
import os, copy, types, gc, sys import os, copy, types, gc, sys
import torch import numpy as np
from src.utils import TOKENIZER
try: try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except: except:
pass pass
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200) np.set_printoptions(precision=4, suppress=True, linewidth=200)
args = types.SimpleNamespace()
os.environ["RWKV_JIT_ON"] = '1' print('\n\nChatRWKV project: https://github.com/BlinkDL/ChatRWKV')
CHAT_LANG = 'Chinese' # English Chinese
from src.model_run import RWKV_RNN
WORD_NAME = [ ########################################################################################################
"20B_tokenizer.json",
"20B_tokenizer.json",
] # [vocab, vocab] for Pile model
UNKNOWN_CHAR = None args.RUN_DEVICE = "cuda" # cuda // cpu
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) # fp16 (good for GPU, does NOT support CPU) // fp32 (good for CPU) // bf16 (worse accuracy, supports CPU)
args.FLOAT_MODE = "fp16"
args = types.SimpleNamespace() os.environ["RWKV_JIT_ON"] = '1' # '1' or '0', please use torch 1.13+ and benchmark speed
args.RUN_DEVICE = "cuda" # 'cpu' (already very fast) // 'cuda'
args.FLOAT_MODE = "fp16" # fp32 (good for CPU) // fp16 (recommended for GPU) // bf16 (less accurate)
args.vocab_size = 50277
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
args.MODEL_NAME = 'RWKV-4-Pile-1B5-EngChn-test4-20230115' CHAT_LANG = 'Chinese' # English // Chinese // more to come
args.n_layer = 24
args.n_embd = 2048
args.ctx_len = 1024
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047' QA_PROMPT = True # True: Q & A prompt // False: User & Bot prompt
# args.n_layer = 32 # 中文问答设置QA_PROMPT=True(只能问答,问答效果更好,但不能闲聊) 中文聊天设置QA_PROMPT=False(可以闲聊,但需要大模型才适合闲聊)
# args.n_embd = 4096
# args.ctx_len = 1024
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023' # Download RWKV-4 models from https://huggingface.co/BlinkDL (don't use Instruct-test models unless you use their prompt templates)
# args.n_layer = 32
# args.n_embd = 2560
# args.ctx_len = 1024
if CHAT_LANG == 'English': if CHAT_LANG == 'English':
user = "User" # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230213-8019'
bot = "Bot" # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
interface = ":" # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221110-ctx4096'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-430m/RWKV-4-Pile-430M-20220808-8066'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023'
# args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-340'
# args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/14b-run1/rwkv-6210'
args.MODEL_NAME = 'out_sft/rwkv-20'
# The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. elif CHAT_LANG == 'Chinese':
# The following is a conversation between a highly knowledgeable and intelligent AI called {bot}, and a human called {user}. In the following interactions, {user} and {bot} converse in natural language, and {bot} do its best to answer {user}'s questions. {bot} is respectful, polite and inclusive. {bot} knows a lot, and always tells the truth. # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-EngChn-test4-20230116'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-EngChn-test4-20230115'
init_prompt = f''' # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-EngChn-test4-20230115'
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-80'
# args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/3-run1z/rwkv-170'
# args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/1.5-run1z/rwkv-0'
args.MODEL_NAME = 'out_sft/rwkv-20'
{user}{interface} french revolution what year args.ctx_len = 1024
{bot}{interface} The French Revolution started in 1789, and lasted 10 years until 1799. CHAT_LEN_SHORT = 40
CHAT_LEN_LONG = 150
FREE_GEN_LEN = 200
{user}{interface} 3+5=? GEN_TEMP = 1.0
GEN_TOP_P = 0.85
{bot}{interface} The answer is 8. AVOID_REPEAT = ',。:?!'
{user}{interface} guess i marry who ? ########################################################################################################
{bot}{interface} Only if you tell me more about yourself - what are your interests? os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
print(f'\nLoading ChatRWKV - {CHAT_LANG} - {args.RUN_DEVICE} - {args.FLOAT_MODE} - QA_PROMPT {QA_PROMPT}')
import torch
{user}{interface} solve for a: 9-a=2 # please tune these (test True/False for all of them). can significantly improve speed.
# torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True)
# torch._C._jit_override_can_fuse_on_cpu(True)
# torch._C._jit_override_can_fuse_on_gpu(True)
# torch._C._jit_set_texpr_fuser_enabled(False)
# torch._C._jit_set_nvfuser_enabled(False)
{bot}{interface} The answer is a = 7, because 9 - 7 = 2. torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
from src.model_run import RWKV_RNN
from src.utils import TOKENIZER
{user}{interface} wat is lhc tokenizer = TOKENIZER("20B_tokenizer.json")
{bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. args.vocab_size = 50277
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
MODEL_NAME = args.MODEL_NAME
''' interface = ":"
HELP_MSG = '''Commands: user = "Q"
say something --> chat with bot. use \\n for new line. bot = "A"
+alt --> alternate chat reply
+reset --> reset chat
+gen YOUR PROMPT --> free generation with any prompt. use \\n for new line.
+qa YOUR QUESTION --> free generation - ask any question (just ask the question). use \\n for new line.
+more --> continue last free generation (only for +gen / +qa)
+retry --> retry last free generation (only for +gen / +qa)
Now talk with the bot and enjoy. Remember to +reset periodically to clean up the bot's memory. Use RWKV-4 14B for best results.
This is not instruct-tuned for conversation yet, so don't expect good quality. Better use +gen for free generation.
'''
elif CHAT_LANG == 'Chinese':
args.MODEL_NAME = 'RWKV-4-Pile-1B5-EngChn-test4-20230115'
args.n_layer = 24
args.n_embd = 2048
args.ctx_len = 1024
user = "Q" init_prompt = f'''
bot = "A" Expert Questions & Helpful Answers
interface = ":"
init_prompt = ''' Ask Research Experts
Q: 企鹅会飞吗?
A: 企鹅是不会飞的。它们的翅膀主要用于游泳和平衡,而不是飞行。 '''
Q: 西瓜是什么
A: 西瓜是一种常见的水果,是一种多年生蔓生藤本植物。西瓜的果实呈圆形或卵形,通常是绿色的,里面有红色或黄色的肉和很多的籽。西瓜味甜,多吃可以增加水分,是夏季非常受欢迎的水果之一。 HELP_MSG = '''指令:
''' 直接输入内容 --> 和机器人聊天(建议问机器人问题),用\\n代表换行
HELP_MSG = '''指令: + --> 让机器人换个回答
直接输入内容 --> 和机器人聊天,用\\n代表换行
+alt --> 让机器人换个回答
+reset --> 重置对话 +reset --> 重置对话
+++ --> 继续回答
++ --> 换个回答
+gen 某某内容 --> 续写任何中英文内容,用\\n代表换行 现在可以输入内容和机器人聊天(注意它不大懂中文,它更懂英文)。请经常使用 +reset 重置机器人记忆。
+qa 某某问题 --> 问独立的问题(忽略上下文),用\\n代表换行 目前没有“重复惩罚”,所以机器人有时会重复,此时必须使用 + 换成正常回答,以免污染电脑记忆。
+more --> 继续 +gen / +qa 的回答
+retry --> 换个 +gen / +qa 的回答
现在可以输入内容和机器人聊天(注意它不怎么懂中文,它可能更懂英文)。请经常使用 +reset 重置机器人记忆。
''' '''
# Load Model # Load Model
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE print(f'Loading model - {MODEL_NAME}')
MODEL_NAME = args.MODEL_NAME
print(f'loading... {MODEL_NAME}')
model = RWKV_RNN(args) model = RWKV_RNN(args)
model_tokens = [] model_tokens = []
model_state = None
AVOID_REPEAT_TOKENS = []
for i in AVOID_REPEAT:
dd = tokenizer.encode(i)
assert len(dd) == 1
AVOID_REPEAT_TOKENS += dd
current_state = None
######################################################################################################## ########################################################################################################
def run_rnn(tokens, newline_adj = 0): def run_rnn(tokens, newline_adj=0):
global model_tokens, current_state global model_tokens, model_state
for i in range(len(tokens)):
model_tokens += [int(tokens[i])]
if i == len(tokens) - 1:
out, current_state = model.forward(model_tokens, current_state)
else:
current_state = model.forward(model_tokens, current_state, preprocess_only=True)
# print(f'### model ###\n[{tokenizer.tokenizer.decode(model_tokens)}]')
out[0] = -999999999 # disable <|endoftext|> tokens = [int(x) for x in tokens]
out[187] += newline_adj model_tokens += tokens
out, model_state = model.forward(tokens, model_state)
# print(f'### model ###\n{tokens}\n[{tokenizer.decode(model_tokens)}]')
# out[0] = -999999999 # disable <|endoftext|>
# out[187] += newline_adj # adjust \n probability
# if newline_adj > 0: # if newline_adj > 0:
# out[15] += newline_adj / 2 # '.' # out[15] += newline_adj / 2 # '.'
# if model_tokens[-1] in AVOID_REPEAT_TOKENS:
# out[model_tokens[-1]] = -999999999
out[0] += newline_adj
return out return out
all_state = {} all_state = {}
def save_all_stat(srv, name, last_out): def save_all_stat(srv, name, last_out):
n = f'{name}_{srv}' n = f'{name}_{srv}'
all_state[n] = {} all_state[n] = {}
all_state[n]['out'] = last_out all_state[n]['out'] = last_out
all_state[n]['rnn'] = copy.deepcopy(current_state) all_state[n]['rnn'] = copy.deepcopy(model_state)
all_state[n]['token'] = copy.deepcopy(model_tokens) all_state[n]['token'] = copy.deepcopy(model_tokens)
def load_all_stat(srv, name): def load_all_stat(srv, name):
global model_tokens, current_state global model_tokens, model_state
n = f'{name}_{srv}' n = f'{name}_{srv}'
current_state = copy.deepcopy(all_state[n]['rnn']) model_state = copy.deepcopy(all_state[n]['rnn'])
model_tokens = copy.deepcopy(all_state[n]['token']) model_tokens = copy.deepcopy(all_state[n]['token'])
return all_state[n]['out'] return all_state[n]['out']
######################################################################################################## ########################################################################################################
# Run inference # Run inference
print(f'\nRun prompt...') print(f'\nRun prompt...')
out = run_rnn(tokenizer.tokenizer.encode(init_prompt)) out = run_rnn(tokenizer.encode(init_prompt))
save_all_stat('', 'chat_init', out)
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
save_all_stat('', 'chat_init', out)
srv_list = ['dummy_server'] srv_list = ['dummy_server']
for s in srv_list: for s in srv_list:
save_all_stat(s, 'chat', out) save_all_stat(s, 'chat', out)
print(f'### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n') print(f'### prompt ###\n[{tokenizer.decode(model_tokens)}]\n')
def reply_msg(msg): def reply_msg(msg):
print(f'{bot}{interface} {msg}\n') print(f'{bot}{interface} {msg}\n')
def on_message(message): def on_message(message):
global model_tokens, current_state global model_tokens, model_state
srv = 'dummy_server' srv = 'dummy_server'
msg = message.replace('\\n','\n').strip() msg = message.replace('\\n', '\n').strip()
if len(msg) > 1000: # if len(msg) > 1000:
reply_msg('your message is too long (max 1000 tokens)') # reply_msg('your message is too long (max 1000 tokens)')
return # return
x_temp = 1.0 x_temp = GEN_TEMP
x_top_p = 0.85 x_top_p = GEN_TOP_P
if ("-temp=" in msg): if "-temp=" in msg:
x_temp = float(msg.split("-temp=")[1].split(" ")[0]) x_temp = float(msg.split("-temp=")[1].split(" ")[0])
msg = msg.replace("-temp="+f'{x_temp:g}', "") msg = msg.replace("-temp=" + f'{x_temp:g}', "")
# print(f"temp: {x_temp}") # print(f"temp: {x_temp}")
if ("-top_p=" in msg): if "-top_p=" in msg:
x_top_p = float(msg.split("-top_p=")[1].split(" ")[0]) x_top_p = float(msg.split("-top_p=")[1].split(" ")[0])
msg = msg.replace("-top_p="+f'{x_top_p:g}', "") msg = msg.replace("-top_p=" + f'{x_top_p:g}', "")
# print(f"top_p: {x_top_p}") # print(f"top_p: {x_top_p}")
if x_temp <= 0.2: if x_temp <= 0.2:
x_temp = 0.2 x_temp = 0.2
...@@ -224,136 +216,73 @@ def on_message(message): ...@@ -224,136 +216,73 @@ def on_message(message):
x_temp = 5 x_temp = 5
if x_top_p <= 0: if x_top_p <= 0:
x_top_p = 0 x_top_p = 0
if msg == '+reset': if msg == '+reset':
out = load_all_stat('', 'chat_init') out = load_all_stat('', 'chat_init')
save_all_stat(srv, 'chat', out) save_all_stat(srv, 'chat', out)
reply_msg("Chat reset.") reply_msg("Chat reset.")
return return
elif msg.lower() == '+++':
elif msg[:5].lower() == '+gen ' or msg[:4].lower() == '+qa ' or msg.lower() == '+more' or msg.lower() == '+retry': try:
if msg[:5].lower() == '+gen ':
new = '\n' + msg[5:].strip()
# print(f'### prompt ###\n[{new}]')
current_state = None
out = run_rnn(tokenizer.tokenizer.encode(new))
save_all_stat(srv, 'gen_0', out)
elif msg[:4].lower() == '+qa ':
out = load_all_stat('', 'chat_init')
real_msg = msg[4:].strip()
new = f"{user}{interface} {real_msg}\n\n{bot}{interface}"
# print(f'### qa ###\n[{new}]')
out = run_rnn(tokenizer.tokenizer.encode(new))
save_all_stat(srv, 'gen_0', out)
# new = f"\nThe following is an excellent Q&A session consists of detailed and factual information.\n\nQ: What is 3+5?\nA: The answer is 8.\n\nQ: {msg[9:].strip()}\nA:"
# print(f'### prompt ###\n[{new}]')
# current_state = None
# out = run_rnn(tokenizer.tokenizer.encode(new))
# save_all_stat(srv, 'gen_0', out)
elif msg.lower() == '+more':
try:
out = load_all_stat(srv, 'gen_1')
save_all_stat(srv, 'gen_0', out)
except:
return
elif msg.lower() == '+retry':
try:
out = load_all_stat(srv, 'gen_0')
except:
return
begin = len(model_tokens)
out_last = begin
for i in range(150):
token = tokenizer.sample_logits(
out,
model_tokens,
args.ctx_len,
temperature=x_temp,
top_p_usual=x_top_p,
top_p_newline=x_top_p,
)
if msg[:4].lower() == '+qa ':
out = run_rnn([token], newline_adj=-1)
else:
out = run_rnn([token])
xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
if '\ufffd' not in xxx:
print(xxx, end='', flush=True)
out_last = begin + i + 1
print('\n')
# send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip()
# print(f'### send ###\n[{send_msg}]')
# reply_msg(send_msg)
save_all_stat(srv, 'gen_1', out)
else:
if msg.lower() == '+alt':
try:
out = load_all_stat(srv, 'chat_pre')
except:
return
else:
out = load_all_stat(srv, 'chat') out = load_all_stat(srv, 'chat')
new = f"{user}{interface} {msg}\n\n{bot}{interface}"
# print(f'### add ###\n[{new}]')
out = run_rnn(tokenizer.tokenizer.encode(new), newline_adj=-999999999)
save_all_stat(srv, 'chat_pre', out) save_all_stat(srv, 'chat_pre', out)
except:
return
elif msg.lower() == '+':
try:
out = load_all_stat(srv, 'chat_pre')
except:
return
else:
# out = load_all_stat(srv, 'chat')
new = f"{user}{interface} {msg}\n\n{bot}{interface}"
out = run_rnn(tokenizer.encode(new), newline_adj=-999999999)
save_all_stat(srv, 'chat_pre', out)
begin = len(model_tokens)
out_last = begin
print(f'{bot}{interface}', end='', flush=True)
for i in range(FREE_GEN_LEN+100):
if i <= 0:
newline_adj = -999999999
elif i <= CHAT_LEN_SHORT:
newline_adj = (i - CHAT_LEN_SHORT) / 10
elif i <= CHAT_LEN_LONG:
newline_adj = 0
else:
newline_adj = (i - CHAT_LEN_LONG) * 0.25 # MUST END THE GENERATION
token = tokenizer.sample_logits(
out,
model_tokens,
args.ctx_len,
temperature=x_temp,
top_p=x_top_p,
)
if token == 0:
break
out = run_rnn([token], newline_adj=newline_adj)
xxx = tokenizer.decode(model_tokens[out_last:])
if '\ufffd' not in xxx: # avoid utf-8 display issues
print(xxx, end='', flush=True)
out_last = begin + i + 1
send_msg = tokenizer.decode(model_tokens[begin:])
# if '\n\n' in send_msg:
# send_msg = send_msg.strip()
# break
# send_msg = tokenizer.decode(model_tokens[begin:]).strip()
# if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!!
# send_msg = send_msg[:-len(f'{user}{interface}')].strip()
# break
# if send_msg.endswith(f'{bot}{interface}'):
# send_msg = send_msg[:-len(f'{bot}{interface}')].strip()
# break
print("\n")
save_all_stat(srv, 'chat', out)
begin = len(model_tokens)
out_last = begin
print(f'{bot}{interface}', end='', flush=True)
for i in range(999):
if i <= 0:
newline_adj = -999999999
elif i <= 30:
newline_adj = (i - 30) / 10
elif i <= 130:
newline_adj = 0
else:
newline_adj = (i - 130) * 0.25 # MUST END THE GENERATION
token = tokenizer.sample_logits(
out,
model_tokens,
args.ctx_len,
temperature=x_temp,
top_p_usual=x_top_p,
top_p_newline=x_top_p,
)
out = run_rnn([token], newline_adj=newline_adj)
xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
if '\ufffd' not in xxx:
print(xxx, end='', flush=True)
out_last = begin + i + 1
send_msg = tokenizer.tokenizer.decode(model_tokens[begin:])
if '\n\n' in send_msg:
send_msg = send_msg.strip()
break
# send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip()
# if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!!
# send_msg = send_msg[:-len(f'{user}{interface}')].strip()
# break
# if send_msg.endswith(f'{bot}{interface}'):
# send_msg = send_msg[:-len(f'{bot}{interface}')].strip()
# break
# print(f'{model_tokens}')
# print(f'[{tokenizer.tokenizer.decode(model_tokens)}]')
# print(f'### send ###\n[{send_msg}]')
# reply_msg(send_msg)
save_all_stat(srv, 'chat', out)
print(HELP_MSG) print(HELP_MSG)
...@@ -362,4 +291,4 @@ while True: ...@@ -362,4 +291,4 @@ while True:
if len(msg.strip()) > 0: if len(msg.strip()) > 0:
on_message(msg) on_message(msg)
else: else:
print('Erorr: please say something') print('Error: please say something')
...@@ -180,9 +180,9 @@ class MyDataset(Dataset): ...@@ -180,9 +180,9 @@ class MyDataset(Dataset):
if args.data_type == "binidx": if args.data_type == "binidx":
dix = data.get(idx=0, offset=i, length=req_len).astype(int) dix = data.get(idx=0, offset=i, length=req_len).astype(int)
elif args.data_type == "numpy": elif args.data_type == "numpy":
dix = data[i : i + req_len] dix = data[i: i + req_len]
else: else:
dix = [self.stoi[s] for s in data[i : i + req_len]] dix = [self.stoi[s] for s in data[i: i + req_len]]
if args.my_qa_mask == 1: if args.my_qa_mask == 1:
if data == self.data_pile: if data == self.data_pile:
...@@ -235,7 +235,8 @@ class S2SDataset(Dataset): ...@@ -235,7 +235,8 @@ class S2SDataset(Dataset):
for index, row in pf.iterrows(): for index, row in pf.iterrows():
question = row["question"] question = row["question"]
answer = row["answer"] answer = row["answer"]
data_list.append((self.tokenizer.tokenizer.encode(question), self.tokenizer.tokenizer.encode(answer))) data_list.append((self.tokenizer.tokenizer.encode(question), self.tokenizer.tokenizer.encode("\n"),
self.tokenizer.tokenizer.encode(answer)))
self.data = data_list self.data = data_list
def __len__(self): def __len__(self):
...@@ -244,16 +245,15 @@ class S2SDataset(Dataset): ...@@ -244,16 +245,15 @@ class S2SDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
ctx_len = self.args.ctx_len ctx_len = self.args.ctx_len
req_len = ctx_len + 1 req_len = ctx_len + 1
question, answer = self.data[index] question, sep, answer = self.data[index]
text = question + answer text = question + sep + answer
text = text[:req_len] text = text[:req_len]
text = text + [0] * (req_len - len(text)) text = text + [0] * (req_len - len(text))
x = torch.tensor(text[:-1], dtype=torch.long) x = torch.tensor(text[:-1], dtype=torch.long)
y = torch.tensor(text[1:], dtype=torch.long) y = torch.tensor(text[1:], dtype=torch.long)
z = [0] * (len(question) - 1) + [1] * (ctx_len - (len(question) - 1)) z = [0] * len(question) + [1] * (ctx_len - len(question))
z = torch.tensor(z, dtype=torch.long) z = torch.tensor(z, dtype=torch.long)
return x, y, z return x, y, z
...@@ -2,33 +2,31 @@ ...@@ -2,33 +2,31 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
######################################################################################################## ########################################################################################################
import types import types, math, os, gc
import torch import torch
import math, os, gc
from torch.nn import functional as F from torch.nn import functional as F
import torch.nn as nn
from typing import List, Dict
MyModule = nn.Module torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
MyModule = torch.nn.Module
def __nop(ob): def __nop(ob):
return ob return ob
MyFunction = __nop
# # try torchdynamo
# import torchdynamo
# MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
# try torch jit --> faster for fp32, slower for fp16 (why?) MyFunction = __nop
if os.environ["RWKV_JIT_ON"] == "1":
if int(os.environ["RWKV_JIT_ON"]) > 0:
MyModule = torch.jit.ScriptModule MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method MyFunction = torch.jit.script_method
RWKV_HEAD_QK_DIM = 0 print(f'\nRWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n')
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM} RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n')
DEBUG_TIME = False # True False - show trained time-coeffs RWKV_RESCALE_LAYER = 6 # set x = x/2 every X layer (to avoid FP16 overflow)
RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer
############################################################################################################ ############################################################################################################
...@@ -37,58 +35,67 @@ class RWKV_RNN(MyModule): ...@@ -37,58 +35,67 @@ class RWKV_RNN(MyModule):
super().__init__() super().__init__()
self.args = args self.args = args
self.FLOAT_MODE = args.FLOAT_MODE if args.FLOAT_MODE == 'fp32':
self.FLOAT_MODE = torch.float
elif args.FLOAT_MODE == 'fp16':
self.FLOAT_MODE = torch.half
elif args.FLOAT_MODE == 'bf16':
self.FLOAT_MODE = torch.bfloat16
self.RUN_DEVICE = args.RUN_DEVICE self.RUN_DEVICE = args.RUN_DEVICE
with torch.no_grad(): with torch.no_grad():
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu') w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
# refine weights and send to correct device args.n_embd = w['emb.weight'].shape[1]
keys = list(w.keys()) args.n_layer = 0
if 'pos_emb_x' in keys: keys = list(w.keys()) # refine weights and send to correct device
w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(args.ctx_len+1, -1)[:-1,:]
keys = list(w.keys())
print_need_newline = False print_need_newline = False
for x in keys: for x in keys:
block_id = 0 w[x].requires_grad = False
if 'blocks.' in x: if x == 'emb.weight' or 'ln0' in x:
block_id = int(x.split('.')[1]) continue
if 'att.output.weight' in x:
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER)) block_id = int(x.split('.')[1]) if ('blocks.' in x) else 0
if 'ffn.value.weight' in x: args.n_layer = max(args.n_layer, block_id + 1)
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
if '.time_' in x: if '.time_' in x:
w[x] = w[x].squeeze() w[x] = w[x].squeeze()
if DEBUG_TIME: if 'key.weight' in x or 'value.weight' in x or 'receptance.weight' in x or 'output.weight' in x:
print(x, w[x].numpy()) w[x] = w[x].t()
if '.time_decay' in x: if '.time_decay' in x:
w[x] = w[x].float() w[x] = w[x].float()
w[x] = -torch.exp(w[x]) w[x] = -torch.exp(w[x])
elif '.time_first' in x: elif '.time_first' in x:
w[x] = w[x].float() w[x] = w[x].float()
else: else:
if self.FLOAT_MODE == "fp32": w[x] = w[x].to(dtype=self.FLOAT_MODE)
w[x] = w[x].float()
elif self.FLOAT_MODE == "bf16":
w[x] = w[x].bfloat16()
elif self.FLOAT_MODE == "fp16":
w[x] = w[x].half()
w[x].requires_grad = False if args.FLOAT_MODE == 'fp16':
if args.RUN_DEVICE == 'cuda' and x != 'emb.weight': if 'att.output.weight' in x:
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
if 'ffn.value.weight' in x:
w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
if args.RUN_DEVICE == 'cuda':
w[x] = w[x].cuda() w[x] = w[x].cuda()
if ('blocks.' not in x) or ('blocks.0.' in x): shape = w[x].shape
shape = [i for i in shape if i != 1]
if len(shape) > 1:
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}"
else:
shape = f" {str(shape[0]).rjust(5)} "
if block_id == 0:
if print_need_newline: if print_need_newline:
print('\n', end = '') print('\n', end='')
print_need_newline = False print_need_newline = False
print(x.ljust(40), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device) print(x.ljust(32), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device, shape)
else: else:
print_need_newline = True print_need_newline = True
print('.', end = '', flush = True) print('.', end='', flush=True)
print(f'\nn_layer {args.n_layer} n_embd {args.n_embd} ctx_len {args.ctx_len}')
# store weights in self.w keys = list(w.keys()) # store weights in self.w
keys = list(w.keys())
self.w = types.SimpleNamespace() self.w = types.SimpleNamespace()
for x in keys: for x in keys:
xx = x.split('.') xx = x.split('.')
...@@ -103,12 +110,20 @@ class RWKV_RNN(MyModule): ...@@ -103,12 +110,20 @@ class RWKV_RNN(MyModule):
if i == len(xx) - 1: if i == len(xx) - 1:
setattr(here, xx[i], w[x]) setattr(here, xx[i], w[x])
elif not hasattr(here, xx[i]): elif not hasattr(here, xx[i]):
if xx[i+1].isdigit(): if xx[i + 1].isdigit():
setattr(here, xx[i], {}) setattr(here, xx[i], {})
else: else:
setattr(here, xx[i], types.SimpleNamespace()) setattr(here, xx[i], types.SimpleNamespace())
here = getattr(here, xx[i]) here = getattr(here, xx[i])
with torch.no_grad(): # precompute embedding
try:
x = self.LN(self.w.emb.weight, self.w.blocks[0].ln0)
except:
x = F.layer_norm(self.w.emb.weight.float(), (self.args.n_embd,),
weight=self.w.blocks[0].ln0.weight.float(), bias=self.w.blocks[0].ln0.bias.float())
self.w.emb.weight = x.to(dtype=self.FLOAT_MODE)
self.eval() self.eval()
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -119,119 +134,136 @@ class RWKV_RNN(MyModule): ...@@ -119,119 +134,136 @@ class RWKV_RNN(MyModule):
# state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp # state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp
@MyFunction @MyFunction
def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw): def FF_one(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw):
if self.FLOAT_MODE == "bf16": xx = state[5 * i + 0].to(dtype=self.FLOAT_MODE)
xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k) xk = x * time_mix_k + xx * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r) xr = x * time_mix_r + xx * (1 - time_mix_r)
state[5*i+0] = x.float() state[5 * i + 0] = x.float()
elif self.FLOAT_MODE == "fp16":
xk = x * time_mix_k + state[5*i+0].half() * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0].half() * (1 - time_mix_r)
state[5*i+0] = x.float()
else:
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
state[5*i+0] = x
r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(kw @ xk))
kv = vw @ k
r = torch.sigmoid(xr @ rw)
k = torch.square(torch.relu(xk @ kw))
kv = k @ vw
return r * kv return r * kv
@MyFunction @MyFunction
def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): def FF_seq(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw):
if self.FLOAT_MODE == "bf16": xx = torch.cat((state[5 * i + 0].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1, :]))
xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k) xk = x * time_mix_k + xx * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v) xr = x * time_mix_r + xx * (1 - time_mix_r)
xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r) state[5 * i + 0] = x[-1, :].float()
state[5*i+1] = x.float()
elif self.FLOAT_MODE == "fp16": r = torch.sigmoid(xr @ rw)
xk = x * time_mix_k + state[5*i+1].half() * (1 - time_mix_k) k = torch.square(torch.relu(xk @ kw))
xv = x * time_mix_v + state[5*i+1].half() * (1 - time_mix_v) kv = k @ vw
xr = x * time_mix_r + state[5*i+1].half() * (1 - time_mix_r) return r * kv
state[5*i+1] = x.float()
else: @MyFunction
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k) def SA_one(self, x, state, i: int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v) xx = state[5 * i + 1].to(dtype=self.FLOAT_MODE)
xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r) xk = x * time_mix_k + xx * (1 - time_mix_k)
state[5*i+1] = x xv = x * time_mix_v + xx * (1 - time_mix_v)
xr = x * time_mix_r + xx * (1 - time_mix_r)
r = torch.sigmoid(rw @ xr) state[5 * i + 1] = x.float()
k = kw @ xk
v = vw @ xv r = torch.sigmoid(xr @ rw)
k = (xk @ kw).float()
if '16' in self.FLOAT_MODE: v = (xv @ vw).float()
kk = k.float()
vv = v.float() aa = state[5 * i + 2]
else: bb = state[5 * i + 3]
kk = k pp = state[5 * i + 4]
vv = v ww = time_first + k
aa = state[5*i+2]
bb = state[5*i+3]
pp = state[5*i+4]
ww = time_first + kk
p = torch.maximum(pp, ww) p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p) e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p) e2 = torch.exp(ww - p)
a = e1 * aa + e2 * vv a = e1 * aa + e2 * v
b = e1 * bb + e2 b = e1 * bb + e2
ww = pp + time_decay ww = pp + time_decay
p = torch.maximum(ww, kk) p = torch.maximum(ww, k)
e1 = torch.exp(ww - p) e1 = torch.exp(ww - p)
e2 = torch.exp(kk - p) e2 = torch.exp(k - p)
state[5*i+2] = e1 * aa + e2 * vv state[5 * i + 2] = e1 * aa + e2 * v
state[5*i+3] = e1 * bb + e2 state[5 * i + 3] = e1 * bb + e2
state[5*i+4] = p state[5 * i + 4] = p
if self.FLOAT_MODE == "bf16": wkv = (a / b).to(dtype=self.FLOAT_MODE)
wkv = (a / b).type(torch.bfloat16) return (r * wkv) @ ow
elif self.FLOAT_MODE == "fp16":
wkv = (a / b).half() @MyFunction
else: def SA_seq(self, x, state, i: int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
wkv = a / b xx = torch.cat((state[5 * i + 1].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1, :]))
xk = x * time_mix_k + xx * (1 - time_mix_k)
return ow @ (r * wkv) xv = x * time_mix_v + xx * (1 - time_mix_v)
xr = x * time_mix_r + xx * (1 - time_mix_r)
def forward(self, ctx, state, preprocess_only = False): state[5 * i + 1] = x[-1, :].float()
r = torch.sigmoid(xr @ rw)
k = (xk @ kw).float()
v = (xv @ vw).float()
aa = state[5 * i + 2]
bb = state[5 * i + 3]
pp = state[5 * i + 4]
T = x.shape[0]
for t in range(T):
ww = time_first + k[t]
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
a = e1 * aa + e2 * v[t]
b = e1 * bb + e2
ww = pp + time_decay
p = torch.maximum(ww, k[t])
e1 = torch.exp(ww - p)
e2 = torch.exp(k[t] - p)
if t != T - 1:
aa = e1 * aa + e2 * v[t]
bb = e1 * bb + e2
pp = p
else:
state[5 * i + 2] = e1 * aa + e2 * v[t]
state[5 * i + 3] = e1 * bb + e2
state[5 * i + 4] = p
xx[t] = (a / b).to(dtype=self.FLOAT_MODE)
return (r * xx) @ ow
def forward(self, tokens, state, preprocess_only=False):
with torch.no_grad(): with torch.no_grad():
w = self.w w = self.w
args = self.args args = self.args
x = w.emb.weight[ctx[-1]] seq_mode = len(tokens) > 1
x = w.emb.weight[tokens] if seq_mode else w.emb.weight[tokens[-1]]
if self.RUN_DEVICE == 'cuda': if self.RUN_DEVICE == 'cuda':
x = x.cuda() x = x.cuda()
try:
pos_emb = w.pos_emb[len(ctx)-1]
x = x + pos_emb
except:
pass
if state == None: if state == None:
state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE) state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE)
for i in range(args.n_layer): for i in range(args.n_layer):
state[5*i+4] -= 1e30 state[5 * i + 4] -= 1e30
SA = self.SA_seq if seq_mode else self.SA_one
FF = self.FF_seq if seq_mode else self.FF_one
for i in range(args.n_layer): for i in range(args.n_layer):
if i == 0:
x = self.LN(x, w.blocks[i].ln0)
ww = w.blocks[i].att ww = w.blocks[i].att
x = x + self.SA(self.LN(x, w.blocks[i].ln1), state, i, x = x + SA(self.LN(x, w.blocks[i].ln1), state, i,
ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay, ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay,
ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight) ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight)
ww = w.blocks[i].ffn ww = w.blocks[i].ffn
x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i, x = x + FF(self.LN(x, w.blocks[i].ln2), state, i,
ww.time_mix_k, ww.time_mix_r, ww.time_mix_k, ww.time_mix_r,
ww.key.weight, ww.value.weight, ww.receptance.weight) ww.key.weight, ww.value.weight, ww.receptance.weight)
if (i+1) % RWKV_RESCALE_LAYER == 0: if args.FLOAT_MODE == 'fp16':
x = x / 2 if (i + 1) % RWKV_RESCALE_LAYER == 0:
x = x / 2
if preprocess_only: if preprocess_only:
return state return state
x = self.LN(x, w.ln_out) x = self.LN(x[-1, :], w.ln_out) if seq_mode else self.LN(x, w.ln_out)
x = w.head.weight @ x x = w.head.weight @ x
return x.float(), state return x.float(), state
...@@ -2,6 +2,7 @@ import json, time, random, os ...@@ -2,6 +2,7 @@ import json, time, random, os
import numpy as np import numpy as np
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from tokenizers import Tokenizer
time_slot = {} time_slot = {}
time_ref = time.time_ns() time_ref = time.time_ns()
...@@ -14,8 +15,7 @@ def record_time(name): ...@@ -14,8 +15,7 @@ def record_time(name):
time_slot[name] = tt time_slot[name] = tt
class TOKENIZER(object):
class TOKENIZER():
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
if 'list' in str(type(WORD_NAME)): if 'list' in str(type(WORD_NAME)):
self.charMode = False self.charMode = False
...@@ -26,6 +26,8 @@ class TOKENIZER(): ...@@ -26,6 +26,8 @@ class TOKENIZER():
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
self.vocab_size = len(self.tokenizer) self.vocab_size = len(self.tokenizer)
elif 'str' in str(type(WORD_NAME)):
self.tokenizer = Tokenizer.from_file(WORD_NAME)
else: else:
self.charMode = True self.charMode = True
with open(WORD_NAME + '.json', "r", encoding="utf-16le") as result_file: with open(WORD_NAME + '.json', "r", encoding="utf-16le") as result_file:
...@@ -38,6 +40,12 @@ class TOKENIZER(): ...@@ -38,6 +40,12 @@ class TOKENIZER():
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
def encode(self, x):
return self.tokenizer.encode(x).ids
def decode(self, x):
return self.tokenizer.decode(x)
def refine_context(self, context): def refine_context(self, context):
context = context.strip().split('\n') context = context.strip().split('\n')
for c in range(len(context)): for c in range(len(context)):
...@@ -48,19 +56,8 @@ class TOKENIZER(): ...@@ -48,19 +56,8 @@ class TOKENIZER():
context = '\n' context = '\n'
return context return context
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): def sample_logits(self, logits, x, ctx_len, temperature=1.0, top_p=1.0):
# out[self.UNKNOWN_CHAR] = -float('Inf') probs = F.softmax(logits.float(), dim=-1)
lastChar = int(x[-1])
probs = F.softmax(out, dim=-1)
if self.charMode:
if self.itos[lastChar] == '\n':
top_p = top_p_newline
else:
top_p = top_p_usual
else:
top_p = top_p_usual
if os.environ["RWKV_RUN_DEVICE"] == "cpu": if os.environ["RWKV_RUN_DEVICE"] == "cpu":
probs = probs.numpy() probs = probs.numpy()
...@@ -81,7 +78,7 @@ class TOKENIZER(): ...@@ -81,7 +78,7 @@ class TOKENIZER():
if temperature != 1.0: if temperature != 1.0:
probs = probs.pow(1.0 / temperature) probs = probs.pow(1.0 / temperature)
out = torch.multinomial(probs, num_samples=1)[0] out = torch.multinomial(probs, num_samples=1)[0]
return out return int(out)
def MaybeIsPrime(number): def MaybeIsPrime(number):
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number): if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册