提交 63b543cf 编写于 作者: CSDN-Ada助手's avatar CSDN-Ada助手

fix bug

上级 d19266ad
......@@ -43,10 +43,28 @@ python train.py --load_model "rwkv-80.pth" --wandb "" --proj_dir "out" \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1
```
## 接入Instruction Tuning
使用指令数据集进行监督训练,精调语言模型,指令数据集格式为句子对。这部分数据需要由开发人员来进行编写,有的语料需要涉及到推理过程。
```
python train_sft.py --load_model "rwkv-100.pth" --wandb "" --proj_dir "out_sft" \
--data_file "data/prompts.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
--micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \
--my_qa_mask 1
```
## TODO
### 接入Instruction Tuning
### Reward Model
### 接入RLHF(Reinforcement Learning with Human Feedback)
......@@ -243,14 +243,17 @@ class S2SDataset(Dataset):
def __getitem__(self, index):
ctx_len = self.args.ctx_len
req_len = ctx_len + 1
question, answer = self.data[index]
text = question + answer
text = text[:req_len]
text = text[:ctx_len]
text = text + [0] * (ctx_len - len(text))
x = torch.tensor(text, dtype=torch.long)
answer = answer + [0] * (ctx_len - len(answer))
y = torch.tensor(answer, dtype=torch.long)
z = [1] * len(question) + [0] * (req_len - len(question))
z = [1] * len(question) + [0] * (ctx_len - len(question))
z = torch.tensor(z, dtype=torch.long)
return x, y, z
......@@ -117,7 +117,7 @@ class train_callback(pl.Callback):
def on_train_epoch_start(self, trainer, pl_module):
args = self.args
dataset = trainer.train_dataloader.dataset.datasets
assert "MyDataset" in str(dataset)
assert "MyDataset" in str(dataset) or "S2SDataset" in str(dataset)
dataset.global_rank = trainer.global_rank
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
dataset.world_size = trainer.world_size
......
......@@ -135,6 +135,7 @@ if __name__ == "__main__":
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
# os.environ["WDS_SHOW_SEED"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
args.enable_checkpointing = False
......@@ -264,7 +265,7 @@ if __name__ == "__main__":
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
# must set shuffle=True, persistent_workers=False (because worker is in another thread)
data_loader = DataLoader(train_data, shuffle=True, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
trainer.fit(model, data_loader)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册