diff --git a/src/model.py b/src/model.py index c8ad6b6fa71811b7d9430871a38b7ff3ff409231..8a61e4b0c2c180789f9e10e021e283898c0abc9f 100644 --- a/src/model.py +++ b/src/model.py @@ -517,7 +517,7 @@ class RWKV(pl.LightningModule): sample_num_times = max(1, seq_len - prompt.shape[-1]) for _ in wrapper_fn(range(sample_num_times)): - logits, embeds = self.forward(out, return_logits_with_embedding = True, **kwargs) + logits, embeds = self.forward(out, ppo_train=True) logits, embeds = logits[:, -1], embeds[:, -1] if exists(filter_logits_fn): diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index 4e0c7d5896211e918ec3fad4fa292c0aa3a43211..50f9d76555e91b5dc993e40bc8f998478f5c7c14 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -30,7 +30,7 @@ from src.rlhf.reward import RewardModel from src.rlhf.optimizer import get_optimizer from src.rlhf.utils import masked_mean, eval_decorator -# actor critic - rwkv with lora +# actor critic PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ 'actions', @@ -82,7 +82,6 @@ class ActorCritic(nn.Module): max_seq_len, prompt = state, eos_token = eos_token, - finetune_scope = self.actor_lora_scope, use_tqdm = True, **kwargs ) @@ -454,7 +453,7 @@ class RLHF(pl.LightningModule): return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} - + @torch.no_grad() def make_experience(self, prompts, eos_token=None, temperature=1): ''' 通过与 environment 交互产生训练数据 ''' diff --git a/train_ppo.py b/train_ppo.py index fd8968be02b8bb95b7f5c2a0535e0d8eca3f9bd4..98bc3d55a889523823b6bd679f2919949e99de6b 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -283,7 +283,7 @@ if __name__ == "__main__": rwkv.load_state_dict(load_dict) # 加载 reward_model - reward_model = RewardModel(args) + reward_model = RewardModel(args, rwkv) reward_model.load(args.load_rm_model) # PPO 模型