提交 30ccef27 编写于 作者: U u010280923

update ppo model

上级 ba2760dc
......@@ -29,8 +29,6 @@ from src.rlhf.reward import RewardModel
from src.rlhf.optimizer import get_optimizer
from src.rlhf.utils import masked_mean, eval_decorator
from accelerate import Accelerator
# actor critic - rwkv with lora
PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
......@@ -254,15 +252,12 @@ def clipped_value_loss(values, rewards, old_values, clip):
class RLHF(nn.Module):
def __init__(
self,
args,
accelerate_kwargs: dict = {}
args
):
super().__init__()
self.args = args
self.accelerate = Accelerator(**accelerate_kwargs)
# 加载 RWKV 模型
rwkv = RWKV(args)
......@@ -299,19 +294,12 @@ class RLHF(nn.Module):
reward_model.load(args.load_rm_model)
self.reward_model = reward_model.eval()
def print(self, msg):
return self.accelerate.print(msg)
def save(self, filepath = './checkpoint.pt'):
torch.save(self.actor_critic.state_dict(), filepath)
def load(self, filepath = './checkpoint.pt'):
state_dict = torch.load(filepath)
self.actor_critic.load_state_dict(state_dict)
@property
def device(self):
return self.accelerate.device
def configure_optimizers(self):
args = self.args
......@@ -383,11 +371,7 @@ class RLHF(nn.Module):
assert prompt.ndim == 1, 'only one prompt allowed at a time for now'
prompt = repeat(prompt, 'n -> b n', b = num_samples)
actor_critic = self.accelerate.unwrap_model(self.actor_critic)
reward_model = self.accelerate.unwrap_model(self.reward_model)
actor_critic.eval()
self.actor_critic.eval()
(
actions,
sequences,
......@@ -395,7 +379,7 @@ class RLHF(nn.Module):
prompt_mask,
action_logits,
_
) = actor_critic.generate(
) = self.actor_critic.generate(
prompt,
*args,
max_seq_len = max_seq_len,
......@@ -403,7 +387,7 @@ class RLHF(nn.Module):
**kwargs
)
rewards = reward_model(
rewards = self.reward_model(
sequences,
prompt_mask = prompt_mask,
mask = mask,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册