From 02e8a2d9e71967215a74b74aa9cc827617cfd69e Mon Sep 17 00:00:00 2001 From: u010280923 Date: Mon, 20 Mar 2023 16:22:17 +0800 Subject: [PATCH] opt ppo model --- src/model.py | 6 ++++++ src/rlhf/ppo.py | 20 +++++++++----------- train_ppo.py | 29 +++++++++-------------------- 3 files changed, 24 insertions(+), 31 deletions(-) diff --git a/src/model.py b/src/model.py index 58aee0d..cae2594 100644 --- a/src/model.py +++ b/src/model.py @@ -11,6 +11,7 @@ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.strategies import DeepSpeedStrategy import deepspeed from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam +from pathlib import Path from tqdm import tqdm from einops import pack @@ -381,6 +382,11 @@ class RWKV(pl.LightningModule): self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False) self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) + def load(self, path): + path = Path(path) + assert path.exists() + self.load_state_dict(torch.load(str(path)), map_location="cpu") + def configure_optimizers(self): args = self.args if args.layerwise_lr > 0: diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index de40a8f..09e343d 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -45,19 +45,16 @@ PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ class ActorCritic(nn.Module): def __init__( self, - rwkv: RWKV, args, - critic: Optional[RWKV] = None, + actor: RWKV, + critic: RWKV, pooled_values = False ): super().__init__() - self.actor = copy.deepcopy(rwkv) + self.actor = actor self.critic = critic - if not exists(self.critic): - self.critic = copy.deepcopy(rwkv) - self.pooled_values = pooled_values self.value_head = nn.Sequential( nn.Linear(args.n_embd, 1), @@ -242,20 +239,21 @@ class RLHF(pl.LightningModule): def __init__( self, args, - rwkv: RWKV, + actor: RWKV, + critic: RWKV, reward_model: RewardModel ): super().__init__() self.args = args - self.rwkv = rwkv # 使用 RWKV 初始化 actor_critic actor_critic = ActorCritic( - rwkv = self.rwkv, - args = self.args, + args=self.args, + actor=actor, + critic=critic, pooled_values = args.critic_pooled_values - ).to(self.rwkv.device) + ).to(actor.device) self.actor_critic = actor_critic diff --git a/train_ppo.py b/train_ppo.py index 028cc03..e8c8bae 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -261,33 +261,22 @@ if __name__ == "__main__": # 读入训练数据集 prompts = load_prompt_data_4_ppo(args) - # 加载 RWKV 模型 - rwkv = RWKV(args) - - if len(args.load_sft_model) == 0: - rank_zero_info(f"SFT must load model, please input ") - exit(1) + # 用 rwkv 初始化 actor 模型 + actor = RWKV(args) + actor.load(args.load_sft_model) - rank_zero_info(f"########## Loading {args.load_sft_model}... ##########") - try: - load_dict = torch.load(args.load_sft_model, map_location="cpu") - except: - rank_zero_info(f"Bad checkpoint {args.load_sft_model}") - exit(1) - - if args.load_partial == 1: - load_keys = load_dict.keys() - for k in rwkv.state_dict(): - if k not in load_keys: - load_dict[k] = rwkv.state_dict()[k] - rwkv.load_state_dict(load_dict) + # 用 rwkv 初始化 critic 模型 + critic = RWKV(args) + critic.load(args.load_sft_model) # 加载 reward_model + rwkv = RWKV(args) + rwkv.load(args.load_sft_model) reward_model = RewardModel(args, rwkv) reward_model.load(args.load_rm_model) # PPO 模型 - rlhf_model = RLHF(args, rwkv, reward_model) + rlhf_model = RLHF(args, actor, critic, reward_model) # 模型训练 # trainer -- GitLab