提交 99595b68 编写于 作者: CSDN-Ada助手's avatar CSDN-Ada助手

修改sft

上级 be4440da
......@@ -49,9 +49,9 @@ python train.py --load_model "rwkv-80.pth" --wandb "" --proj_dir "out" \
使用指令数据集进行监督训练,精调语言模型,指令数据集格式为句子对。这部分数据需要由开发人员来进行编写,有的语料需要涉及到推理过程。
```
python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft" \
python train_sft.py --load_model "rwkv-500.pth" --wandb "" --proj_dir "out_sft" \
--data_file "data/prompts.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 2048 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
--ctx_len 2048 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
--micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \
......
......@@ -233,12 +233,17 @@ class S2SDataset(Dataset):
] # [vocab, vocab] for Pile model
self.tokenizer = TOKENIZER(WORD_NAME)
self.prompt_q = self.tokenizer.tokenizer.encode("Q:")
self.prompt_a = self.tokenizer.tokenizer.encode("A:")
self.prompt_sep = self.tokenizer.tokenizer.encode("\n\n")
pf = pd.read_csv(args.data_file)
data_list = []
for index, row in pf.iterrows():
question = row["question"]
answer = row["answer"]
data_list.append((self.tokenizer.tokenizer.encode(question), self.tokenizer.tokenizer.encode("\n"),
data_list.append((self.tokenizer.tokenizer.encode(question),
self.tokenizer.tokenizer.encode(answer)))
self.data = data_list
......@@ -248,15 +253,18 @@ class S2SDataset(Dataset):
def __getitem__(self, index):
ctx_len = self.args.ctx_len
req_len = ctx_len + 1
question, sep, answer = self.data[index]
text = question + sep + answer
question, answer = self.data[index]
question = self.prompt_q + question
answer = self.prompt_a + answer
text = question + self.prompt_sep + answer
text = text[:req_len]
text = text + [0] * (req_len - len(text))
x = torch.tensor(text[:-1], dtype=torch.long)
y = torch.tensor(text[1:], dtype=torch.long)
z = [0] * len(question) + [1] * (ctx_len - len(question))
q_len = len(question) + len(self.prompt_a) + len(self.prompt_sep)
z = [0] * q_len + [1] * (ctx_len - q_len)
z = torch.tensor(z, dtype=torch.long)
return x, y, z
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册