######################################################################################################## # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## import os, copy, types, gc, sys import numpy as np try: os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] except: pass np.set_printoptions(precision=4, suppress=True, linewidth=200) args = types.SimpleNamespace() print('\n\nChatRWKV project: https://github.com/BlinkDL/ChatRWKV') ######################################################################################################## args.RUN_DEVICE = "cuda" # cuda // cpu # fp16 (good for GPU, does NOT support CPU) // fp32 (good for CPU) // bf16 (worse accuracy, supports CPU) args.FLOAT_MODE = "fp16" os.environ["RWKV_JIT_ON"] = '1' # '1' or '0', please use torch 1.13+ and benchmark speed CHAT_LANG = 'Chinese' # English // Chinese // more to come QA_PROMPT = True # True: Q & A prompt // False: User & Bot prompt # 中文问答设置QA_PROMPT=True(只能问答,问答效果更好,但不能闲聊) 中文聊天设置QA_PROMPT=False(可以闲聊,但需要大模型才适合闲聊) # Download RWKV-4 models from https://huggingface.co/BlinkDL (don't use Instruct-test models unless you use their prompt templates) if CHAT_LANG == 'English': # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230213-8019' # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047' # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221110-ctx4096' # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040' # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-430m/RWKV-4-Pile-430M-20220808-8066' # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023' # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-340' # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/14b-run1/rwkv-6210' args.MODEL_NAME = 'out_sft/rwkv-20' elif CHAT_LANG == 'Chinese': # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-EngChn-test4-20230116' # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-EngChn-test4-20230115' # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-EngChn-test4-20230115' # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-80' # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/3-run1z/rwkv-170' # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/1.5-run1z/rwkv-0' args.MODEL_NAME = 'out_sft/rwkv-20' args.ctx_len = 1024 CHAT_LEN_SHORT = 40 CHAT_LEN_LONG = 150 FREE_GEN_LEN = 200 GEN_TEMP = 1.0 GEN_TOP_P = 0.85 AVOID_REPEAT = ',。:?!' ######################################################################################################## os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE print(f'\nLoading ChatRWKV - {CHAT_LANG} - {args.RUN_DEVICE} - {args.FLOAT_MODE} - QA_PROMPT {QA_PROMPT}') import torch # please tune these (test True/False for all of them). can significantly improve speed. # torch._C._jit_set_profiling_executor(True) # torch._C._jit_set_profiling_mode(True) # torch._C._jit_override_can_fuse_on_cpu(True) # torch._C._jit_override_can_fuse_on_gpu(True) # torch._C._jit_set_texpr_fuser_enabled(False) # torch._C._jit_set_nvfuser_enabled(False) torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True from src.model_run import RWKV_RNN from src.utils import TOKENIZER tokenizer = TOKENIZER("20B_tokenizer.json") args.vocab_size = 50277 args.head_qk = 0 args.pre_ffn = 0 args.grad_cp = 0 args.my_pos_emb = 0 MODEL_NAME = args.MODEL_NAME interface = ":" user = "Q" bot = "A" init_prompt = f''' Expert Questions & Helpful Answers Ask Research Experts ''' HELP_MSG = '''指令: 直接输入内容 --> 和机器人聊天(建议问机器人问题),用\\n代表换行 + --> 让机器人换个回答 +reset --> 重置对话 +++ --> 继续回答 ++ --> 换个回答 现在可以输入内容和机器人聊天(注意它不大懂中文,它更懂英文)。请经常使用 +reset 重置机器人记忆。 目前没有“重复惩罚”,所以机器人有时会重复,此时必须使用 + 换成正常回答,以免污染电脑记忆。 ''' # Load Model print(f'Loading model - {MODEL_NAME}') model = RWKV_RNN(args) model_tokens = [] model_state = None AVOID_REPEAT_TOKENS = [] for i in AVOID_REPEAT: dd = tokenizer.encode(i) assert len(dd) == 1 AVOID_REPEAT_TOKENS += dd ######################################################################################################## def run_rnn(tokens, newline_adj=0): global model_tokens, model_state tokens = [int(x) for x in tokens] model_tokens += tokens out, model_state = model.forward(tokens, model_state) # print(f'### model ###\n{tokens}\n[{tokenizer.decode(model_tokens)}]') # out[0] = -999999999 # disable <|endoftext|> # out[187] += newline_adj # adjust \n probability # if newline_adj > 0: # out[15] += newline_adj / 2 # '.' # if model_tokens[-1] in AVOID_REPEAT_TOKENS: # out[model_tokens[-1]] = -999999999 out[0] += newline_adj return out all_state = {} def save_all_stat(srv, name, last_out): n = f'{name}_{srv}' all_state[n] = {} all_state[n]['out'] = last_out all_state[n]['rnn'] = copy.deepcopy(model_state) all_state[n]['token'] = copy.deepcopy(model_tokens) def load_all_stat(srv, name): global model_tokens, model_state n = f'{name}_{srv}' model_state = copy.deepcopy(all_state[n]['rnn']) model_tokens = copy.deepcopy(all_state[n]['token']) return all_state[n]['out'] ######################################################################################################## # Run inference print(f'\nRun prompt...') out = run_rnn(tokenizer.encode(init_prompt)) save_all_stat('', 'chat_init', out) gc.collect() torch.cuda.empty_cache() srv_list = ['dummy_server'] for s in srv_list: save_all_stat(s, 'chat', out) print(f'### prompt ###\n[{tokenizer.decode(model_tokens)}]\n') def reply_msg(msg): print(f'{bot}{interface} {msg}\n') def on_message(message): global model_tokens, model_state srv = 'dummy_server' msg = message.replace('\\n', '\n').strip() # if len(msg) > 1000: # reply_msg('your message is too long (max 1000 tokens)') # return x_temp = GEN_TEMP x_top_p = GEN_TOP_P if "-temp=" in msg: x_temp = float(msg.split("-temp=")[1].split(" ")[0]) msg = msg.replace("-temp=" + f'{x_temp:g}', "") # print(f"temp: {x_temp}") if "-top_p=" in msg: x_top_p = float(msg.split("-top_p=")[1].split(" ")[0]) msg = msg.replace("-top_p=" + f'{x_top_p:g}', "") # print(f"top_p: {x_top_p}") if x_temp <= 0.2: x_temp = 0.2 if x_temp >= 5: x_temp = 5 if x_top_p <= 0: x_top_p = 0 if msg == '+reset': out = load_all_stat('', 'chat_init') save_all_stat(srv, 'chat', out) reply_msg("Chat reset.") return elif msg.lower() == '+++': try: out = load_all_stat(srv, 'chat') save_all_stat(srv, 'chat_pre', out) except: return elif msg.lower() == '+': try: out = load_all_stat(srv, 'chat_pre') except: return else: # out = load_all_stat(srv, 'chat') new = f"{user}{interface} {msg}\n\n{bot}{interface}" out = run_rnn(tokenizer.encode(new), newline_adj=-999999999) save_all_stat(srv, 'chat_pre', out) begin = len(model_tokens) out_last = begin print(f'{bot}{interface}', end='', flush=True) for i in range(FREE_GEN_LEN+100): if i <= 0: newline_adj = -999999999 elif i <= CHAT_LEN_SHORT: newline_adj = (i - CHAT_LEN_SHORT) / 10 elif i <= CHAT_LEN_LONG: newline_adj = 0 else: newline_adj = (i - CHAT_LEN_LONG) * 0.25 # MUST END THE GENERATION token = tokenizer.sample_logits( out, model_tokens, args.ctx_len, temperature=x_temp, top_p=x_top_p, ) if token == 0: break out = run_rnn([token], newline_adj=newline_adj) xxx = tokenizer.decode(model_tokens[out_last:]) if '\ufffd' not in xxx: # avoid utf-8 display issues print(xxx, end='', flush=True) out_last = begin + i + 1 send_msg = tokenizer.decode(model_tokens[begin:]) # if '\n\n' in send_msg: # send_msg = send_msg.strip() # break # send_msg = tokenizer.decode(model_tokens[begin:]).strip() # if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!! # send_msg = send_msg[:-len(f'{user}{interface}')].strip() # break # if send_msg.endswith(f'{bot}{interface}'): # send_msg = send_msg[:-len(f'{bot}{interface}')].strip() # break print("\n") save_all_stat(srv, 'chat', out) print(HELP_MSG) while True: msg = input(f'{user}{interface} ') if len(msg.strip()) > 0: on_message(msg) else: print('Error: please say something')