From 63b543cfa6ce082711303229eddf1f22fb8a5308 Mon Sep 17 00:00:00 2001 From: chenlong Date: Fri, 3 Mar 2023 16:53:22 +0800 Subject: [PATCH] fix bug --- README.md | 20 +++++++++++++++++++- src/dataset.py | 9 ++++++--- src/trainer.py | 2 +- train_sft.py | 5 +++-- 4 files changed, 29 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index c1db3a5..6ed1d0d 100644 --- a/README.md +++ b/README.md @@ -43,10 +43,28 @@ python train.py --load_model "rwkv-80.pth" --wandb "" --proj_dir "out" \ --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 ``` + +## 接入Instruction Tuning + +使用指令数据集进行监督训练,精调语言模型,指令数据集格式为句子对。这部分数据需要由开发人员来进行编写,有的语料需要涉及到推理过程。 + +``` +python train_sft.py --load_model "rwkv-100.pth" --wandb "" --proj_dir "out_sft" \ +--data_file "data/prompts.csv" --data_type "utf-8" --vocab_size 50277 \ +--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \ +--micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ +--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ +--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \ +--my_qa_mask 1 +``` + + ## TODO -### 接入Instruction Tuning +### Reward Model + ### 接入RLHF(Reinforcement Learning with Human Feedback) + diff --git a/src/dataset.py b/src/dataset.py index aa8a65d..8057d6e 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -243,14 +243,17 @@ class S2SDataset(Dataset): def __getitem__(self, index): ctx_len = self.args.ctx_len - req_len = ctx_len + 1 question, answer = self.data[index] text = question + answer - text = text[:req_len] + text = text[:ctx_len] + + text = text + [0] * (ctx_len - len(text)) x = torch.tensor(text, dtype=torch.long) + + answer = answer + [0] * (ctx_len - len(answer)) y = torch.tensor(answer, dtype=torch.long) - z = [1] * len(question) + [0] * (req_len - len(question)) + z = [1] * len(question) + [0] * (ctx_len - len(question)) z = torch.tensor(z, dtype=torch.long) return x, y, z diff --git a/src/trainer.py b/src/trainer.py index 89407f0..e514d65 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -117,7 +117,7 @@ class train_callback(pl.Callback): def on_train_epoch_start(self, trainer, pl_module): args = self.args dataset = trainer.train_dataloader.dataset.datasets - assert "MyDataset" in str(dataset) + assert "MyDataset" in str(dataset) or "S2SDataset" in str(dataset) dataset.global_rank = trainer.global_rank dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch) dataset.world_size = trainer.world_size diff --git a/train_sft.py b/train_sft.py index fa06b1d..51275db 100644 --- a/train_sft.py +++ b/train_sft.py @@ -135,6 +135,7 @@ if __name__ == "__main__": warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*") # os.environ["WDS_SHOW_SEED"] = "1" + os.environ["TOKENIZERS_PARALLELISM"] = "false" args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") args.enable_checkpointing = False @@ -264,7 +265,7 @@ if __name__ == "__main__": trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 - # must set shuffle=False, persistent_workers=False (because worker is in another thread) - data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) + # must set shuffle=True, persistent_workers=False (because worker is in another thread) + data_loader = DataLoader(train_data, shuffle=True, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) trainer.fit(model, data_loader) -- GitLab