提交 7bb01b99 编写于 作者: CSDN-Ada助手's avatar CSDN-Ada助手

sft优化

上级 99595b68
...@@ -51,7 +51,7 @@ python train.py --load_model "rwkv-80.pth" --wandb "" --proj_dir "out" \ ...@@ -51,7 +51,7 @@ python train.py --load_model "rwkv-80.pth" --wandb "" --proj_dir "out" \
``` ```
python train_sft.py --load_model "rwkv-500.pth" --wandb "" --proj_dir "out_sft" \ python train_sft.py --load_model "rwkv-500.pth" --wandb "" --proj_dir "out_sft" \
--data_file "data/prompts.csv" --data_type "utf-8" --vocab_size 50277 \ --data_file "data/prompts.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 2048 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \ --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
--micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \ --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \
......
...@@ -233,17 +233,14 @@ class S2SDataset(Dataset): ...@@ -233,17 +233,14 @@ class S2SDataset(Dataset):
] # [vocab, vocab] for Pile model ] # [vocab, vocab] for Pile model
self.tokenizer = TOKENIZER(WORD_NAME) self.tokenizer = TOKENIZER(WORD_NAME)
self.prompt_q = self.tokenizer.tokenizer.encode("Q:")
self.prompt_a = self.tokenizer.tokenizer.encode("A:")
self.prompt_sep = self.tokenizer.tokenizer.encode("\n\n")
pf = pd.read_csv(args.data_file) pf = pd.read_csv(args.data_file)
data_list = [] data_list = []
for index, row in pf.iterrows(): for index, row in pf.iterrows():
question = row["question"] question = row["question"]
answer = row["answer"] answer = row["answer"]
data_list.append((self.tokenizer.tokenizer.encode(question), data_list.append((self.tokenizer.tokenizer.encode(question),
self.tokenizer.tokenizer.encode("\n"),
self.tokenizer.tokenizer.encode(answer))) self.tokenizer.tokenizer.encode(answer)))
self.data = data_list self.data = data_list
...@@ -253,18 +250,15 @@ class S2SDataset(Dataset): ...@@ -253,18 +250,15 @@ class S2SDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
ctx_len = self.args.ctx_len ctx_len = self.args.ctx_len
req_len = ctx_len + 1 req_len = ctx_len + 1
question, answer = self.data[index] question, sep, answer = self.data[index]
question = self.prompt_q + question text = question + sep + answer
answer = self.prompt_a + answer
text = question + self.prompt_sep + answer
text = text[:req_len] text = text[:req_len]
text = text + [0] * (req_len - len(text)) text = text + [0] * (req_len - len(text))
x = torch.tensor(text[:-1], dtype=torch.long) x = torch.tensor(text[:-1], dtype=torch.long)
y = torch.tensor(text[1:], dtype=torch.long) y = torch.tensor(text[1:], dtype=torch.long)
q_len = len(question) + len(self.prompt_a) + len(self.prompt_sep) z = [0] * len(question) + [1] * (ctx_len - len(question))
z = [0] * q_len + [1] * (ctx_len - q_len)
z = torch.tensor(z, dtype=torch.long) z = torch.tensor(z, dtype=torch.long)
return x, y, z return x, y, z
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册