diff --git a/forward_demo.py b/forward_demo.py index 1608ff52dcc51a25eb9ca411febeb70865ca0bca..dd64212ed4c8a3e786bad1eec85f8b5e73d8bb51 100644 --- a/forward_demo.py +++ b/forward_demo.py @@ -227,6 +227,42 @@ if __name__ == "__main__": from src.model import RWKV model = RWKV(args) + if len(args.load_model) == 0: + rank_zero_info(f"SFT must load model, please input ") + exit(1) + + rank_zero_info(f"########## Loading {args.load_model}... ##########") + try: + load_dict = torch.load(args.load_model, map_location="cpu") + except: + rank_zero_info(f"Bad checkpoint {args.load_model}") + exit(1) + + if args.load_partial == 1: + load_keys = load_dict.keys() + for k in model.state_dict(): + if k not in load_keys: + load_dict[k] = model.state_dict()[k] + model.load_state_dict(load_dict) + + trainer = Trainer.from_argparse_args( + args, + callbacks=[train_callback(args)], + ) + + if trainer.global_rank == 0: + for n in model.state_dict(): + shape = model.state_dict()[n].shape + shape = [i for i in shape if i != 1] + if len(shape) > 1: + print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") + else: + print(f"{str(shape[0]).ljust(5)} {n}") + + if "deepspeed" in args.strategy: + trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + seq = torch.randint(0, 20000, (1, 100)) model(seq) diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index e853f43a588ee4413739857d7876686c245fae52..e420015cbe0460c67e9c4d85cc062d48de94bf6f 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -20,10 +20,10 @@ from torch.nn.utils.rnn import pad_sequence from einops import rearrange, repeat from einops.layers.torch import Rearrange -from palm_rlhf_pytorch.palm import PaLM -from palm_rlhf_pytorch.reward import RewardModel -from palm_rlhf_pytorch.optimizer import get_optimizer -from palm_rlhf_pytorch.utils import masked_mean, eval_decorator +from src.model import RWKV +from src.rlhf.reward import RewardModel +from src.rlhf.optimizer import get_optimizer +from src.rlhf.utils import masked_mean, eval_decorator from accelerate import Accelerator @@ -42,8 +42,8 @@ PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ class ActorCritic(nn.Module): def __init__( self, - palm: PaLM, - critic_palm: Optional[PaLM] = None, + rwkv: RWKV, + critic_palm: Optional[RWKV] = None, pooled_values = False, actor_lora = True, critic_lora = True, @@ -55,12 +55,12 @@ class ActorCritic(nn.Module): critic_dropout = 0. ): super().__init__() - self.actor_palm = palm + self.actor_palm = rwkv self.critic_palm = critic_palm if not exists(self.critic_palm): - self.critic_palm = copy.deepcopy(palm) + self.critic_palm = copy.deepcopy(rwkv) self.actor_palm.set_dropout(actor_dropout) self.critic_palm.set_dropout(critic_dropout) @@ -79,7 +79,7 @@ class ActorCritic(nn.Module): self.pooled_values = pooled_values self.value_head = nn.Sequential( - nn.Linear(palm.dim, 1), + nn.Linear(rwkv.dim, 1), Rearrange('... 1 -> ...') ) @@ -289,7 +289,7 @@ class RLHFTrainer(nn.Module): prompts_path: Optional[str] = None, prompt_token_ids: Optional[torch.Tensor] = None, tokenizer: Callable = None, - palm: PaLM, + rwkv: RWKV, reward_model: RewardModel, actor_critic: Optional[ActorCritic] = None, actor_lr = 1e-4, @@ -339,12 +339,11 @@ class RLHFTrainer(nn.Module): self.register_buffer('prompt_token_ids', prompt_token_ids) # models - - self.palm = palm + self.rwkv = rwkv if not exists(actor_critic): actor_critic = ActorCritic( - palm = palm, + rwkv = rwkv, actor_lora = actor_lora, critic_lora = critic_lora, actor_lora_r = actor_lora_r, @@ -352,7 +351,7 @@ class RLHFTrainer(nn.Module): pooled_values = critic_pooled_values, actor_dropout = actor_dropout, critic_dropout = critic_dropout - ).to(palm.device) + ).to(rwkv.device) self.actor_critic = actor_critic diff --git a/src/rlhf/reward.py b/src/rlhf/reward.py index b264e2db164192e6f61f4e307b891b3a60d8d6b0..2f20dc28659e8475eb0f904b6988e6af0f7f474f 100644 --- a/src/rlhf/reward.py +++ b/src/rlhf/reward.py @@ -34,7 +34,6 @@ class RewardModel(nn.Module): # 用预训练模型初始化奖励模型 self.rwkv = copy.deepcopy(rwkv) - self.rwkv.set_dropout(dropout) # todo(luxin) # 输出 token 向量的维度 dim = rwkv.dim # todo(luxin) diff --git a/train_ppo.py b/train_ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..025d7d95fc5c1f29dfa89ba237d480b54ac6b964 --- /dev/null +++ b/train_ppo.py @@ -0,0 +1,43 @@ +''' +@File : train_rlhf.py +@Time : 2023/03/08 15:23:19 +@Author : Lu Xin +@Contact : luxin@csdn.net +''' + +# here put the import lib +import torch +from src.model import RWKV +from src.rlhf.reward import RewardModel +from src.rlhf.ppo import RLHFTrainer + +# load your pretrained RWKV +# todo(luxin) 加载 SFT 之后的预训练模型 +rwkv_model = RWKV() +# palm.load('./path/to/pretrained/palm.pt') + +# load your pretrained reward model +# todo(luxin) 加载训练好的 reward Model +reward_model = RewardModel( + rwkv_model, + num_binned_output = 5 +) +# reward_model.load('./path/to/pretrained/reward_model.pt') + +# ready your list of prompts for reinforcement learning +# todo(luxin) 读入 Prompts 数据集(此处的 Prompt 与 SFT、RM 阶段的 Prompt 要不一样) +prompts = torch.randint(0, 256, (50000, 512)) # 50k prompts + +# pass it all to the trainer and train +# 训练 PPO 模型 +trainer = RLHFTrainer( + palm = palm, + reward_model = reward_model, + prompt_token_ids = prompts +) +trainer.train(num_episodes = 100) + +# then, if it succeeded... +# generate say 10 samples and use the reward model to return the best one +answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,) +print(answer) \ No newline at end of file diff --git a/train_rm.py b/train_rm.py index 84caa16671dbec6daac704c27fb9c1487e605eea..b4c4c61e312eff565f4f379c32ba4ff848c4cde7 100644 --- a/train_rm.py +++ b/train_rm.py @@ -1,3 +1,11 @@ +''' +@File : train_rm.py +@Time : 2023/03/08 15:23:29 +@Author : Lu Xin +@Contact : luxin@csdn.net +''' + +# here put the import lib import torch from src.rlhf.reward import RewardModel