diff --git a/README.md b/README.md index 5456dcc58359a9138b0fa6b65bcf5c63688095ce..5984db4109b12676c06808a2e4d82bd383c7d48a 100644 --- a/README.md +++ b/README.md @@ -49,9 +49,9 @@ python train.py --load_model "rwkv-80.pth" --wandb "" --proj_dir "out" \ 使用指令数据集进行监督训练,精调语言模型,指令数据集格式为句子对。这部分数据需要由开发人员来进行编写,有的语料需要涉及到推理过程。 ``` -python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft" \ +python train_sft.py --load_model "rwkv-500.pth" --wandb "" --proj_dir "out_sft" \ --data_file "data/prompts.csv" --data_type "utf-8" --vocab_size 50277 \ ---ctx_len 2048 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \ +--ctx_len 2048 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \ --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \ diff --git a/src/dataset.py b/src/dataset.py index c492408de35f22a4c8f831461fb84e71fd29a130..be66b9f945ce05a636e5af2025fb50d077735b26 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -233,12 +233,17 @@ class S2SDataset(Dataset): ] # [vocab, vocab] for Pile model self.tokenizer = TOKENIZER(WORD_NAME) + + self.prompt_q = self.tokenizer.tokenizer.encode("Q:") + self.prompt_a = self.tokenizer.tokenizer.encode("A:") + self.prompt_sep = self.tokenizer.tokenizer.encode("\n\n") + pf = pd.read_csv(args.data_file) data_list = [] for index, row in pf.iterrows(): question = row["question"] answer = row["answer"] - data_list.append((self.tokenizer.tokenizer.encode(question), self.tokenizer.tokenizer.encode("\n"), + data_list.append((self.tokenizer.tokenizer.encode(question), self.tokenizer.tokenizer.encode(answer))) self.data = data_list @@ -248,15 +253,18 @@ class S2SDataset(Dataset): def __getitem__(self, index): ctx_len = self.args.ctx_len req_len = ctx_len + 1 - question, sep, answer = self.data[index] - text = question + sep + answer + question, answer = self.data[index] + question = self.prompt_q + question + answer = self.prompt_a + answer + text = question + self.prompt_sep + answer text = text[:req_len] text = text + [0] * (req_len - len(text)) x = torch.tensor(text[:-1], dtype=torch.long) y = torch.tensor(text[1:], dtype=torch.long) - z = [0] * len(question) + [1] * (ctx_len - len(question)) + q_len = len(question) + len(self.prompt_a) + len(self.prompt_sep) + z = [0] * q_len + [1] * (ctx_len - q_len) z = torch.tensor(z, dtype=torch.long) return x, y, z