Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
CSDN 技术社区
ai
chatCSDN
提交
26b8d465
C
chatCSDN
项目概览
CSDN 技术社区
/
ai
/
chatCSDN
通知
107
Star
8
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
chatCSDN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
26b8d465
编写于
3月 20, 2023
作者:
U
u010280923
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
opt reward model
上级
82d6d979
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
28 addition
and
26 deletion
+28
-26
src/rlhf/ppo.py
src/rlhf/ppo.py
+1
-1
src/rlhf/reward.py
src/rlhf/reward.py
+2
-22
train_rm.py
train_rm.py
+25
-3
未找到文件。
src/rlhf/ppo.py
浏览文件 @
26b8d465
...
...
@@ -250,7 +250,7 @@ def clipped_value_loss(values, rewards, old_values, clip):
# rlhf
@
beartype
class
RLHF
(
nn
.
Module
):
class
RLHF
(
pl
.
Lightning
Module
):
def
__init__
(
self
,
args
...
...
src/rlhf/reward.py
浏览文件 @
26b8d465
...
...
@@ -21,6 +21,7 @@ from einops.layers.torch import Rearrange, Reduce
from
src.rlhf.utils
import
masked_mean
,
gumbel_sample
from
src.model
import
RWKV
# helper functions
def
exists
(
val
):
...
...
@@ -34,30 +35,9 @@ def loss_function(prefer_reward, alter_reward):
@
beartype
class
RewardModel
(
pl
.
LightningModule
):
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
,
rwkv
:
RWKV
):
super
().
__init__
()
# 加载 RWKV 模型
rwkv
=
RWKV
(
args
)
if
len
(
args
.
load_model
)
==
0
:
rank_zero_info
(
f
"SFT must load model, please input "
)
exit
(
1
)
rank_zero_info
(
f
"########## Loading
{
args
.
load_model
}
... ##########"
)
try
:
load_dict
=
torch
.
load
(
args
.
load_model
,
map_location
=
"cpu"
)
except
:
rank_zero_info
(
f
"Bad checkpoint
{
args
.
load_model
}
"
)
exit
(
1
)
if
args
.
load_partial
==
1
:
load_keys
=
load_dict
.
keys
()
for
k
in
rwkv
.
state_dict
():
if
k
not
in
load_keys
:
load_dict
[
k
]
=
rwkv
.
state_dict
()[
k
]
rwkv
.
load_state_dict
(
load_dict
)
# 用预训练模型初始化奖励模型
self
.
rwkv
=
rwkv
self
.
args
=
args
...
...
train_rm.py
浏览文件 @
26b8d465
...
...
@@ -57,7 +57,7 @@ if __name__ == "__main__":
parser
=
ArgumentParser
()
parser
.
add_argument
(
"--load_model"
,
default
=
""
,
type
=
str
)
# full path, with .pth
parser
.
add_argument
(
"--load_
sft_
model"
,
default
=
""
,
type
=
str
)
# full path, with .pth
parser
.
add_argument
(
"--wandb"
,
default
=
""
,
type
=
str
)
# wandb project name. if "" then don't use wandb
parser
.
add_argument
(
"--proj_dir"
,
default
=
"out"
,
type
=
str
)
parser
.
add_argument
(
"--random_seed"
,
default
=
"-1"
,
type
=
int
)
...
...
@@ -228,13 +228,35 @@ if __name__ == "__main__":
from
src.trainer
import
rm_train_callback
from
src.rlhf.reward
import
RewardModel
from
src.dataset
import
RMDataset
from
src.model
import
RWKV
# 读入训练数据
train_data
=
RMDataset
(
args
)
args
.
vocab_size
=
train_data
.
vocab_size
# RM 模型
rm_model
=
RewardModel
(
args
)
# 加载 RWKV 模型
rwkv
=
RWKV
(
args
)
if
len
(
args
.
load_sft_model
)
==
0
:
rank_zero_info
(
f
"SFT must load model, please input "
)
exit
(
1
)
rank_zero_info
(
f
"########## Loading
{
args
.
load_sft_model
}
... ##########"
)
try
:
load_dict
=
torch
.
load
(
args
.
load_sft_model
,
map_location
=
"cpu"
)
except
:
rank_zero_info
(
f
"Bad checkpoint
{
args
.
load_sft_model
}
"
)
exit
(
1
)
if
args
.
load_partial
==
1
:
load_keys
=
load_dict
.
keys
()
for
k
in
rwkv
.
state_dict
():
if
k
not
in
load_keys
:
load_dict
[
k
]
=
rwkv
.
state_dict
()[
k
]
rwkv
.
load_state_dict
(
load_dict
)
# 初始化 RM 模型
rm_model
=
RewardModel
(
args
,
rwkv
)
# 训练
trainer
=
Trainer
.
from_argparse_args
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录