提交 da480d58 编写于 作者: U u010280923

add ppo model

上级 5f0fb2fd
......@@ -227,6 +227,42 @@ if __name__ == "__main__":
from src.model import RWKV
model = RWKV(args)
if len(args.load_model) == 0:
rank_zero_info(f"SFT must load model, please input ")
exit(1)
rank_zero_info(f"########## Loading {args.load_model}... ##########")
try:
load_dict = torch.load(args.load_model, map_location="cpu")
except:
rank_zero_info(f"Bad checkpoint {args.load_model}")
exit(1)
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in model.state_dict():
if k not in load_keys:
load_dict[k] = model.state_dict()[k]
model.load_state_dict(load_dict)
trainer = Trainer.from_argparse_args(
args,
callbacks=[train_callback(args)],
)
if trainer.global_rank == 0:
for n in model.state_dict():
shape = model.state_dict()[n].shape
shape = [i for i in shape if i != 1]
if len(shape) > 1:
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
else:
print(f"{str(shape[0]).ljust(5)} {n}")
if "deepspeed" in args.strategy:
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
seq = torch.randint(0, 20000, (1, 100))
model(seq)
......
......@@ -20,10 +20,10 @@ from torch.nn.utils.rnn import pad_sequence
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from palm_rlhf_pytorch.palm import PaLM
from palm_rlhf_pytorch.reward import RewardModel
from palm_rlhf_pytorch.optimizer import get_optimizer
from palm_rlhf_pytorch.utils import masked_mean, eval_decorator
from src.model import RWKV
from src.rlhf.reward import RewardModel
from src.rlhf.optimizer import get_optimizer
from src.rlhf.utils import masked_mean, eval_decorator
from accelerate import Accelerator
......@@ -42,8 +42,8 @@ PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
class ActorCritic(nn.Module):
def __init__(
self,
palm: PaLM,
critic_palm: Optional[PaLM] = None,
rwkv: RWKV,
critic_palm: Optional[RWKV] = None,
pooled_values = False,
actor_lora = True,
critic_lora = True,
......@@ -55,12 +55,12 @@ class ActorCritic(nn.Module):
critic_dropout = 0.
):
super().__init__()
self.actor_palm = palm
self.actor_palm = rwkv
self.critic_palm = critic_palm
if not exists(self.critic_palm):
self.critic_palm = copy.deepcopy(palm)
self.critic_palm = copy.deepcopy(rwkv)
self.actor_palm.set_dropout(actor_dropout)
self.critic_palm.set_dropout(critic_dropout)
......@@ -79,7 +79,7 @@ class ActorCritic(nn.Module):
self.pooled_values = pooled_values
self.value_head = nn.Sequential(
nn.Linear(palm.dim, 1),
nn.Linear(rwkv.dim, 1),
Rearrange('... 1 -> ...')
)
......@@ -289,7 +289,7 @@ class RLHFTrainer(nn.Module):
prompts_path: Optional[str] = None,
prompt_token_ids: Optional[torch.Tensor] = None,
tokenizer: Callable = None,
palm: PaLM,
rwkv: RWKV,
reward_model: RewardModel,
actor_critic: Optional[ActorCritic] = None,
actor_lr = 1e-4,
......@@ -339,12 +339,11 @@ class RLHFTrainer(nn.Module):
self.register_buffer('prompt_token_ids', prompt_token_ids)
# models
self.palm = palm
self.rwkv = rwkv
if not exists(actor_critic):
actor_critic = ActorCritic(
palm = palm,
rwkv = rwkv,
actor_lora = actor_lora,
critic_lora = critic_lora,
actor_lora_r = actor_lora_r,
......@@ -352,7 +351,7 @@ class RLHFTrainer(nn.Module):
pooled_values = critic_pooled_values,
actor_dropout = actor_dropout,
critic_dropout = critic_dropout
).to(palm.device)
).to(rwkv.device)
self.actor_critic = actor_critic
......
......@@ -34,7 +34,6 @@ class RewardModel(nn.Module):
# 用预训练模型初始化奖励模型
self.rwkv = copy.deepcopy(rwkv)
self.rwkv.set_dropout(dropout) # todo(luxin)
# 输出 token 向量的维度
dim = rwkv.dim # todo(luxin)
......
'''
@File : train_rlhf.py
@Time : 2023/03/08 15:23:19
@Author : Lu Xin
@Contact : luxin@csdn.net
'''
# here put the import lib
import torch
from src.model import RWKV
from src.rlhf.reward import RewardModel
from src.rlhf.ppo import RLHFTrainer
# load your pretrained RWKV
# todo(luxin) 加载 SFT 之后的预训练模型
rwkv_model = RWKV()
# palm.load('./path/to/pretrained/palm.pt')
# load your pretrained reward model
# todo(luxin) 加载训练好的 reward Model
reward_model = RewardModel(
rwkv_model,
num_binned_output = 5
)
# reward_model.load('./path/to/pretrained/reward_model.pt')
# ready your list of prompts for reinforcement learning
# todo(luxin) 读入 Prompts 数据集(此处的 Prompt 与 SFT、RM 阶段的 Prompt 要不一样)
prompts = torch.randint(0, 256, (50000, 512)) # 50k prompts
# pass it all to the trainer and train
# 训练 PPO 模型
trainer = RLHFTrainer(
palm = palm,
reward_model = reward_model,
prompt_token_ids = prompts
)
trainer.train(num_episodes = 100)
# then, if it succeeded...
# generate say 10 samples and use the reward model to return the best one
answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,)
print(answer)
\ No newline at end of file
'''
@File : train_rm.py
@Time : 2023/03/08 15:23:29
@Author : Lu Xin
@Contact : luxin@csdn.net
'''
# here put the import lib
import torch
from src.rlhf.reward import RewardModel
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册