提交 30ccef27 编写于 作者: U u010280923

update ppo model

上级 ba2760dc
...@@ -29,8 +29,6 @@ from src.rlhf.reward import RewardModel ...@@ -29,8 +29,6 @@ from src.rlhf.reward import RewardModel
from src.rlhf.optimizer import get_optimizer from src.rlhf.optimizer import get_optimizer
from src.rlhf.utils import masked_mean, eval_decorator from src.rlhf.utils import masked_mean, eval_decorator
from accelerate import Accelerator
# actor critic - rwkv with lora # actor critic - rwkv with lora
PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
...@@ -254,15 +252,12 @@ def clipped_value_loss(values, rewards, old_values, clip): ...@@ -254,15 +252,12 @@ def clipped_value_loss(values, rewards, old_values, clip):
class RLHF(nn.Module): class RLHF(nn.Module):
def __init__( def __init__(
self, self,
args, args
accelerate_kwargs: dict = {}
): ):
super().__init__() super().__init__()
self.args = args self.args = args
self.accelerate = Accelerator(**accelerate_kwargs)
# 加载 RWKV 模型 # 加载 RWKV 模型
rwkv = RWKV(args) rwkv = RWKV(args)
...@@ -299,19 +294,12 @@ class RLHF(nn.Module): ...@@ -299,19 +294,12 @@ class RLHF(nn.Module):
reward_model.load(args.load_rm_model) reward_model.load(args.load_rm_model)
self.reward_model = reward_model.eval() self.reward_model = reward_model.eval()
def print(self, msg):
return self.accelerate.print(msg)
def save(self, filepath = './checkpoint.pt'): def save(self, filepath = './checkpoint.pt'):
torch.save(self.actor_critic.state_dict(), filepath) torch.save(self.actor_critic.state_dict(), filepath)
def load(self, filepath = './checkpoint.pt'): def load(self, filepath = './checkpoint.pt'):
state_dict = torch.load(filepath) state_dict = torch.load(filepath)
self.actor_critic.load_state_dict(state_dict) self.actor_critic.load_state_dict(state_dict)
@property
def device(self):
return self.accelerate.device
def configure_optimizers(self): def configure_optimizers(self):
args = self.args args = self.args
...@@ -383,11 +371,7 @@ class RLHF(nn.Module): ...@@ -383,11 +371,7 @@ class RLHF(nn.Module):
assert prompt.ndim == 1, 'only one prompt allowed at a time for now' assert prompt.ndim == 1, 'only one prompt allowed at a time for now'
prompt = repeat(prompt, 'n -> b n', b = num_samples) prompt = repeat(prompt, 'n -> b n', b = num_samples)
actor_critic = self.accelerate.unwrap_model(self.actor_critic) self.actor_critic.eval()
reward_model = self.accelerate.unwrap_model(self.reward_model)
actor_critic.eval()
( (
actions, actions,
sequences, sequences,
...@@ -395,7 +379,7 @@ class RLHF(nn.Module): ...@@ -395,7 +379,7 @@ class RLHF(nn.Module):
prompt_mask, prompt_mask,
action_logits, action_logits,
_ _
) = actor_critic.generate( ) = self.actor_critic.generate(
prompt, prompt,
*args, *args,
max_seq_len = max_seq_len, max_seq_len = max_seq_len,
...@@ -403,7 +387,7 @@ class RLHF(nn.Module): ...@@ -403,7 +387,7 @@ class RLHF(nn.Module):
**kwargs **kwargs
) )
rewards = reward_model( rewards = self.reward_model(
sequences, sequences,
prompt_mask = prompt_mask, prompt_mask = prompt_mask,
mask = mask, mask = mask,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册