Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
CSDN 技术社区
ai
chatCSDN
提交
3c80d013
C
chatCSDN
项目概览
CSDN 技术社区
/
ai
/
chatCSDN
通知
108
Star
8
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
chatCSDN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3c80d013
编写于
3月 10, 2023
作者:
U
u010280923
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
opt reward model
上级
e45f1cdf
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
218 addition
and
48 deletion
+218
-48
README.md
README.md
+10
-0
src/dataset.py
src/dataset.py
+42
-0
src/model.py
src/model.py
+5
-1
src/rlhf/reward.py
src/rlhf/reward.py
+65
-11
src/rlhf/rwkv/model.py
src/rlhf/rwkv/model.py
+4
-5
train_rm.py
train_rm.py
+30
-31
train_rm_demo.py
train_rm_demo.py
+62
-0
未找到文件。
README.md
浏览文件 @
3c80d013
...
...
@@ -63,6 +63,16 @@ python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft"
### Reward Model
```
python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft" \
--data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 2048 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
--micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \
--my_qa_mask 1
```
### 接入RLHF(Reinforcement Learning with Human Feedback)
...
...
src/dataset.py
浏览文件 @
3c80d013
...
...
@@ -257,3 +257,45 @@ class S2SDataset(Dataset):
z
=
torch
.
tensor
(
z
,
dtype
=
torch
.
long
)
return
x
,
y
,
z
class
RMDataset
(
Dataset
):
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
vocab_size
=
args
.
vocab_size
WORD_NAME
=
[
"20B_tokenizer.json"
,
"20B_tokenizer.json"
,
]
# [vocab, vocab] for Pile model
self
.
tokenizer
=
TOKENIZER
(
WORD_NAME
)
pf
=
pd
.
read_csv
(
args
.
data_file
)
data_list
=
[]
for
index
,
row
in
pf
.
iterrows
():
prompt
=
row
[
"prompt"
]
preferred
=
row
[
"preferred"
]
alternate
=
row
[
"alternate"
]
preferred_sample
=
f
"
{
prompt
}
\n
{
preferred
}
"
alternate_sample
=
f
"
{
prompt
}
\n
{
alternate
}
"
data_list
.
append
((
self
.
tokenizer
.
tokenizer
.
encode
(
preferred_sample
),
self
.
tokenizer
.
tokenizer
.
encode
(
alternate_sample
)))
self
.
data
=
data_list
def
__len__
(
self
):
return
len
(
self
.
data
)
def
__getitem__
(
self
,
index
):
ctx_len
=
self
.
args
.
ctx_len
req_len
=
ctx_len
+
1
preferred_sample
,
alternate_sample
=
self
.
data
[
index
]
preferred_sample
=
preferred_sample
[:
req_len
]
alternate_sample
=
alternate_sample
[:
req_len
]
preferred_sample
=
preferred_sample
+
[
0
]
*
(
req_len
-
len
(
preferred_sample
))
alternate_sample
=
alternate_sample
+
[
0
]
*
(
req_len
-
len
(
alternate_sample
))
x_p
=
torch
.
tensor
(
preferred_sample
,
dtype
=
torch
.
long
)
x_a
=
torch
.
tensor
(
alternate_sample
,
dtype
=
torch
.
long
)
return
x_p
,
x_a
\ No newline at end of file
src/model.py
浏览文件 @
3c80d013
...
...
@@ -429,7 +429,7 @@ class RWKV(pl.LightningModule):
return
cfg
.
get
(
"offload_optimizer"
)
or
cfg
.
get
(
"offload_param"
)
return
False
def
forward
(
self
,
idx
,
extra_embed
=
None
):
def
forward
(
self
,
idx
,
extra_embed
=
None
,
rm_train
=
False
):
args
=
self
.
args
B
,
T
=
idx
.
size
()
assert
T
<=
args
.
ctx_len
,
"Cannot forward, model ctx_len is exhausted."
...
...
@@ -456,6 +456,10 @@ class RWKV(pl.LightningModule):
x
=
self
.
ln_out
(
x
)
# 用于 RM 模型的编码
if
rm_train
is
True
:
return
x
if
args
.
head_qk
>
0
:
q
=
self
.
head_q
(
x
)[:,
:
T
,
:]
k
=
self
.
head_k
(
x
)[:,
:
T
,
:]
...
...
src/rlhf/reward.py
浏览文件 @
3c80d013
...
...
@@ -8,33 +8,63 @@ from beartype.typing import Tuple, Optional
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
import
pytorch_lightning
as
pl
from
pytorch_lightning.utilities
import
rank_zero_info
from
einops
import
rearrange
,
repeat
,
reduce
,
pack
,
unpack
from
einops.layers.torch
import
Rearrange
,
Reduce
from
src.rlhf.utils
import
masked_mean
,
gumbel_sample
from
src.rlhf.rwkv.model
import
RWKV
# from src.model import RWKV
from
src.model
import
RWKV
# helper functions
def
exists
(
val
):
return
val
is
not
None
# loss function
def
loss_function
(
prefer_reward
,
alter_reward
):
return
-
torch
.
mean
(
torch
.
log
(
torch
.
sigmoid
(
alter_reward
-
prefer_reward
)))
# Reward Model - RWKV with a scalar head
@
beartype
class
RewardModel
(
nn
.
Module
):
class
RewardModel
(
pl
.
Lightning
Module
):
def
__init__
(
self
,
args
,
rwkv
:
RWKV
):
super
().
__init__
()
# 加载 RWKV 模型
rwkv
=
RWKV
(
args
)
if
len
(
args
.
load_model
)
==
0
:
rank_zero_info
(
f
"SFT must load model, please input "
)
exit
(
1
)
rank_zero_info
(
f
"########## Loading
{
args
.
load_model
}
... ##########"
)
try
:
load_dict
=
torch
.
load
(
args
.
load_model
,
map_location
=
"cpu"
)
except
:
rank_zero_info
(
f
"Bad checkpoint
{
args
.
load_model
}
"
)
exit
(
1
)
if
args
.
load_partial
==
1
:
load_keys
=
load_dict
.
keys
()
for
k
in
rwkv
.
state_dict
():
if
k
not
in
load_keys
:
load_dict
[
k
]
=
rwkv
.
state_dict
()[
k
]
rwkv
.
load_state_dict
(
load_dict
)
# 用预训练模型初始化奖励模型
self
.
rwkv
=
rwkv
self
.
args
=
args
# 输出 token 向量的维度
dim
=
rwkv
.
args
.
n_embd
dim
=
self
.
args
.
n_embd
# 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0
self
.
prompt_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
dim
))
...
...
@@ -56,8 +86,18 @@ class RewardModel(nn.Module):
*
self
.
to_pred
.
parameters
(),
*
self
.
rwkv
.
parameters
()
]
def
forward
(
def
configure_optimizers
(
self
):
# 论文中的参数:lr=1e-5, betas=(0.9, 0.95)
optimizer
=
torch
.
optim
.
Adam
([
{
"rwkv_params"
:
self
.
rwkv
.
parameters
()},
{
"rm_params"
:
self
.
parameters
()}
],
lr
=
self
.
args
.
lr_init
,
betas
=
self
.
args
.
betas
)
return
optimizer
def
single_forward
(
self
,
x
,
mask
=
None
,
...
...
@@ -89,15 +129,29 @@ class RewardModel(nn.Module):
last_token_embeds
=
self
.
rwkv
(
x
,
state
=
None
,
extra_embed
=
extra_embed
extra_embed
=
extra_embed
,
rm_train
=
True
)
# 所有的 token 向量求平均,并输入到打分模块进行打分
try
:
pooled
=
masked_mean
(
last_token_embeds
,
mask
,
dim
=
1
)
except
:
import
ipdb
ipdb
.
set_trace
()
pooled
=
masked_mean
(
last_token_embeds
,
mask
,
dim
=
1
)
reward
=
self
.
pred_reward
(
pooled
)
return
reward
def
forward
(
self
,
prefer_x
,
alter_x
,
prefer_x_prompt_mask
,
alter_x_prompt_mask
):
prefer_reward
=
self
.
single_forward
(
prefer_x
,
prefer_x_prompt_mask
)
alter_reward
=
self
.
single_forward
(
alter_x
,
alter_x_prompt_mask
)
return
prefer_reward
,
alter_reward
def
training_step
(
self
,
batch
,
batch_idx
):
prefer_x
,
alter_x
,
prefer_x_prompt_mask
,
alter_x_prompt_mask
=
batch
prefer_reward
,
alter_reward
=
self
(
prefer_x
,
alter_x
,
prefer_x_prompt_mask
,
alter_x_prompt_mask
)
loss
=
loss_function
(
prefer_reward
,
alter_reward
)
return
loss
src/rlhf/rwkv/model.py
浏览文件 @
3c80d013
...
...
@@ -365,7 +365,7 @@ class RWKV(MyModule):
########################################################################################################
def
forward
(
self
,
tokens
,
state
,
full_output
=
False
,
extra_embed
=
None
):
def
forward
(
self
,
tokens
,
state
,
full_output
=
False
,
extra_embed
=
None
,
rm_train
=
False
):
with
torch
.
no_grad
():
w
=
self
.
w
args
=
self
.
args
...
...
@@ -384,8 +384,6 @@ class RWKV(MyModule):
state
[
i
*
5
+
4
]
=
torch
.
zeros
(
args
.
n_embd
,
dtype
=
atype
,
requires_grad
=
False
,
device
=
dev
)
seq_mode
=
len
(
tokens
)
>
1
import
ipdb
ipdb
.
set_trace
()
# 输入:根据 idx 取每个 token 的 embedding
x
=
w
[
'emb.weight'
][
tokens
if
seq_mode
else
tokens
[
0
]]
...
...
@@ -460,11 +458,12 @@ class RWKV(MyModule):
# 对 token embedding 进行 LayerNorm,维度不变
x
=
F
.
layer_norm
(
x
,
(
args
.
n_embd
,),
weight
=
w
[
'ln_out.weight'
],
bias
=
w
[
'ln_out.bias'
])
token_embed
=
copy
.
deepcopy
(
x
)
if
rm_train
is
True
:
return
x
if
w
[
'head.weight'
].
dtype
!=
torch
.
uint8
:
x
=
x
@
w
[
'head.weight'
]
else
:
x
=
x
@
self
.
get_w
(
'head.weight'
,
dd
.
atype
)
return
x
.
float
(),
state
,
token_embed
.
float
()
return
x
.
float
(),
state
train_rm.py
浏览文件 @
3c80d013
...
...
@@ -220,45 +220,44 @@ if __name__ == "__main__":
########################################################################################################
# 训练 RM 模型
def
loss_function
(
prefer_reward
,
alter_reward
):
return
-
torch
.
mean
(
torch
.
log
(
torch
.
sigmoid
(
alter_reward
-
prefer_reward
)))
import
torch
from
tqdm
import
tqdm
from
src.trainer
import
train_callback
from
src.rlhf.reward
import
RewardModel
from
src.model
import
RWKV
from
src.dataset
import
RMDataset
# 读入训练数据
train_data
=
RMDataset
(
args
)
args
.
vocab_size
=
train_data
.
vocab_size
# RM 模型
rm_model
=
RewardModel
(
args
)
import
torch
from
src.rlhf.reward
import
RewardModel
from
src.rlhf.rwkv.model
import
RWKV
model
=
"./model/RWKV-4-Pile-169M-20220807-8023.pth"
strategy
=
"cpu fp32"
rwkv_model
=
RWKV
(
model
,
strategy
)
dim
=
rwkv_model
.
args
.
n_embd
reward_model
=
RewardModel
(
rwkv_model
)
# mock data
prompt
=
torch
.
randint
(
0
,
dim
,
(
1
,
50
))
prefer_response
=
torch
.
randint
(
0
,
dim
,
(
1
,
50
))
alter_response
=
torch
.
randint
(
0
,
dim
,
(
1
,
50
))
# 训练
trainer
=
Trainer
.
from_argparse_args
()
prefer_pair
=
torch
.
concat
((
prompt
,
prefer_response
),
dim
=
1
)
alter_pair
=
torch
.
concat
((
prompt
,
alter_response
),
dim
=
1
)
if
trainer
.
global_rank
==
0
:
for
n
in
rm_model
.
state_dict
():
shape
=
rm_model
.
state_dict
()[
n
].
shape
shape
=
[
i
for
i
in
shape
if
i
!=
1
]
if
len
(
shape
)
>
1
:
print
(
f
"
{
str
(
shape
[
0
]).
ljust
(
5
)
}
{
str
(
shape
[
1
]).
ljust
(
5
)
}
{
n
}
"
)
else
:
print
(
f
"
{
str
(
shape
[
0
]).
ljust
(
5
)
}
{
n
}
"
)
# which part of the sequence is prompt, which part is response
prompt_mask
=
torch
.
cat
((
torch
.
ones
(
1
,
50
).
bool
(),
torch
.
zeros
(
1
,
50
).
bool
()),
dim
=
1
)
# labels = torch.randint(0, 5, (1,))
if
"deepspeed"
in
args
.
strategy
:
trainer
.
strategy
.
config
[
"zero_optimization"
][
"allgather_bucket_size"
]
=
args
.
ds_bucket_mb
*
1000
*
1000
trainer
.
strategy
.
config
[
"zero_optimization"
][
"reduce_bucket_size"
]
=
args
.
ds_bucket_mb
*
1000
*
1000
# train
# loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)
# loss.backward()
# must set shuffle=True, persistent_workers=False (because worker is in another thread)
data_loader
=
DataLoader
(
train_data
,
shuffle
=
True
,
pin_memory
=
True
,
batch_size
=
args
.
micro_bsz
,
num_workers
=
1
,
persistent_workers
=
False
,
drop_last
=
True
)
# inference
prefer_reward
=
reward_model
(
prefer_pair
,
prompt_mask
=
prompt_mask
)
alter_reward
=
reward_model
(
alter_pair
,
prompt_mask
=
prompt_mask
)
trainer
.
fit
(
rm_model
,
data_loader
)
print
(
"Preferred response reward:"
,
prefer_reward
)
print
(
"Alternate response reward:"
,
alter_reward
)
\ No newline at end of file
train_rm_demo.py
0 → 100644
浏览文件 @
3c80d013
'''
@File : train_rm_demo.py
@Time : 2023/03/10 00:54:57
@Author : Lu Xin
@Contact : luxin@csdn.net
'''
# here put the import lib
import
torch
from
tqdm
import
tqdm
from
src.rlhf.reward
import
RewardModel
from
src.rlhf.rwkv.model
import
RWKV
def
loss_function
(
prefer_reward
,
alter_reward
):
return
-
torch
.
mean
(
torch
.
log
(
torch
.
sigmoid
(
alter_reward
-
prefer_reward
)))
model
=
"./model/RWKV-4-Pile-169M-20220807-8023.pth"
strategy
=
"cpu fp32"
rwkv_model
=
RWKV
(
model
,
strategy
)
reward_model
=
RewardModel
(
rwkv_model
)
import
ipdb
ipdb
.
set_trace
()
# as used in the InstructGPT paper
optimizer
=
torch
.
optim
.
Adam
(
reward_model
.
parameters
(),
lr
=
1e-5
,
betas
=
(
0.9
,
0.95
))
# 假数据
dim
=
20000
prompt
=
torch
.
randint
(
0
,
dim
,
(
1
,
50
))
prefer_response
=
torch
.
randint
(
0
,
dim
,
(
1
,
50
))
alter_response
=
torch
.
randint
(
0
,
dim
,
(
1
,
50
))
prefer_pair
=
torch
.
concat
((
prompt
,
prefer_response
),
dim
=
1
)
alter_pair
=
torch
.
concat
((
prompt
,
alter_response
),
dim
=
1
)
prompt_mask
=
torch
.
cat
((
torch
.
ones
(
1
,
50
).
bool
(),
torch
.
zeros
(
1
,
50
).
bool
()),
dim
=
1
)
for
epoch
in
range
(
100
):
# 计算奖励
prefer_reward
=
reward_model
(
prefer_pair
,
prompt_mask
=
prompt_mask
)
alter_reward
=
reward_model
(
alter_pair
,
prompt_mask
=
prompt_mask
)
# print(f"prefer_reward: {prefer_reward}")
# print(f"alter_reward: {alter_reward}")
# train
loss
=
loss_function
(
prefer_reward
,
alter_reward
)
print
(
f
"loss:
{
loss
}
"
)
# Backward pass
loss
.
backward
()
optimizer
.
step
()
# Zero the gradients
optimizer
.
zero_grad
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录