From ade381c443b31fd0ec75b5b864b9fa5e31e88529 Mon Sep 17 00:00:00 2001 From: zhangzc Date: Mon, 20 Mar 2023 14:22:01 +0800 Subject: [PATCH] fix bug --- src/model.py | 2 +- src/rlhf/ppo.py | 5 ++--- train_ppo.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/model.py b/src/model.py index c8ad6b6..8a61e4b 100644 --- a/src/model.py +++ b/src/model.py @@ -517,7 +517,7 @@ class RWKV(pl.LightningModule): sample_num_times = max(1, seq_len - prompt.shape[-1]) for _ in wrapper_fn(range(sample_num_times)): - logits, embeds = self.forward(out, return_logits_with_embedding = True, **kwargs) + logits, embeds = self.forward(out, ppo_train=True) logits, embeds = logits[:, -1], embeds[:, -1] if exists(filter_logits_fn): diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index 4e0c7d5..50f9d76 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -30,7 +30,7 @@ from src.rlhf.reward import RewardModel from src.rlhf.optimizer import get_optimizer from src.rlhf.utils import masked_mean, eval_decorator -# actor critic - rwkv with lora +# actor critic PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ 'actions', @@ -82,7 +82,6 @@ class ActorCritic(nn.Module): max_seq_len, prompt = state, eos_token = eos_token, - finetune_scope = self.actor_lora_scope, use_tqdm = True, **kwargs ) @@ -454,7 +453,7 @@ class RLHF(pl.LightningModule): return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} - + @torch.no_grad() def make_experience(self, prompts, eos_token=None, temperature=1): ''' 通过与 environment 交互产生训练数据 ''' diff --git a/train_ppo.py b/train_ppo.py index fd8968b..98bc3d5 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -283,7 +283,7 @@ if __name__ == "__main__": rwkv.load_state_dict(load_dict) # 加载 reward_model - reward_model = RewardModel(args) + reward_model = RewardModel(args, rwkv) reward_model.load(args.load_rm_model) # PPO 模型 -- GitLab