提交 02e8a2d9 编写于 作者: U u010280923

opt ppo model

上级 749886a8
......@@ -11,6 +11,7 @@ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from pathlib import Path
from tqdm import tqdm
from einops import pack
......@@ -381,6 +382,11 @@ class RWKV(pl.LightningModule):
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
def load(self, path):
path = Path(path)
assert path.exists()
self.load_state_dict(torch.load(str(path)), map_location="cpu")
def configure_optimizers(self):
args = self.args
if args.layerwise_lr > 0:
......
......@@ -45,19 +45,16 @@ PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
class ActorCritic(nn.Module):
def __init__(
self,
rwkv: RWKV,
args,
critic: Optional[RWKV] = None,
actor: RWKV,
critic: RWKV,
pooled_values = False
):
super().__init__()
self.actor = copy.deepcopy(rwkv)
self.actor = actor
self.critic = critic
if not exists(self.critic):
self.critic = copy.deepcopy(rwkv)
self.pooled_values = pooled_values
self.value_head = nn.Sequential(
nn.Linear(args.n_embd, 1),
......@@ -242,20 +239,21 @@ class RLHF(pl.LightningModule):
def __init__(
self,
args,
rwkv: RWKV,
actor: RWKV,
critic: RWKV,
reward_model: RewardModel
):
super().__init__()
self.args = args
self.rwkv = rwkv
# 使用 RWKV 初始化 actor_critic
actor_critic = ActorCritic(
rwkv = self.rwkv,
args = self.args,
args=self.args,
actor=actor,
critic=critic,
pooled_values = args.critic_pooled_values
).to(self.rwkv.device)
).to(actor.device)
self.actor_critic = actor_critic
......
......@@ -261,33 +261,22 @@ if __name__ == "__main__":
# 读入训练数据集
prompts = load_prompt_data_4_ppo(args)
# 加载 RWKV 模型
rwkv = RWKV(args)
if len(args.load_sft_model) == 0:
rank_zero_info(f"SFT must load model, please input ")
exit(1)
# 用 rwkv 初始化 actor 模型
actor = RWKV(args)
actor.load(args.load_sft_model)
rank_zero_info(f"########## Loading {args.load_sft_model}... ##########")
try:
load_dict = torch.load(args.load_sft_model, map_location="cpu")
except:
rank_zero_info(f"Bad checkpoint {args.load_sft_model}")
exit(1)
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in rwkv.state_dict():
if k not in load_keys:
load_dict[k] = rwkv.state_dict()[k]
rwkv.load_state_dict(load_dict)
# 用 rwkv 初始化 critic 模型
critic = RWKV(args)
critic.load(args.load_sft_model)
# 加载 reward_model
rwkv = RWKV(args)
rwkv.load(args.load_sft_model)
reward_model = RewardModel(args, rwkv)
reward_model.load(args.load_rm_model)
# PPO 模型
rlhf_model = RLHF(args, rwkv, reward_model)
rlhf_model = RLHF(args, actor, critic, reward_model)
# 模型训练
# trainer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册