From da480d5892bf8b7bd983e99df327209ece2bff87 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Wed, 8 Mar 2023 15:50:13 +0800 Subject: [PATCH] add ppo model --- forward_demo.py | 36 ++++++++++++++++++++++++++++++++++++ src/rlhf/ppo.py | 27 +++++++++++++-------------- src/rlhf/reward.py | 1 - train_ppo.py | 43 +++++++++++++++++++++++++++++++++++++++++++ train_rm.py | 8 ++++++++ 5 files changed, 100 insertions(+), 15 deletions(-) create mode 100644 train_ppo.py diff --git a/forward_demo.py b/forward_demo.py index 1608ff5..dd64212 100644 --- a/forward_demo.py +++ b/forward_demo.py @@ -227,6 +227,42 @@ if __name__ == "__main__": from src.model import RWKV model = RWKV(args) + if len(args.load_model) == 0: + rank_zero_info(f"SFT must load model, please input ") + exit(1) + + rank_zero_info(f"########## Loading {args.load_model}... ##########") + try: + load_dict = torch.load(args.load_model, map_location="cpu") + except: + rank_zero_info(f"Bad checkpoint {args.load_model}") + exit(1) + + if args.load_partial == 1: + load_keys = load_dict.keys() + for k in model.state_dict(): + if k not in load_keys: + load_dict[k] = model.state_dict()[k] + model.load_state_dict(load_dict) + + trainer = Trainer.from_argparse_args( + args, + callbacks=[train_callback(args)], + ) + + if trainer.global_rank == 0: + for n in model.state_dict(): + shape = model.state_dict()[n].shape + shape = [i for i in shape if i != 1] + if len(shape) > 1: + print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") + else: + print(f"{str(shape[0]).ljust(5)} {n}") + + if "deepspeed" in args.strategy: + trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + seq = torch.randint(0, 20000, (1, 100)) model(seq) diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index e853f43..e420015 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -20,10 +20,10 @@ from torch.nn.utils.rnn import pad_sequence from einops import rearrange, repeat from einops.layers.torch import Rearrange -from palm_rlhf_pytorch.palm import PaLM -from palm_rlhf_pytorch.reward import RewardModel -from palm_rlhf_pytorch.optimizer import get_optimizer -from palm_rlhf_pytorch.utils import masked_mean, eval_decorator +from src.model import RWKV +from src.rlhf.reward import RewardModel +from src.rlhf.optimizer import get_optimizer +from src.rlhf.utils import masked_mean, eval_decorator from accelerate import Accelerator @@ -42,8 +42,8 @@ PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ class ActorCritic(nn.Module): def __init__( self, - palm: PaLM, - critic_palm: Optional[PaLM] = None, + rwkv: RWKV, + critic_palm: Optional[RWKV] = None, pooled_values = False, actor_lora = True, critic_lora = True, @@ -55,12 +55,12 @@ class ActorCritic(nn.Module): critic_dropout = 0. ): super().__init__() - self.actor_palm = palm + self.actor_palm = rwkv self.critic_palm = critic_palm if not exists(self.critic_palm): - self.critic_palm = copy.deepcopy(palm) + self.critic_palm = copy.deepcopy(rwkv) self.actor_palm.set_dropout(actor_dropout) self.critic_palm.set_dropout(critic_dropout) @@ -79,7 +79,7 @@ class ActorCritic(nn.Module): self.pooled_values = pooled_values self.value_head = nn.Sequential( - nn.Linear(palm.dim, 1), + nn.Linear(rwkv.dim, 1), Rearrange('... 1 -> ...') ) @@ -289,7 +289,7 @@ class RLHFTrainer(nn.Module): prompts_path: Optional[str] = None, prompt_token_ids: Optional[torch.Tensor] = None, tokenizer: Callable = None, - palm: PaLM, + rwkv: RWKV, reward_model: RewardModel, actor_critic: Optional[ActorCritic] = None, actor_lr = 1e-4, @@ -339,12 +339,11 @@ class RLHFTrainer(nn.Module): self.register_buffer('prompt_token_ids', prompt_token_ids) # models - - self.palm = palm + self.rwkv = rwkv if not exists(actor_critic): actor_critic = ActorCritic( - palm = palm, + rwkv = rwkv, actor_lora = actor_lora, critic_lora = critic_lora, actor_lora_r = actor_lora_r, @@ -352,7 +351,7 @@ class RLHFTrainer(nn.Module): pooled_values = critic_pooled_values, actor_dropout = actor_dropout, critic_dropout = critic_dropout - ).to(palm.device) + ).to(rwkv.device) self.actor_critic = actor_critic diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index b264e2d..2f20dc2 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -34,7 +34,6 @@ class RewardModel(nn.Module): # 用预训练模型初始化奖励模型 self.rwkv = copy.deepcopy(rwkv) - self.rwkv.set_dropout(dropout) # todo(luxin) # 输出 token 向量的维度 dim = rwkv.dim # todo(luxin) diff --git a/train_ppo.py b/train_ppo.py new file mode 100644 index 0000000..025d7d9 --- /dev/null +++ b/train_ppo.py @@ -0,0 +1,43 @@ +''' +@File : train_rlhf.py +@Time : 2023/03/08 15:23:19 +@Author : Lu Xin +@Contact : luxin@csdn.net +''' + +# here put the import lib +import torch +from src.model import RWKV +from src.rlhf.reward import RewardModel +from src.rlhf.ppo import RLHFTrainer + +# load your pretrained RWKV +# todo(luxin) 加载 SFT 之后的预训练模型 +rwkv_model = RWKV() +# palm.load('./path/to/pretrained/palm.pt') + +# load your pretrained reward model +# todo(luxin) 加载训练好的 reward Model +reward_model = RewardModel( + rwkv_model, + num_binned_output = 5 +) +# reward_model.load('./path/to/pretrained/reward_model.pt') + +# ready your list of prompts for reinforcement learning +# todo(luxin) 读入 Prompts 数据集(此处的 Prompt 与 SFT、RM 阶段的 Prompt 要不一样) +prompts = torch.randint(0, 256, (50000, 512)) # 50k prompts + +# pass it all to the trainer and train +# 训练 PPO 模型 +trainer = RLHFTrainer( + palm = palm, + reward_model = reward_model, + prompt_token_ids = prompts +) +trainer.train(num_episodes = 100) + +# then, if it succeeded... +# generate say 10 samples and use the reward model to return the best one +answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,) +print(answer) \ No newline at end of file diff --git a/train_rm.py b/train_rm.py index 84caa16..b4c4c61 100644 --- a/train_rm.py +++ b/train_rm.py @@ -1,3 +1,11 @@ +''' +@File : train_rm.py +@Time : 2023/03/08 15:23:29 +@Author : Lu Xin +@Contact : luxin@csdn.net +''' + +# here put the import lib import torch from src.rlhf.reward import RewardModel -- GitLab