From 30ccef278f093767f604a5daa8d8a1e8d76b1dc4 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 17 Mar 2023 17:44:27 +0800 Subject: [PATCH] update ppo model --- src/rlhf/ppo.py | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index faf072c..3fcbd7e 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -29,8 +29,6 @@ from src.rlhf.reward import RewardModel from src.rlhf.optimizer import get_optimizer from src.rlhf.utils import masked_mean, eval_decorator -from accelerate import Accelerator - # actor critic - rwkv with lora PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ @@ -254,15 +252,12 @@ def clipped_value_loss(values, rewards, old_values, clip): class RLHF(nn.Module): def __init__( self, - args, - accelerate_kwargs: dict = {} + args ): super().__init__() self.args = args - self.accelerate = Accelerator(**accelerate_kwargs) - # 加载 RWKV 模型 rwkv = RWKV(args) @@ -299,19 +294,12 @@ class RLHF(nn.Module): reward_model.load(args.load_rm_model) self.reward_model = reward_model.eval() - def print(self, msg): - return self.accelerate.print(msg) - def save(self, filepath = './checkpoint.pt'): torch.save(self.actor_critic.state_dict(), filepath) def load(self, filepath = './checkpoint.pt'): state_dict = torch.load(filepath) self.actor_critic.load_state_dict(state_dict) - - @property - def device(self): - return self.accelerate.device def configure_optimizers(self): args = self.args @@ -383,11 +371,7 @@ class RLHF(nn.Module): assert prompt.ndim == 1, 'only one prompt allowed at a time for now' prompt = repeat(prompt, 'n -> b n', b = num_samples) - actor_critic = self.accelerate.unwrap_model(self.actor_critic) - reward_model = self.accelerate.unwrap_model(self.reward_model) - - actor_critic.eval() - + self.actor_critic.eval() ( actions, sequences, @@ -395,7 +379,7 @@ class RLHF(nn.Module): prompt_mask, action_logits, _ - ) = actor_critic.generate( + ) = self.actor_critic.generate( prompt, *args, max_seq_len = max_seq_len, @@ -403,7 +387,7 @@ class RLHF(nn.Module): **kwargs ) - rewards = reward_model( + rewards = self.reward_model( sequences, prompt_mask = prompt_mask, mask = mask, -- GitLab