######################################################################################################## # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## import os, copy, types, gc, sys import numpy as np try: os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] except: pass np.set_printoptions(precision=4, suppress=True, linewidth=200) args = types.SimpleNamespace() print('\n\nChatRWKV project: https://github.com/BlinkDL/ChatRWKV') ######################################################################################################## args.RUN_DEVICE = "cuda" # cuda // cpu # fp16 (good for GPU, does NOT support CPU) // fp32 (good for CPU) // bf16 (worse accuracy, supports CPU) args.FLOAT_MODE = "fp16" os.environ["RWKV_JIT_ON"] = '1' # '1' or '0', please use torch 1.13+ and benchmark speed args.MODEL_NAME = 'out_sft/rwkv-440' args.ctx_len = 1024 CHAT_LEN_SHORT = 10 CHAT_LEN_LONG = 200 FREE_GEN_LEN = 200 GEN_TEMP = 1.0 GEN_TOP_P = 0.85 ######################################################################################################## os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE print(f'\nLoading ChatRWKV - {args.RUN_DEVICE} - {args.FLOAT_MODE}') import torch # please tune these (test True/False for all of them). can significantly improve speed. # torch._C._jit_set_profiling_executor(True) # torch._C._jit_set_profiling_mode(True) # torch._C._jit_override_can_fuse_on_cpu(True) # torch._C._jit_override_can_fuse_on_gpu(True) # torch._C._jit_set_texpr_fuser_enabled(False) # torch._C._jit_set_nvfuser_enabled(False) torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True from src.model_run import RWKV_RNN from src.utils import TOKENIZER tokenizer = TOKENIZER("20B_tokenizer.json") args.vocab_size = 50277 args.head_qk = 0 args.pre_ffn = 0 args.grad_cp = 0 args.my_pos_emb = 0 MODEL_NAME = args.MODEL_NAME interface = ":" user = "### Instruction" bot = "### Response" init_prompt = f''' Below is an instruction that describes a task. Write a response that appropriately completes the request. ''' HELP_MSG = '''指令: 直接输入内容 --> 和机器人聊天(建议问机器人问题),用\\n代表换行 + --> 让机器人换个回答 +reset --> 重置对话 +++ --> 继续回答 现在可以输入内容和机器人聊天(注意它不大懂中文,它更懂英文)。请经常使用 +reset 重置机器人记忆。 目前没有“重复惩罚”,所以机器人有时会重复,此时必须使用 + 换成正常回答,以免污染电脑记忆。 ''' # Load Model print(f'Loading model - {MODEL_NAME}') model = RWKV_RNN(args) model_tokens = [] model_state = None ######################################################################################################## def run_rnn(tokens, newline_adj=0): global model_tokens, model_state tokens = [int(x) for x in tokens] model_tokens += tokens out, model_state = model.forward(tokens, model_state) out[0] += newline_adj return out all_state = {} def save_all_stat(srv, name, last_out): n = f'{name}_{srv}' all_state[n] = {} all_state[n]['out'] = last_out all_state[n]['rnn'] = copy.deepcopy(model_state) all_state[n]['token'] = copy.deepcopy(model_tokens) def load_all_stat(srv, name): global model_tokens, model_state n = f'{name}_{srv}' model_state = copy.deepcopy(all_state[n]['rnn']) model_tokens = copy.deepcopy(all_state[n]['token']) return all_state[n]['out'] ######################################################################################################## # Run inference print(f'\nRun prompt...') out = run_rnn(tokenizer.encode(init_prompt)) save_all_stat('', 'chat_init', out) gc.collect() torch.cuda.empty_cache() srv_list = ['dummy_server'] for s in srv_list: save_all_stat(s, 'chat', out) print(f'### prompt ###\n[{tokenizer.decode(model_tokens)}]\n') def reply_msg(msg): print(f'{bot}{interface} {msg}\n') def on_message(message): global model_tokens, model_state srv = 'dummy_server' msg = message.replace('\\n', '\n').strip() # if len(msg) > 1000: # reply_msg('your message is too long (max 1000 tokens)') # return x_temp = GEN_TEMP x_top_p = GEN_TOP_P if "-temp=" in msg: x_temp = float(msg.split("-temp=")[1].split(" ")[0]) msg = msg.replace("-temp=" + f'{x_temp:g}', "") # print(f"temp: {x_temp}") if "-top_p=" in msg: x_top_p = float(msg.split("-top_p=")[1].split(" ")[0]) msg = msg.replace("-top_p=" + f'{x_top_p:g}', "") # print(f"top_p: {x_top_p}") if x_temp <= 0.2: x_temp = 0.2 if x_temp >= 5: x_temp = 5 if x_top_p <= 0: x_top_p = 0 if msg == '+reset': out = load_all_stat('', 'chat_init') save_all_stat(srv, 'chat', out) reply_msg("Chat reset.") return elif msg.lower() == '+++': try: out = load_all_stat(srv, 'chat') save_all_stat(srv, 'chat_pre', out) except: return elif msg.lower() == '+': try: out = load_all_stat(srv, 'chat_pre') except: return else: # out = load_all_stat(srv, 'chat') new = f"{user}{interface} {msg}\n\n{bot}{interface}" out = run_rnn(tokenizer.encode(new), newline_adj=-999999999) save_all_stat(srv, 'chat_pre', out) begin = len(model_tokens) out_last = begin print(f'{bot}{interface}', end='', flush=True) for i in range(FREE_GEN_LEN+100): if i <= 0: newline_adj = -999999999 elif i <= CHAT_LEN_SHORT: newline_adj = (i - CHAT_LEN_SHORT) / 10 elif i <= CHAT_LEN_LONG: newline_adj = 0 else: newline_adj = (i - CHAT_LEN_LONG) * 0.25 # MUST END THE GENERATION token = tokenizer.sample_logits( out, model_tokens, args.ctx_len, temperature=x_temp, top_p=x_top_p, ) if token == 0: break out = run_rnn([token], newline_adj=newline_adj) xxx = tokenizer.decode(model_tokens[out_last:]) if '\ufffd' not in xxx: # avoid utf-8 display issues print(xxx, end='', flush=True) out_last = begin + i + 1 send_msg = tokenizer.decode(model_tokens[begin:]) # if '\n\n' in send_msg: # send_msg = send_msg.strip() # break # send_msg = tokenizer.decode(model_tokens[begin:]).strip() # if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!! # send_msg = send_msg[:-len(f'{user}{interface}')].strip() # break # if send_msg.endswith(f'{bot}{interface}'): # send_msg = send_msg[:-len(f'{bot}{interface}')].strip() # break print("\n") save_all_stat(srv, 'chat', out) print(HELP_MSG) while True: msg = input(f'{user}{interface} ') if len(msg.strip()) > 0: on_message(msg) else: print('Error: please say something')