From 99595b68a10ae913b72da10283594c1bf8bce3e2 Mon Sep 17 00:00:00 2001 From: chenlong Date: Thu, 23 Mar 2023 14:00:13 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9sft?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 4 ++-- src/dataset.py | 16 ++++++++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 5456dcc..5984db4 100644 --- a/README.md +++ b/README.md @@ -49,9 +49,9 @@ python train.py --load_model "rwkv-80.pth" --wandb "" --proj_dir "out" \ 使用指令数据集进行监督训练,精调语言模型,指令数据集格式为句子对。这部分数据需要由开发人员来进行编写,有的语料需要涉及到推理过程。 ``` -python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft" \ +python train_sft.py --load_model "rwkv-500.pth" --wandb "" --proj_dir "out_sft" \ --data_file "data/prompts.csv" --data_type "utf-8" --vocab_size 50277 \ ---ctx_len 2048 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \ +--ctx_len 2048 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \ --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \ diff --git a/src/dataset.py b/src/dataset.py index c492408..be66b9f 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -233,12 +233,17 @@ class S2SDataset(Dataset): ] # [vocab, vocab] for Pile model self.tokenizer = TOKENIZER(WORD_NAME) + + self.prompt_q = self.tokenizer.tokenizer.encode("Q:") + self.prompt_a = self.tokenizer.tokenizer.encode("A:") + self.prompt_sep = self.tokenizer.tokenizer.encode("\n\n") + pf = pd.read_csv(args.data_file) data_list = [] for index, row in pf.iterrows(): question = row["question"] answer = row["answer"] - data_list.append((self.tokenizer.tokenizer.encode(question), self.tokenizer.tokenizer.encode("\n"), + data_list.append((self.tokenizer.tokenizer.encode(question), self.tokenizer.tokenizer.encode(answer))) self.data = data_list @@ -248,15 +253,18 @@ class S2SDataset(Dataset): def __getitem__(self, index): ctx_len = self.args.ctx_len req_len = ctx_len + 1 - question, sep, answer = self.data[index] - text = question + sep + answer + question, answer = self.data[index] + question = self.prompt_q + question + answer = self.prompt_a + answer + text = question + self.prompt_sep + answer text = text[:req_len] text = text + [0] * (req_len - len(text)) x = torch.tensor(text[:-1], dtype=torch.long) y = torch.tensor(text[1:], dtype=torch.long) - z = [0] * len(question) + [1] * (ctx_len - len(question)) + q_len = len(question) + len(self.prompt_a) + len(self.prompt_sep) + z = [0] * q_len + [1] * (ctx_len - q_len) z = torch.tensor(z, dtype=torch.long) return x, y, z -- GitLab