提交 a532f71e 编写于 作者: U u010280923

opt ppo model

上级 26b8d465
......@@ -64,7 +64,7 @@ python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft"
### Reward Model
```
python train_rm.py --load_model "./out_sft/rwkv-190.pth" --wandb "" --proj_dir "out_rm" \
python train_rm.py --load_sft_model "./out_sft/rwkv-190.pth" --wandb "" --proj_dir "out_rm" \
--data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
--micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
......
......@@ -20,6 +20,7 @@ from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.strategies import DeepSpeedStrategy
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
......@@ -253,33 +254,13 @@ def clipped_value_loss(values, rewards, old_values, clip):
class RLHF(pl.LightningModule):
def __init__(
self,
args
args,
rwkv: RWKV,
reward_model: RewardModel
):
super().__init__()
self.args = args
# 加载 RWKV 模型
rwkv = RWKV(args)
if len(args.load_sft_model) == 0:
rank_zero_info(f"SFT must load model, please input ")
exit(1)
rank_zero_info(f"########## Loading {args.load_sft_model}... ##########")
try:
load_dict = torch.load(args.load_sft_model, map_location="cpu")
except:
rank_zero_info(f"Bad checkpoint {args.load_sft_model}")
exit(1)
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in rwkv.state_dict():
if k not in load_keys:
load_dict[k] = rwkv.state_dict()[k]
rwkv.load_state_dict(load_dict)
self.rwkv = rwkv
# 使用 RWKV 初始化 actor_critic
......@@ -291,9 +272,7 @@ class RLHF(pl.LightningModule):
self.actor_critic = actor_critic
# 加载 reward_model,并将 reward_model 设置为 evaluation 模式
reward_model = RewardModel(args)
reward_model.load(args.load_rm_model)
# 将 reward_model 设置为 evaluation 模式
self.reward_model = reward_model.eval()
def save(self, filepath = './checkpoint.pt'):
......
......@@ -252,6 +252,8 @@ if __name__ == "__main__":
from src.dataset import PPODataset, load_prompt_data_4_ppo
from src.rlhf.ppo import RLHF
from src.trainer import rlhf_train_callback
from src.model import RWKV
from src.rlhf.reward import RewardModel
# 用于 PPO 训练的数据,需要与 environment 交互获得
memory = []
......@@ -259,8 +261,33 @@ if __name__ == "__main__":
# 读入训练数据集
prompts = load_prompt_data_4_ppo(args)
# 加载 RWKV 模型
rwkv = RWKV(args)
if len(args.load_sft_model) == 0:
rank_zero_info(f"SFT must load model, please input ")
exit(1)
rank_zero_info(f"########## Loading {args.load_sft_model}... ##########")
try:
load_dict = torch.load(args.load_sft_model, map_location="cpu")
except:
rank_zero_info(f"Bad checkpoint {args.load_sft_model}")
exit(1)
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in rwkv.state_dict():
if k not in load_keys:
load_dict[k] = rwkv.state_dict()[k]
rwkv.load_state_dict(load_dict)
# 加载 reward_model
reward_model = RewardModel(args)
reward_model.load(args.load_rm_model)
# PPO 模型
rlhf_model = RLHF(args)
rlhf_model = RLHF(args, rwkv, reward_model)
# 模型训练
# trainer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册