From a532f71e397ebecd9e3afe8a2e1fd613e26d9fbb Mon Sep 17 00:00:00 2001 From: u010280923 Date: Mon, 20 Mar 2023 14:03:16 +0800 Subject: [PATCH] opt ppo model --- README.md | 2 +- src/rlhf/ppo.py | 31 +++++-------------------------- train_ppo.py | 29 ++++++++++++++++++++++++++++- 3 files changed, 34 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index 94f50da..5456dcc 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft" ### Reward Model ``` -python train_rm.py --load_model "./out_sft/rwkv-190.pth" --wandb "" --proj_dir "out_rm" \ +python train_rm.py --load_sft_model "./out_sft/rwkv-190.pth" --wandb "" --proj_dir "out_rm" \ --data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \ --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \ --micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index 41137e0..4e0c7d5 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -20,6 +20,7 @@ from torch.optim import Adam from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence +import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.strategies import DeepSpeedStrategy from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam @@ -253,33 +254,13 @@ def clipped_value_loss(values, rewards, old_values, clip): class RLHF(pl.LightningModule): def __init__( self, - args + args, + rwkv: RWKV, + reward_model: RewardModel ): super().__init__() self.args = args - - # 加载 RWKV 模型 - rwkv = RWKV(args) - - if len(args.load_sft_model) == 0: - rank_zero_info(f"SFT must load model, please input ") - exit(1) - - rank_zero_info(f"########## Loading {args.load_sft_model}... ##########") - try: - load_dict = torch.load(args.load_sft_model, map_location="cpu") - except: - rank_zero_info(f"Bad checkpoint {args.load_sft_model}") - exit(1) - - if args.load_partial == 1: - load_keys = load_dict.keys() - for k in rwkv.state_dict(): - if k not in load_keys: - load_dict[k] = rwkv.state_dict()[k] - rwkv.load_state_dict(load_dict) - self.rwkv = rwkv # 使用 RWKV 初始化 actor_critic @@ -291,9 +272,7 @@ class RLHF(pl.LightningModule): self.actor_critic = actor_critic - # 加载 reward_model,并将 reward_model 设置为 evaluation 模式 - reward_model = RewardModel(args) - reward_model.load(args.load_rm_model) + # 将 reward_model 设置为 evaluation 模式 self.reward_model = reward_model.eval() def save(self, filepath = './checkpoint.pt'): diff --git a/train_ppo.py b/train_ppo.py index 951f964..fd8968b 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -252,6 +252,8 @@ if __name__ == "__main__": from src.dataset import PPODataset, load_prompt_data_4_ppo from src.rlhf.ppo import RLHF from src.trainer import rlhf_train_callback + from src.model import RWKV + from src.rlhf.reward import RewardModel # 用于 PPO 训练的数据,需要与 environment 交互获得 memory = [] @@ -259,8 +261,33 @@ if __name__ == "__main__": # 读入训练数据集 prompts = load_prompt_data_4_ppo(args) + # 加载 RWKV 模型 + rwkv = RWKV(args) + + if len(args.load_sft_model) == 0: + rank_zero_info(f"SFT must load model, please input ") + exit(1) + + rank_zero_info(f"########## Loading {args.load_sft_model}... ##########") + try: + load_dict = torch.load(args.load_sft_model, map_location="cpu") + except: + rank_zero_info(f"Bad checkpoint {args.load_sft_model}") + exit(1) + + if args.load_partial == 1: + load_keys = load_dict.keys() + for k in rwkv.state_dict(): + if k not in load_keys: + load_dict[k] = rwkv.state_dict()[k] + rwkv.load_state_dict(load_dict) + + # 加载 reward_model + reward_model = RewardModel(args) + reward_model.load(args.load_rm_model) + # PPO 模型 - rlhf_model = RLHF(args) + rlhf_model = RLHF(args, rwkv, reward_model) # 模型训练 # trainer -- GitLab