From a0f6ec614fb9bc88b592e7ce92143dd0cc696922 Mon Sep 17 00:00:00 2001 From: chenlong Date: Thu, 20 Apr 2023 10:37:13 +0800 Subject: [PATCH] opt code --- chat.py | 55 +++++++------------------------------------------- clean_data.py | 2 +- src/dataset.py | 3 ++- 3 files changed, 10 insertions(+), 50 deletions(-) diff --git a/chat.py b/chat.py index 9967248..6ecd181 100644 --- a/chat.py +++ b/chat.py @@ -22,48 +22,22 @@ args.FLOAT_MODE = "fp16" os.environ["RWKV_JIT_ON"] = '1' # '1' or '0', please use torch 1.13+ and benchmark speed -CHAT_LANG = 'Chinese' # English // Chinese // more to come - -QA_PROMPT = True # True: Q & A prompt // False: User & Bot prompt -# 中文问答设置QA_PROMPT=True(只能问答,问答效果更好,但不能闲聊) 中文聊天设置QA_PROMPT=False(可以闲聊,但需要大模型才适合闲聊) - -# Download RWKV-4 models from https://huggingface.co/BlinkDL (don't use Instruct-test models unless you use their prompt templates) - -if CHAT_LANG == 'English': - # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230213-8019' - # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047' - # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221110-ctx4096' - # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040' - # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-430m/RWKV-4-Pile-430M-20220808-8066' - # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023' - # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-340' - # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/14b-run1/rwkv-6210' - args.MODEL_NAME = 'out_sft/rwkv-20' - -elif CHAT_LANG == 'Chinese': - # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-EngChn-test4-20230116' - # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-EngChn-test4-20230115' - # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-EngChn-test4-20230115' - # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-80' - # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/3-run1z/rwkv-170' - # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/1.5-run1z/rwkv-0' - args.MODEL_NAME = 'out_sft/rwkv-20' +args.MODEL_NAME = 'out_sft/rwkv-440' args.ctx_len = 1024 -CHAT_LEN_SHORT = 40 -CHAT_LEN_LONG = 150 +CHAT_LEN_SHORT = 10 +CHAT_LEN_LONG = 200 FREE_GEN_LEN = 200 GEN_TEMP = 1.0 GEN_TOP_P = 0.85 -AVOID_REPEAT = ',。:?!' ######################################################################################################## os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE -print(f'\nLoading ChatRWKV - {CHAT_LANG} - {args.RUN_DEVICE} - {args.FLOAT_MODE} - QA_PROMPT {QA_PROMPT}') +print(f'\nLoading ChatRWKV - {args.RUN_DEVICE} - {args.FLOAT_MODE}') import torch # please tune these (test True/False for all of them). can significantly improve speed. @@ -90,13 +64,12 @@ args.my_pos_emb = 0 MODEL_NAME = args.MODEL_NAME interface = ":" -user = "Q" -bot = "A" +user = "### Instruction" +bot = "### Response" init_prompt = f''' -Expert Questions & Helpful Answers +Below is an instruction that describes a task. Write a response that appropriately completes the request. -Ask Research Experts ''' @@ -107,7 +80,6 @@ HELP_MSG = '''指令: + --> 让机器人换个回答 +reset --> 重置对话 +++ --> 继续回答 -++ --> 换个回答 现在可以输入内容和机器人聊天(注意它不大懂中文,它更懂英文)。请经常使用 +reset 重置机器人记忆。 目前没有“重复惩罚”,所以机器人有时会重复,此时必须使用 + 换成正常回答,以免污染电脑记忆。 @@ -122,11 +94,6 @@ model = RWKV_RNN(args) model_tokens = [] model_state = None -AVOID_REPEAT_TOKENS = [] -for i in AVOID_REPEAT: - dd = tokenizer.encode(i) - assert len(dd) == 1 - AVOID_REPEAT_TOKENS += dd ######################################################################################################## @@ -138,14 +105,6 @@ def run_rnn(tokens, newline_adj=0): model_tokens += tokens out, model_state = model.forward(tokens, model_state) - # print(f'### model ###\n{tokens}\n[{tokenizer.decode(model_tokens)}]') - - # out[0] = -999999999 # disable <|endoftext|> - # out[187] += newline_adj # adjust \n probability - # if newline_adj > 0: - # out[15] += newline_adj / 2 # '.' - # if model_tokens[-1] in AVOID_REPEAT_TOKENS: - # out[model_tokens[-1]] = -999999999 out[0] += newline_adj return out diff --git a/clean_data.py b/clean_data.py index 143b682..9380787 100644 --- a/clean_data.py +++ b/clean_data.py @@ -31,7 +31,7 @@ def clean_text(text): def clean_blog_data(): # 数据从odps读取 raw_data_path = "./data/raw_data.txt" - data_path = "./data/data.json" + data_path = "./data/data.jsonl" with open(raw_data_path) as file_r: with open(data_path, "w") as file_w: index = 1 diff --git a/src/dataset.py b/src/dataset.py index 993696e..0b82389 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -239,7 +239,8 @@ class S2SDataset(Dataset): for index, row in pf.iterrows(): input = row["input"] target = row["target"] - input_instruction = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction: {input}\n\n### Response:" + input_instruction = f"Below is an instruction that describes a task. Write a response that appropriately " \ + f"completes the request.\n\n### Instruction: {input}\n\n### Response:" input_tokens = self.tokenizer.tokenizer.encode(input_instruction) target_tokens = self.tokenizer.tokenizer.encode(target) if len(input_tokens) + len(target_tokens) > self.args.ctx_len: -- GitLab