From 7bb01b99c46a2ba5895bfb67ece907eb6f58772e Mon Sep 17 00:00:00 2001 From: chenlong Date: Thu, 23 Mar 2023 21:38:03 +0800 Subject: [PATCH] =?UTF-8?q?sft=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- src/dataset.py | 16 +++++----------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 5984db4..96054cf 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ python train.py --load_model "rwkv-80.pth" --wandb "" --proj_dir "out" \ ``` python train_sft.py --load_model "rwkv-500.pth" --wandb "" --proj_dir "out_sft" \ --data_file "data/prompts.csv" --data_type "utf-8" --vocab_size 50277 \ ---ctx_len 2048 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \ +--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \ --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \ diff --git a/src/dataset.py b/src/dataset.py index be66b9f..2a0e78a 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -233,17 +233,14 @@ class S2SDataset(Dataset): ] # [vocab, vocab] for Pile model self.tokenizer = TOKENIZER(WORD_NAME) - - self.prompt_q = self.tokenizer.tokenizer.encode("Q:") - self.prompt_a = self.tokenizer.tokenizer.encode("A:") - self.prompt_sep = self.tokenizer.tokenizer.encode("\n\n") - pf = pd.read_csv(args.data_file) data_list = [] + for index, row in pf.iterrows(): question = row["question"] answer = row["answer"] data_list.append((self.tokenizer.tokenizer.encode(question), + self.tokenizer.tokenizer.encode("\n"), self.tokenizer.tokenizer.encode(answer))) self.data = data_list @@ -253,18 +250,15 @@ class S2SDataset(Dataset): def __getitem__(self, index): ctx_len = self.args.ctx_len req_len = ctx_len + 1 - question, answer = self.data[index] - question = self.prompt_q + question - answer = self.prompt_a + answer - text = question + self.prompt_sep + answer + question, sep, answer = self.data[index] + text = question + sep + answer text = text[:req_len] text = text + [0] * (req_len - len(text)) x = torch.tensor(text[:-1], dtype=torch.long) y = torch.tensor(text[1:], dtype=torch.long) - q_len = len(question) + len(self.prompt_a) + len(self.prompt_sep) - z = [0] * q_len + [1] * (ctx_len - q_len) + z = [0] * len(question) + [1] * (ctx_len - len(question)) z = torch.tensor(z, dtype=torch.long) return x, y, z -- GitLab