提交 a0f6ec61 编写于 作者: CSDN-Ada助手's avatar CSDN-Ada助手

opt code

上级 723ff564
......@@ -22,48 +22,22 @@ args.FLOAT_MODE = "fp16"
os.environ["RWKV_JIT_ON"] = '1' # '1' or '0', please use torch 1.13+ and benchmark speed
CHAT_LANG = 'Chinese' # English // Chinese // more to come
QA_PROMPT = True # True: Q & A prompt // False: User & Bot prompt
# 中文问答设置QA_PROMPT=True(只能问答,问答效果更好,但不能闲聊) 中文聊天设置QA_PROMPT=False(可以闲聊,但需要大模型才适合闲聊)
# Download RWKV-4 models from https://huggingface.co/BlinkDL (don't use Instruct-test models unless you use their prompt templates)
if CHAT_LANG == 'English':
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230213-8019'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221110-ctx4096'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-430m/RWKV-4-Pile-430M-20220808-8066'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023'
# args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-340'
# args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/14b-run1/rwkv-6210'
args.MODEL_NAME = 'out_sft/rwkv-20'
elif CHAT_LANG == 'Chinese':
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-EngChn-test4-20230116'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-EngChn-test4-20230115'
# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-EngChn-test4-20230115'
# args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-80'
# args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/3-run1z/rwkv-170'
# args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/1.5-run1z/rwkv-0'
args.MODEL_NAME = 'out_sft/rwkv-20'
args.MODEL_NAME = 'out_sft/rwkv-440'
args.ctx_len = 1024
CHAT_LEN_SHORT = 40
CHAT_LEN_LONG = 150
CHAT_LEN_SHORT = 10
CHAT_LEN_LONG = 200
FREE_GEN_LEN = 200
GEN_TEMP = 1.0
GEN_TOP_P = 0.85
AVOID_REPEAT = ',。:?!'
########################################################################################################
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
print(f'\nLoading ChatRWKV - {CHAT_LANG} - {args.RUN_DEVICE} - {args.FLOAT_MODE} - QA_PROMPT {QA_PROMPT}')
print(f'\nLoading ChatRWKV - {args.RUN_DEVICE} - {args.FLOAT_MODE}')
import torch
# please tune these (test True/False for all of them). can significantly improve speed.
......@@ -90,13 +64,12 @@ args.my_pos_emb = 0
MODEL_NAME = args.MODEL_NAME
interface = ":"
user = "Q"
bot = "A"
user = "### Instruction"
bot = "### Response"
init_prompt = f'''
Expert Questions & Helpful Answers
Below is an instruction that describes a task. Write a response that appropriately completes the request.
Ask Research Experts
'''
......@@ -107,7 +80,6 @@ HELP_MSG = '''指令:
+ --> 让机器人换个回答
+reset --> 重置对话
+++ --> 继续回答
++ --> 换个回答
现在可以输入内容和机器人聊天(注意它不大懂中文,它更懂英文)。请经常使用 +reset 重置机器人记忆。
目前没有“重复惩罚”,所以机器人有时会重复,此时必须使用 + 换成正常回答,以免污染电脑记忆。
......@@ -122,11 +94,6 @@ model = RWKV_RNN(args)
model_tokens = []
model_state = None
AVOID_REPEAT_TOKENS = []
for i in AVOID_REPEAT:
dd = tokenizer.encode(i)
assert len(dd) == 1
AVOID_REPEAT_TOKENS += dd
########################################################################################################
......@@ -138,14 +105,6 @@ def run_rnn(tokens, newline_adj=0):
model_tokens += tokens
out, model_state = model.forward(tokens, model_state)
# print(f'### model ###\n{tokens}\n[{tokenizer.decode(model_tokens)}]')
# out[0] = -999999999 # disable <|endoftext|>
# out[187] += newline_adj # adjust \n probability
# if newline_adj > 0:
# out[15] += newline_adj / 2 # '.'
# if model_tokens[-1] in AVOID_REPEAT_TOKENS:
# out[model_tokens[-1]] = -999999999
out[0] += newline_adj
return out
......
......@@ -31,7 +31,7 @@ def clean_text(text):
def clean_blog_data():
# 数据从odps读取
raw_data_path = "./data/raw_data.txt"
data_path = "./data/data.json"
data_path = "./data/data.jsonl"
with open(raw_data_path) as file_r:
with open(data_path, "w") as file_w:
index = 1
......
......@@ -239,7 +239,8 @@ class S2SDataset(Dataset):
for index, row in pf.iterrows():
input = row["input"]
target = row["target"]
input_instruction = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction: {input}\n\n### Response:"
input_instruction = f"Below is an instruction that describes a task. Write a response that appropriately " \
f"completes the request.\n\n### Instruction: {input}\n\n### Response:"
input_tokens = self.tokenizer.tokenizer.encode(input_instruction)
target_tokens = self.tokenizer.tokenizer.encode(target)
if len(input_tokens) + len(target_tokens) > self.args.ctx_len:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册