提交 ade381c4 编写于 作者: 每日一练社区's avatar 每日一练社区

fix bug

上级 a532f71e
...@@ -517,7 +517,7 @@ class RWKV(pl.LightningModule): ...@@ -517,7 +517,7 @@ class RWKV(pl.LightningModule):
sample_num_times = max(1, seq_len - prompt.shape[-1]) sample_num_times = max(1, seq_len - prompt.shape[-1])
for _ in wrapper_fn(range(sample_num_times)): for _ in wrapper_fn(range(sample_num_times)):
logits, embeds = self.forward(out, return_logits_with_embedding = True, **kwargs) logits, embeds = self.forward(out, ppo_train=True)
logits, embeds = logits[:, -1], embeds[:, -1] logits, embeds = logits[:, -1], embeds[:, -1]
if exists(filter_logits_fn): if exists(filter_logits_fn):
......
...@@ -30,7 +30,7 @@ from src.rlhf.reward import RewardModel ...@@ -30,7 +30,7 @@ from src.rlhf.reward import RewardModel
from src.rlhf.optimizer import get_optimizer from src.rlhf.optimizer import get_optimizer
from src.rlhf.utils import masked_mean, eval_decorator from src.rlhf.utils import masked_mean, eval_decorator
# actor critic - rwkv with lora # actor critic
PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
'actions', 'actions',
...@@ -82,7 +82,6 @@ class ActorCritic(nn.Module): ...@@ -82,7 +82,6 @@ class ActorCritic(nn.Module):
max_seq_len, max_seq_len,
prompt = state, prompt = state,
eos_token = eos_token, eos_token = eos_token,
finetune_scope = self.actor_lora_scope,
use_tqdm = True, use_tqdm = True,
**kwargs **kwargs
) )
...@@ -454,7 +453,7 @@ class RLHF(pl.LightningModule): ...@@ -454,7 +453,7 @@ class RLHF(pl.LightningModule):
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
@torch.no_grad()
def make_experience(self, prompts, eos_token=None, temperature=1): def make_experience(self, prompts, eos_token=None, temperature=1):
''' 通过与 environment 交互产生训练数据 ''' 通过与 environment 交互产生训练数据
''' '''
......
...@@ -283,7 +283,7 @@ if __name__ == "__main__": ...@@ -283,7 +283,7 @@ if __name__ == "__main__":
rwkv.load_state_dict(load_dict) rwkv.load_state_dict(load_dict)
# 加载 reward_model # 加载 reward_model
reward_model = RewardModel(args) reward_model = RewardModel(args, rwkv)
reward_model.load(args.load_rm_model) reward_model.load(args.load_rm_model)
# PPO 模型 # PPO 模型
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册