提交 02e8a2d9 编写于 作者: U u010280923

opt ppo model

上级 749886a8
...@@ -11,6 +11,7 @@ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only ...@@ -11,6 +11,7 @@ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
from einops import pack from einops import pack
...@@ -381,6 +382,11 @@ class RWKV(pl.LightningModule): ...@@ -381,6 +382,11 @@ class RWKV(pl.LightningModule):
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False) self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
def load(self, path):
path = Path(path)
assert path.exists()
self.load_state_dict(torch.load(str(path)), map_location="cpu")
def configure_optimizers(self): def configure_optimizers(self):
args = self.args args = self.args
if args.layerwise_lr > 0: if args.layerwise_lr > 0:
......
...@@ -45,19 +45,16 @@ PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ ...@@ -45,19 +45,16 @@ PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
class ActorCritic(nn.Module): class ActorCritic(nn.Module):
def __init__( def __init__(
self, self,
rwkv: RWKV,
args, args,
critic: Optional[RWKV] = None, actor: RWKV,
critic: RWKV,
pooled_values = False pooled_values = False
): ):
super().__init__() super().__init__()
self.actor = copy.deepcopy(rwkv)
self.actor = actor
self.critic = critic self.critic = critic
if not exists(self.critic):
self.critic = copy.deepcopy(rwkv)
self.pooled_values = pooled_values self.pooled_values = pooled_values
self.value_head = nn.Sequential( self.value_head = nn.Sequential(
nn.Linear(args.n_embd, 1), nn.Linear(args.n_embd, 1),
...@@ -242,20 +239,21 @@ class RLHF(pl.LightningModule): ...@@ -242,20 +239,21 @@ class RLHF(pl.LightningModule):
def __init__( def __init__(
self, self,
args, args,
rwkv: RWKV, actor: RWKV,
critic: RWKV,
reward_model: RewardModel reward_model: RewardModel
): ):
super().__init__() super().__init__()
self.args = args self.args = args
self.rwkv = rwkv
# 使用 RWKV 初始化 actor_critic # 使用 RWKV 初始化 actor_critic
actor_critic = ActorCritic( actor_critic = ActorCritic(
rwkv = self.rwkv, args=self.args,
args = self.args, actor=actor,
critic=critic,
pooled_values = args.critic_pooled_values pooled_values = args.critic_pooled_values
).to(self.rwkv.device) ).to(actor.device)
self.actor_critic = actor_critic self.actor_critic = actor_critic
......
...@@ -261,33 +261,22 @@ if __name__ == "__main__": ...@@ -261,33 +261,22 @@ if __name__ == "__main__":
# 读入训练数据集 # 读入训练数据集
prompts = load_prompt_data_4_ppo(args) prompts = load_prompt_data_4_ppo(args)
# 加载 RWKV 模型 # 用 rwkv 初始化 actor 模型
rwkv = RWKV(args) actor = RWKV(args)
actor.load(args.load_sft_model)
if len(args.load_sft_model) == 0:
rank_zero_info(f"SFT must load model, please input ")
exit(1)
rank_zero_info(f"########## Loading {args.load_sft_model}... ##########") # 用 rwkv 初始化 critic 模型
try: critic = RWKV(args)
load_dict = torch.load(args.load_sft_model, map_location="cpu") critic.load(args.load_sft_model)
except:
rank_zero_info(f"Bad checkpoint {args.load_sft_model}")
exit(1)
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in rwkv.state_dict():
if k not in load_keys:
load_dict[k] = rwkv.state_dict()[k]
rwkv.load_state_dict(load_dict)
# 加载 reward_model # 加载 reward_model
rwkv = RWKV(args)
rwkv.load(args.load_sft_model)
reward_model = RewardModel(args, rwkv) reward_model = RewardModel(args, rwkv)
reward_model.load(args.load_rm_model) reward_model.load(args.load_rm_model)
# PPO 模型 # PPO 模型
rlhf_model = RLHF(args, rwkv, reward_model) rlhf_model = RLHF(args, actor, critic, reward_model)
# 模型训练 # 模型训练
# trainer # trainer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册