Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
CSDN 技术社区
ai
chatCSDN
提交
a532f71e
C
chatCSDN
项目概览
CSDN 技术社区
/
ai
/
chatCSDN
通知
107
Star
8
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
chatCSDN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a532f71e
编写于
3月 20, 2023
作者:
U
u010280923
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
opt ppo model
上级
26b8d465
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
34 addition
and
28 deletion
+34
-28
README.md
README.md
+1
-1
src/rlhf/ppo.py
src/rlhf/ppo.py
+5
-26
train_ppo.py
train_ppo.py
+28
-1
未找到文件。
README.md
浏览文件 @
a532f71e
...
...
@@ -64,7 +64,7 @@ python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft"
### Reward Model
```
python train_rm.py --load_model "./out_sft/rwkv-190.pth" --wandb "" --proj_dir "out_rm" \
python train_rm.py --load_
sft_
model "./out_sft/rwkv-190.pth" --wandb "" --proj_dir "out_rm" \
--data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
--micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
...
...
src/rlhf/ppo.py
浏览文件 @
a532f71e
...
...
@@ -20,6 +20,7 @@ from torch.optim import Adam
from
torch.utils.data
import
Dataset
,
DataLoader
from
torch.nn.utils.rnn
import
pad_sequence
import
pytorch_lightning
as
pl
from
pytorch_lightning.utilities
import
rank_zero_info
from
pytorch_lightning.strategies
import
DeepSpeedStrategy
from
deepspeed.ops.adam
import
DeepSpeedCPUAdam
,
FusedAdam
...
...
@@ -253,33 +254,13 @@ def clipped_value_loss(values, rewards, old_values, clip):
class
RLHF
(
pl
.
LightningModule
):
def
__init__
(
self
,
args
args
,
rwkv
:
RWKV
,
reward_model
:
RewardModel
):
super
().
__init__
()
self
.
args
=
args
# 加载 RWKV 模型
rwkv
=
RWKV
(
args
)
if
len
(
args
.
load_sft_model
)
==
0
:
rank_zero_info
(
f
"SFT must load model, please input "
)
exit
(
1
)
rank_zero_info
(
f
"########## Loading
{
args
.
load_sft_model
}
... ##########"
)
try
:
load_dict
=
torch
.
load
(
args
.
load_sft_model
,
map_location
=
"cpu"
)
except
:
rank_zero_info
(
f
"Bad checkpoint
{
args
.
load_sft_model
}
"
)
exit
(
1
)
if
args
.
load_partial
==
1
:
load_keys
=
load_dict
.
keys
()
for
k
in
rwkv
.
state_dict
():
if
k
not
in
load_keys
:
load_dict
[
k
]
=
rwkv
.
state_dict
()[
k
]
rwkv
.
load_state_dict
(
load_dict
)
self
.
rwkv
=
rwkv
# 使用 RWKV 初始化 actor_critic
...
...
@@ -291,9 +272,7 @@ class RLHF(pl.LightningModule):
self
.
actor_critic
=
actor_critic
# 加载 reward_model,并将 reward_model 设置为 evaluation 模式
reward_model
=
RewardModel
(
args
)
reward_model
.
load
(
args
.
load_rm_model
)
# 将 reward_model 设置为 evaluation 模式
self
.
reward_model
=
reward_model
.
eval
()
def
save
(
self
,
filepath
=
'./checkpoint.pt'
):
...
...
train_ppo.py
浏览文件 @
a532f71e
...
...
@@ -252,6 +252,8 @@ if __name__ == "__main__":
from
src.dataset
import
PPODataset
,
load_prompt_data_4_ppo
from
src.rlhf.ppo
import
RLHF
from
src.trainer
import
rlhf_train_callback
from
src.model
import
RWKV
from
src.rlhf.reward
import
RewardModel
# 用于 PPO 训练的数据,需要与 environment 交互获得
memory
=
[]
...
...
@@ -259,8 +261,33 @@ if __name__ == "__main__":
# 读入训练数据集
prompts
=
load_prompt_data_4_ppo
(
args
)
# 加载 RWKV 模型
rwkv
=
RWKV
(
args
)
if
len
(
args
.
load_sft_model
)
==
0
:
rank_zero_info
(
f
"SFT must load model, please input "
)
exit
(
1
)
rank_zero_info
(
f
"########## Loading
{
args
.
load_sft_model
}
... ##########"
)
try
:
load_dict
=
torch
.
load
(
args
.
load_sft_model
,
map_location
=
"cpu"
)
except
:
rank_zero_info
(
f
"Bad checkpoint
{
args
.
load_sft_model
}
"
)
exit
(
1
)
if
args
.
load_partial
==
1
:
load_keys
=
load_dict
.
keys
()
for
k
in
rwkv
.
state_dict
():
if
k
not
in
load_keys
:
load_dict
[
k
]
=
rwkv
.
state_dict
()[
k
]
rwkv
.
load_state_dict
(
load_dict
)
# 加载 reward_model
reward_model
=
RewardModel
(
args
)
reward_model
.
load
(
args
.
load_rm_model
)
# PPO 模型
rlhf_model
=
RLHF
(
args
)
rlhf_model
=
RLHF
(
args
,
rwkv
,
reward_model
)
# 模型训练
# trainer
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录