Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
CSDN 技术社区
ai
chatCSDN
提交
02e8a2d9
C
chatCSDN
项目概览
CSDN 技术社区
/
ai
/
chatCSDN
通知
107
Star
8
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
chatCSDN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
02e8a2d9
编写于
3月 20, 2023
作者:
U
u010280923
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
opt ppo model
上级
749886a8
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
24 addition
and
31 deletion
+24
-31
src/model.py
src/model.py
+6
-0
src/rlhf/ppo.py
src/rlhf/ppo.py
+9
-11
train_ppo.py
train_ppo.py
+9
-20
未找到文件。
src/model.py
浏览文件 @
02e8a2d9
...
@@ -11,6 +11,7 @@ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
...
@@ -11,6 +11,7 @@ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from
pytorch_lightning.strategies
import
DeepSpeedStrategy
from
pytorch_lightning.strategies
import
DeepSpeedStrategy
import
deepspeed
import
deepspeed
from
deepspeed.ops.adam
import
DeepSpeedCPUAdam
,
FusedAdam
from
deepspeed.ops.adam
import
DeepSpeedCPUAdam
,
FusedAdam
from
pathlib
import
Path
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
einops
import
pack
from
einops
import
pack
...
@@ -381,6 +382,11 @@ class RWKV(pl.LightningModule):
...
@@ -381,6 +382,11 @@ class RWKV(pl.LightningModule):
self
.
head_k
=
nn
.
Linear
(
args
.
n_embd
,
args
.
head_qk
,
bias
=
False
)
self
.
head_k
=
nn
.
Linear
(
args
.
n_embd
,
args
.
head_qk
,
bias
=
False
)
self
.
register_buffer
(
"copy_mask"
,
torch
.
tril
(
torch
.
ones
(
args
.
ctx_len
,
args
.
ctx_len
)))
self
.
register_buffer
(
"copy_mask"
,
torch
.
tril
(
torch
.
ones
(
args
.
ctx_len
,
args
.
ctx_len
)))
def
load
(
self
,
path
):
path
=
Path
(
path
)
assert
path
.
exists
()
self
.
load_state_dict
(
torch
.
load
(
str
(
path
)),
map_location
=
"cpu"
)
def
configure_optimizers
(
self
):
def
configure_optimizers
(
self
):
args
=
self
.
args
args
=
self
.
args
if
args
.
layerwise_lr
>
0
:
if
args
.
layerwise_lr
>
0
:
...
...
src/rlhf/ppo.py
浏览文件 @
02e8a2d9
...
@@ -45,19 +45,16 @@ PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
...
@@ -45,19 +45,16 @@ PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
class
ActorCritic
(
nn
.
Module
):
class
ActorCritic
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
rwkv
:
RWKV
,
args
,
args
,
critic
:
Optional
[
RWKV
]
=
None
,
actor
:
RWKV
,
critic
:
RWKV
,
pooled_values
=
False
pooled_values
=
False
):
):
super
().
__init__
()
super
().
__init__
()
self
.
actor
=
copy
.
deepcopy
(
rwkv
)
self
.
actor
=
actor
self
.
critic
=
critic
self
.
critic
=
critic
if
not
exists
(
self
.
critic
):
self
.
critic
=
copy
.
deepcopy
(
rwkv
)
self
.
pooled_values
=
pooled_values
self
.
pooled_values
=
pooled_values
self
.
value_head
=
nn
.
Sequential
(
self
.
value_head
=
nn
.
Sequential
(
nn
.
Linear
(
args
.
n_embd
,
1
),
nn
.
Linear
(
args
.
n_embd
,
1
),
...
@@ -242,20 +239,21 @@ class RLHF(pl.LightningModule):
...
@@ -242,20 +239,21 @@ class RLHF(pl.LightningModule):
def
__init__
(
def
__init__
(
self
,
self
,
args
,
args
,
rwkv
:
RWKV
,
actor
:
RWKV
,
critic
:
RWKV
,
reward_model
:
RewardModel
reward_model
:
RewardModel
):
):
super
().
__init__
()
super
().
__init__
()
self
.
args
=
args
self
.
args
=
args
self
.
rwkv
=
rwkv
# 使用 RWKV 初始化 actor_critic
# 使用 RWKV 初始化 actor_critic
actor_critic
=
ActorCritic
(
actor_critic
=
ActorCritic
(
rwkv
=
self
.
rwkv
,
args
=
self
.
args
,
args
=
self
.
args
,
actor
=
actor
,
critic
=
critic
,
pooled_values
=
args
.
critic_pooled_values
pooled_values
=
args
.
critic_pooled_values
).
to
(
self
.
rwkv
.
device
)
).
to
(
actor
.
device
)
self
.
actor_critic
=
actor_critic
self
.
actor_critic
=
actor_critic
...
...
train_ppo.py
浏览文件 @
02e8a2d9
...
@@ -261,33 +261,22 @@ if __name__ == "__main__":
...
@@ -261,33 +261,22 @@ if __name__ == "__main__":
# 读入训练数据集
# 读入训练数据集
prompts
=
load_prompt_data_4_ppo
(
args
)
prompts
=
load_prompt_data_4_ppo
(
args
)
# 加载 RWKV 模型
# 用 rwkv 初始化 actor 模型
rwkv
=
RWKV
(
args
)
actor
=
RWKV
(
args
)
actor
.
load
(
args
.
load_sft_model
)
if
len
(
args
.
load_sft_model
)
==
0
:
rank_zero_info
(
f
"SFT must load model, please input "
)
exit
(
1
)
rank_zero_info
(
f
"########## Loading
{
args
.
load_sft_model
}
... ##########"
)
# 用 rwkv 初始化 critic 模型
try
:
critic
=
RWKV
(
args
)
load_dict
=
torch
.
load
(
args
.
load_sft_model
,
map_location
=
"cpu"
)
critic
.
load
(
args
.
load_sft_model
)
except
:
rank_zero_info
(
f
"Bad checkpoint
{
args
.
load_sft_model
}
"
)
exit
(
1
)
if
args
.
load_partial
==
1
:
load_keys
=
load_dict
.
keys
()
for
k
in
rwkv
.
state_dict
():
if
k
not
in
load_keys
:
load_dict
[
k
]
=
rwkv
.
state_dict
()[
k
]
rwkv
.
load_state_dict
(
load_dict
)
# 加载 reward_model
# 加载 reward_model
rwkv
=
RWKV
(
args
)
rwkv
.
load
(
args
.
load_sft_model
)
reward_model
=
RewardModel
(
args
,
rwkv
)
reward_model
=
RewardModel
(
args
,
rwkv
)
reward_model
.
load
(
args
.
load_rm_model
)
reward_model
.
load
(
args
.
load_rm_model
)
# PPO 模型
# PPO 模型
rlhf_model
=
RLHF
(
args
,
rwkv
,
reward_model
)
rlhf_model
=
RLHF
(
args
,
actor
,
critic
,
reward_model
)
# 模型训练
# 模型训练
# trainer
# trainer
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录