提交 ade381c4 编写于 作者: 每日一练社区's avatar 每日一练社区

fix bug

上级 a532f71e
......@@ -517,7 +517,7 @@ class RWKV(pl.LightningModule):
sample_num_times = max(1, seq_len - prompt.shape[-1])
for _ in wrapper_fn(range(sample_num_times)):
logits, embeds = self.forward(out, return_logits_with_embedding = True, **kwargs)
logits, embeds = self.forward(out, ppo_train=True)
logits, embeds = logits[:, -1], embeds[:, -1]
if exists(filter_logits_fn):
......
......@@ -30,7 +30,7 @@ from src.rlhf.reward import RewardModel
from src.rlhf.optimizer import get_optimizer
from src.rlhf.utils import masked_mean, eval_decorator
# actor critic - rwkv with lora
# actor critic
PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
'actions',
......@@ -82,7 +82,6 @@ class ActorCritic(nn.Module):
max_seq_len,
prompt = state,
eos_token = eos_token,
finetune_scope = self.actor_lora_scope,
use_tqdm = True,
**kwargs
)
......@@ -454,7 +453,7 @@ class RLHF(pl.LightningModule):
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
@torch.no_grad()
def make_experience(self, prompts, eos_token=None, temperature=1):
''' 通过与 environment 交互产生训练数据
'''
......
......@@ -283,7 +283,7 @@ if __name__ == "__main__":
rwkv.load_state_dict(load_dict)
# 加载 reward_model
reward_model = RewardModel(args)
reward_model = RewardModel(args, rwkv)
reward_model.load(args.load_rm_model)
# PPO 模型
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册