Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
CSDN 技术社区
ai
chatCSDN
提交
e7dc79af
C
chatCSDN
项目概览
CSDN 技术社区
/
ai
/
chatCSDN
通知
108
Star
8
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
chatCSDN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
e7dc79af
编写于
3月 15, 2023
作者:
U
u010280923
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
opt ppo model
上级
65604ada
变更
9
展开全部
隐藏空白更改
内联
并排
Showing
9 changed file
with
1350 addition
and
362 deletion
+1350
-362
README.md
README.md
+12
-1
src/dataset.py
src/dataset.py
+44
-1
src/model.py
src/model.py
+73
-9
src/rlhf/ppo.py
src/rlhf/ppo.py
+168
-315
src/rlhf/ppo_old.py
src/rlhf/ppo_old.py
+623
-0
src/rlhf/utils.py
src/rlhf/utils.py
+3
-0
src/trainer.py
src/trainer.py
+136
-0
train_ppo.py
train_ppo.py
+290
-35
train_rm.py
train_rm.py
+1
-1
未找到文件。
README.md
浏览文件 @
e7dc79af
...
...
@@ -74,7 +74,18 @@ python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_rm" \
```
### 接入RLHF(Reinforcement Learning with Human Feedback)
### PPO Model (Reinforcement learning from Human Feedback)
```
python train_rm.py --load_sft_model "rwkv-190.pth" --load_rm_model "rm-6.pth" --wandb "" \
--proj_dir "out_rlhf" \
--data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
--micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \
--my_qa_mask 1
```
src/dataset.py
浏览文件 @
e7dc79af
...
...
@@ -323,4 +323,47 @@ class RMDataset(Dataset):
m_p
=
torch
.
tensor
(
prompt_prefer_mask
,
dtype
=
torch
.
long
)
m_a
=
torch
.
tensor
(
prompt_alter_mask
,
dtype
=
torch
.
long
)
return
x_p
,
x_a
,
m_p
,
m_a
\ No newline at end of file
return
x_p
,
x_a
,
m_p
,
m_a
class
PPODataset
(
Dataset
):
def
__init__
(
self
,
memory
):
self
.
data
=
memory
def
__len__
(
self
):
return
len
(
self
.
data
)
def
__getitem__
(
self
,
index
):
# todo(luxin) 是否需要 padding ???
sequence
,
\
prompt_mask
,
\
mask
,
\
action_prob
,
\
action_log_prob
,
\
reward
,
\
value
=
self
.
data
[
index
]
return
sequence
,
prompt_mask
,
mask
,
action_prob
,
action_log_prob
,
reward
,
value
def
load_prompt_data_4_ppo
(
args
):
prompt_token_ids
=
[]
WORD_NAME
=
[
"20B_tokenizer.json"
,
"20B_tokenizer.json"
,
]
# [vocab, vocab] for Pile model
tokenizer
=
TOKENIZER
(
WORD_NAME
)
pf
=
pd
.
read_csv
(
args
.
data_file
)
for
index
,
row
in
pf
.
iterrows
():
prompt
=
row
[
"prompt"
]
prompt_token_ids
.
append
(
tokenizer
.
tokenizer
.
encode
(
prompt
))
prompt_token_ids
=
torch
.
tensor
(
prompt_token_ids
,
dtype
=
torch
.
long
)
return
prompt_token_ids
src/model.py
浏览文件 @
e7dc79af
...
...
@@ -12,6 +12,15 @@ from pytorch_lightning.strategies import DeepSpeedStrategy
import
deepspeed
from
deepspeed.ops.adam
import
DeepSpeedCPUAdam
,
FusedAdam
from
tqdm
import
tqdm
from
einops
import
pack
from
einops
import
unpack
from
src.rlhf.utils
import
exists
from
src.rlhf.utils
import
gumbel_sample
from
src.rlhf.utils
import
top_k
from
src.rlhf.utils
import
identity
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
try
:
...
...
@@ -429,7 +438,7 @@ class RWKV(pl.LightningModule):
return
cfg
.
get
(
"offload_optimizer"
)
or
cfg
.
get
(
"offload_param"
)
return
False
def
forward
(
self
,
idx
,
extra_embed
=
None
,
rm_train
=
False
):
def
forward
(
self
,
idx
,
extra_embed
=
None
,
rm_train
=
False
,
ppo_train
=
False
):
args
=
self
.
args
B
,
T
=
idx
.
size
()
assert
T
<=
args
.
ctx_len
,
"Cannot forward, model ctx_len is exhausted."
...
...
@@ -454,15 +463,15 @@ class RWKV(pl.LightningModule):
else
:
x
=
block
(
x
)
x
=
self
.
ln_out
(
x
)
embeds
=
self
.
ln_out
(
x
)
# 用于 RM 模型的编码
if
rm_train
is
True
:
return
x
return
embeds
if
args
.
head_qk
>
0
:
q
=
self
.
head_q
(
x
)[:,
:
T
,
:]
k
=
self
.
head_k
(
x
)[:,
:
T
,
:]
q
=
self
.
head_q
(
embeds
)[:,
:
T
,
:]
k
=
self
.
head_k
(
embeds
)[:,
:
T
,
:]
c
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
(
1.0
/
args
.
head_qk
)
c
=
c
.
masked_fill
(
self
.
copy_mask
[:
T
,
:
T
]
==
0
,
0
)
...
...
@@ -473,11 +482,66 @@ class RWKV(pl.LightningModule):
elif
os
.
environ
[
"RWKV_FLOAT_MODE"
]
==
"bf16"
:
c
=
c
@
F
.
one_hot
(
idx
,
num_classes
=
args
.
vocab_size
).
bfloat16
()
x
=
self
.
head
(
x
)
+
c
logits
=
self
.
head
(
embeds
)
+
c
else
:
x
=
self
.
head
(
x
)
return
x
logits
=
self
.
head
(
embeds
)
# 用于 PPO 模型
if
ppo_train
is
True
:
return
logits
,
embeds
return
logits
@
torch
.
no_grad
()
def
generate
(
self
,
seq_len
,
prompt
=
None
,
temperature
=
1.
,
filter_logits_fn
=
top_k
,
filter_thres
=
0.9
,
pad_value
=
0.
,
eos_token
=
None
,
return_seq_without_prompt
=
True
,
use_tqdm
=
False
,
**
kwargs
):
'''
'''
prompt
,
leading_dims
=
pack
([
prompt
],
'* n'
)
n
,
out
=
prompt
.
shape
[
-
1
],
prompt
.
clone
()
wrapper_fn
=
identity
if
not
use_tqdm
else
tqdm
sample_num_times
=
max
(
1
,
seq_len
-
prompt
.
shape
[
-
1
])
for
_
in
wrapper_fn
(
range
(
sample_num_times
)):
logits
,
embeds
=
self
.
forward
(
out
,
return_logits_with_embedding
=
True
,
**
kwargs
)
logits
,
embeds
=
logits
[:,
-
1
],
embeds
[:,
-
1
]
if
exists
(
filter_logits_fn
):
logits
=
filter_logits_fn
(
logits
,
thres
=
filter_thres
)
sample
=
gumbel_sample
(
logits
,
temperature
=
temperature
,
dim
=
-
1
)
out
,
_
=
pack
([
out
,
sample
],
'b *'
)
if
exists
(
eos_token
):
is_eos_tokens
=
(
out
==
eos_token
)
if
is_eos_tokens
.
any
(
dim
=
-
1
).
all
():
# mask out everything after the eos tokens
shifted_is_eos_tokens
=
F
.
pad
(
is_eos_tokens
,
(
1
,
-
1
))
mask
=
shifted_is_eos_tokens
.
float
().
cumsum
(
dim
=
-
1
)
>=
1
out
=
out
.
masked_fill
(
mask
,
pad_value
)
break
out
,
=
unpack
(
out
,
leading_dims
,
'* n'
)
if
not
return_seq_without_prompt
:
return
out
return
out
[...,
n
:]
def
training_step
(
self
,
batch
,
batch_idx
):
args
=
self
.
args
...
...
src/rlhf/ppo.py
浏览文件 @
e7dc79af
此差异已折叠。
点击以展开。
src/rlhf/ppo_old.py
0 → 100644
浏览文件 @
e7dc79af
import
math
from
pathlib
import
Path
import
copy
from
tqdm
import
tqdm
from
functools
import
partial
from
collections
import
deque
,
namedtuple
from
random
import
randrange
from
beartype
import
beartype
from
beartype.typing
import
List
,
Optional
,
Callable
,
Deque
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
torch.optim
import
Adam
from
torch.utils.data
import
Dataset
,
DataLoader
from
torch.nn.utils.rnn
import
pad_sequence
from
pytorch_lightning.utilities
import
rank_zero_info
from
einops
import
rearrange
,
repeat
from
einops.layers.torch
import
Rearrange
from
src.model
import
RWKV
from
src.rlhf.reward
import
RewardModel
from
src.rlhf.optimizer
import
get_optimizer
from
src.rlhf.utils
import
masked_mean
,
eval_decorator
from
accelerate
import
Accelerator
# actor critic - PaLM with lora
PPOActionCriticReturn
=
namedtuple
(
'PPOActionCriticReturn'
,
[
'actions'
,
'sequence'
,
'mask'
,
'prompt_mask'
,
'action_logits'
,
'values'
])
@
beartype
class
ActorCritic
(
nn
.
Module
):
def
__init__
(
self
,
rwkv
:
RWKV
,
critic_palm
:
Optional
[
RWKV
]
=
None
,
pooled_values
=
False
,
actor_lora
=
True
,
critic_lora
=
True
,
actor_lora_r
=
8
,
critic_lora_r
=
8
,
actor_lora_scope
=
'actor'
,
critic_lora_scope
=
'critic'
,
actor_dropout
=
0.
,
critic_dropout
=
0.
):
super
().
__init__
()
self
.
actor_palm
=
rwkv
self
.
critic_palm
=
critic_palm
if
not
exists
(
self
.
critic_palm
):
self
.
critic_palm
=
copy
.
deepcopy
(
rwkv
)
self
.
actor_palm
.
set_dropout
(
actor_dropout
)
self
.
critic_palm
.
set_dropout
(
critic_dropout
)
self
.
actor_lora
=
actor_lora
self
.
critic_lora
=
critic_lora
self
.
actor_lora_scope
=
actor_lora_scope
if
actor_lora
else
None
self
.
critic_lora_scope
=
critic_lora_scope
if
critic_lora
else
None
if
self
.
actor_lora
:
self
.
actor_palm
.
add_finetune_params
(
actor_lora_scope
,
lora_r
=
actor_lora_r
)
if
self
.
critic_lora
:
self
.
critic_palm
.
add_finetune_params
(
critic_lora_scope
,
lora_r
=
critic_lora_r
)
self
.
pooled_values
=
pooled_values
self
.
value_head
=
nn
.
Sequential
(
nn
.
Linear
(
rwkv
.
dim
,
1
),
Rearrange
(
'... 1 -> ...'
)
)
nn
.
init
.
zeros_
(
self
.
value_head
[
0
].
bias
)
nn
.
init
.
orthogonal_
(
self
.
value_head
[
0
].
weight
,
gain
=
math
.
sqrt
(
2
))
def
actor_parameters
(
self
):
if
not
self
.
actor_lora
:
return
self
.
actor_palm
.
parameters
()
return
[
*
self
.
actor_palm
.
finetune_parameters
(
self
.
actor_lora_scope
)
]
def
critic_parameters
(
self
):
if
not
self
.
actor_lora
:
return
[
*
self
.
critic_palm
.
parameters
(),
*
self
.
value_head
.
parameters
()]
return
[
*
self
.
critic_palm
.
finetune_parameters
(
self
.
critic_lora_scope
),
*
self
.
value_head
.
parameters
()
]
@
torch
.
no_grad
()
@
eval_decorator
def
generate
(
self
,
state
,
max_seq_len
,
eos_token
=
None
,
return_values
=
False
,
**
kwargs
):
actions
=
self
.
actor_palm
.
generate
(
max_seq_len
,
prompt
=
state
,
eos_token
=
eos_token
,
finetune_scope
=
self
.
actor_lora_scope
,
use_tqdm
=
True
,
**
kwargs
)
sequence
=
torch
.
cat
((
state
,
actions
),
dim
=
-
1
)
action_len
=
actions
.
shape
[
-
1
]
state_len
=
state
.
shape
[
-
1
]
prompt_mask
=
torch
.
arange
(
sequence
.
shape
[
-
1
],
device
=
state
.
device
)
<
state_len
prompt_mask
=
repeat
(
prompt_mask
,
'n -> b n'
,
b
=
sequence
.
shape
[
0
])
action_mask
=
~
prompt_mask
mask
=
None
if
exists
(
eos_token
):
mask
=
((
sequence
==
eos_token
).
cumsum
(
dim
=
-
1
)
==
0
)
mask
=
F
.
pad
(
mask
,
(
1
,
-
1
),
value
=
True
)
# include eos token
action_mask
&=
mask
action_logits
,
value
=
self
.
forward
(
sequence
,
mask
=
action_mask
,
return_values
=
return_values
)
return
PPOActionCriticReturn
(
actions
,
sequence
,
mask
,
prompt_mask
,
action_logits
,
value
)
def
forward
(
self
,
x
,
mask
=
None
,
return_values
=
True
):
action_logits
=
self
.
actor_palm
(
x
,
finetune_scope
=
self
.
actor_lora_scope
)
if
not
return_values
:
return
action_logits
,
None
critic_embeds
=
self
.
critic_palm
(
x
,
return_only_embedding
=
True
,
finetune_scope
=
self
.
critic_lora_scope
)
if
self
.
pooled_values
:
critic_embeds
=
shift
(
critic_embeds
,
shift
=
1
,
dim
=
-
2
)
critic_embeds
=
masked_mean
(
critic_embeds
,
mask
,
dim
=
1
)
values
=
self
.
value_head
(
critic_embeds
)
return
action_logits
,
values
# data
Memory
=
namedtuple
(
'Memory'
,
[
'sequence'
,
'prompt_mask'
,
'mask'
,
'action_prob'
,
'action_log_prob'
,
'reward'
,
'value'
])
@
beartype
class
ExperienceDataset
(
Dataset
):
def
__init__
(
self
,
data
:
List
[
torch
.
Tensor
],
device
=
None
):
super
().
__init__
()
self
.
data
=
data
self
.
device
=
device
def
__len__
(
self
):
return
self
.
data
[
0
].
shape
[
0
]
def
__getitem__
(
self
,
ind
):
return
tuple
(
map
(
lambda
t
:
t
[
ind
].
to
(
self
.
device
),
self
.
data
))
def
create_dataloader
(
data
,
batch_size
,
shuffle
=
True
,
device
=
None
,
**
kwargs
):
ds
=
ExperienceDataset
(
data
,
device
=
device
)
return
DataLoader
(
ds
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
**
kwargs
)
# helper functions
def
exists
(
val
):
return
val
is
not
None
def
default
(
val
,
d
):
return
val
if
exists
(
val
)
else
d
def
masked_normalize
(
t
,
eps
=
1e-5
,
mask
=
None
,
dim
=
None
):
dim
=
default
(
dim
,
tuple
(
range
(
t
.
ndim
)))
kwargs
=
dict
(
dim
=
dim
,
keepdim
=
True
)
mean
=
masked_mean
(
t
,
mask
=
mask
,
**
kwargs
)
mean_centered
=
t
-
mean
var
=
masked_mean
(
mean_centered
**
2
,
mask
=
mask
,
**
kwargs
)
return
mean_centered
*
var
.
clamp
(
min
=
eps
).
rsqrt
()
def
pad_sequence_fixed
(
sequences
,
*
args
,
**
kwargs
):
first_el
=
sequences
[
0
]
has_no_dimension
=
first_el
.
ndim
==
0
# if no dimensions, add a single dimension
if
has_no_dimension
:
sequences
=
tuple
(
map
(
lambda
t
:
t
[
None
],
sequences
))
out
=
pad_sequence
(
sequences
,
*
args
,
**
kwargs
)
if
has_no_dimension
:
out
=
rearrange
(
out
,
'... 1 -> ...'
)
return
out
def
log
(
t
,
eps
=
1e-20
):
return
torch
.
log
(
t
.
clamp
(
min
=
eps
))
def
log_prob
(
prob
,
indices
):
assert
prob
.
shape
[:
2
]
==
indices
.
shape
,
f
'preceding shapes of prob
{
prob
.
shape
[:
2
]
}
and indices
{
indices
.
shape
}
must match'
return
log
(
prob
.
gather
(
-
1
,
indices
[...,
None
])).
squeeze
(
-
1
)
def
shift
(
t
,
value
=
0
,
shift
=
1
,
dim
=
-
1
):
zeros
=
(
0
,
0
)
*
(
-
dim
-
1
)
return
F
.
pad
(
t
,
(
*
zeros
,
shift
,
-
shift
),
value
=
value
)
def
masked_entropy
(
prob
,
dim
=
-
1
,
mask
=
None
):
entropies
=
(
prob
*
log
(
prob
)).
sum
(
dim
=
-
1
)
return
masked_mean
(
entropies
,
mask
=
mask
).
mean
()
def
masked_kl_div
(
prob1
,
prob2
,
mask
=
None
):
"""
need to account for variable sequence lengths, therefore not using the built-in functional version
"""
kl_divs
=
(
prob1
*
(
log
(
prob2
)
-
log
(
prob1
))).
sum
(
dim
=
-
1
)
if
not
exists
(
mask
):
return
kl_divs
.
mean
()
return
masked_mean
(
kl_divs
,
mask
).
mean
()
def
clipped_value_loss
(
values
,
rewards
,
old_values
,
clip
):
value_clipped
=
old_values
+
(
values
-
old_values
).
clamp
(
-
clip
,
clip
)
value_loss_1
=
(
value_clipped
.
flatten
()
-
rewards
)
**
2
value_loss_2
=
(
values
.
flatten
()
-
rewards
)
**
2
return
torch
.
mean
(
torch
.
max
(
value_loss_1
,
value_loss_2
))
# rlhf trainer
@
beartype
class
RLHFTrainer
(
nn
.
Module
):
def
__init__
(
self
,
args
,
accelerate_kwargs
:
dict
=
{}
):
super
().
__init__
()
self
.
args
=
args
self
.
accelerate
=
Accelerator
(
**
accelerate_kwargs
)
# 加载 RWKV 模型
rwkv
=
RWKV
(
args
)
if
len
(
args
.
load_sft_model
)
==
0
:
rank_zero_info
(
f
"SFT must load model, please input "
)
exit
(
1
)
rank_zero_info
(
f
"########## Loading
{
args
.
load_sft_model
}
... ##########"
)
try
:
load_dict
=
torch
.
load
(
args
.
load_sft_model
,
map_location
=
"cpu"
)
except
:
rank_zero_info
(
f
"Bad checkpoint
{
args
.
load_sft_model
}
"
)
exit
(
1
)
if
args
.
load_partial
==
1
:
load_keys
=
load_dict
.
keys
()
for
k
in
rwkv
.
state_dict
():
if
k
not
in
load_keys
:
load_dict
[
k
]
=
rwkv
.
state_dict
()[
k
]
rwkv
.
load_state_dict
(
load_dict
)
self
.
rwkv
=
rwkv
# 使用 RWKV 初始化 actor_critic
actor_critic
=
ActorCritic
(
rwkv
=
self
.
rwkv
,
actor_lora
=
args
.
actor_lora
,
critic_lora
=
args
.
critic_lora
,
actor_lora_r
=
args
.
actor_lora_r
,
critic_lora_r
=
args
.
critic_lora_r
,
pooled_values
=
args
.
critic_pooled_values
,
actor_dropout
=
args
.
actor_dropout
,
critic_dropout
=
args
.
critic_dropout
).
to
(
self
.
rwkv
.
device
)
self
.
actor_critic
=
actor_critic
# 加载 reward_model,并将 reward_model 设置为 evaluation 模式
reward_model
=
RewardModel
(
args
)
reward_model
.
load
(
args
.
load_rm_model
)
self
.
reward_model
=
reward_model
.
eval
()
# optimizers
self
.
actor_optim
=
get_optimizer
(
actor_critic
.
actor_parameters
(),
lr
=
self
.
args
.
actor_lr
,
wd
=
self
.
args
.
actor_wd
,
betas
=
self
.
args
.
betas
,
eps
=
self
.
args
.
actor_adam_eps
,
use_lion
=
self
.
args
.
use_lion
)
self
.
critic_optim
=
get_optimizer
(
actor_critic
.
critic_parameters
(),
lr
=
self
.
args
.
critic_lr
,
wd
=
self
.
args
.
critic_wd
,
betas
=
self
.
args
.
betas
,
eps
=
self
.
args
.
critic_adam_eps
,
use_lion
=
self
.
args
.
use_lion
)
# prepare with accelerator
(
self
.
actor_critic
,
self
.
reward_model
,
self
.
actor_optim
,
self
.
critic_optim
)
=
self
.
accelerate
.
prepare
(
self
.
actor_critic
,
self
.
reward_model
,
self
.
actor_optim
,
self
.
critic_optim
)
def
print
(
self
,
msg
):
return
self
.
accelerate
.
print
(
msg
)
def
save
(
self
,
filepath
=
'./checkpoint.pt'
):
torch
.
save
(
self
.
actor_critic
.
state_dict
(),
filepath
)
def
load
(
self
,
filepath
=
'./checkpoint.pt'
):
state_dict
=
torch
.
load
(
filepath
)
self
.
actor_critic
.
load_state_dict
(
state_dict
)
@
property
def
device
(
self
):
return
self
.
accelerate
.
device
@
torch
.
no_grad
()
def
generate
(
self
,
max_seq_len
,
*
args
,
prompt
,
num_samples
=
4
,
# sample 4 per prompt and select the one with highest reward
**
kwargs
):
assert
prompt
.
ndim
==
1
,
'only one prompt allowed at a time for now'
prompt
=
repeat
(
prompt
,
'n -> b n'
,
b
=
num_samples
)
actor_critic
=
self
.
accelerate
.
unwrap_model
(
self
.
actor_critic
)
reward_model
=
self
.
accelerate
.
unwrap_model
(
self
.
reward_model
)
actor_critic
.
eval
()
(
actions
,
sequences
,
mask
,
prompt_mask
,
action_logits
,
_
)
=
actor_critic
.
generate
(
prompt
,
*
args
,
max_seq_len
=
max_seq_len
,
return_values
=
False
,
**
kwargs
)
rewards
=
reward_model
(
sequences
,
prompt_mask
=
prompt_mask
,
mask
=
mask
,
sample
=
True
)
best_sequence_index
=
rewards
.
topk
(
1
,
dim
=
-
1
).
indices
best_sequence
=
sequences
[
best_sequence_index
]
best_sequence
=
rearrange
(
best_sequence
,
'1 ... -> ...'
)
return
best_sequence
def
learn
(
self
,
memories
:
Deque
[
Memory
]
):
# stack all data stored in the memories
all_memories_stacked_and_padded
=
list
(
map
(
partial
(
pad_sequence_fixed
,
batch_first
=
True
),
zip
(
*
memories
)))
# prepare dataloader for policy phase training
dl
=
create_dataloader
(
all_memories_stacked_and_padded
,
self
.
minibatch_size
,
device
=
self
.
device
)
self
.
actor_critic
.
train
()
# PPO training
for
_
in
range
(
self
.
epochs
):
for
(
sequences
,
prompt_masks
,
masks
,
old_action_probs
,
old_log_probs
,
rewards
,
old_values
)
in
dl
:
action_masks
=
~
prompt_masks
&
masks
action_logits
,
values
=
self
.
actor_critic
(
sequences
,
mask
=
action_masks
)
action_logits
=
shift
(
action_logits
,
shift
=
1
,
dim
=
-
2
)
# need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
action_len
=
old_log_probs
.
shape
[
-
1
]
action_probs
=
action_logits
.
softmax
(
dim
=
-
1
)
action_log_probs
=
log_prob
(
action_probs
,
sequences
)
action_log_probs
=
action_log_probs
[:,
-
action_len
:]
# calculate entropies, taking into account which part of the sequence is actually an action
entropies
=
masked_entropy
(
action_probs
,
mask
=
action_masks
)
# calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not
kl_div_loss
=
0.
if
self
.
args
.
kl_div_loss_weight
>
0
:
kl_div_loss
=
masked_kl_div
(
action_probs
,
old_action_probs
,
mask
=
action_masks
)
*
self
.
args
.
kl_div_loss_weight
# handle non-pooled values
normalize_kwargs
=
dict
()
if
old_values
.
ndim
==
2
:
old_values
,
values
=
map
(
lambda
t
:
shift
(
t
,
shift
=
1
,
dim
=
-
2
),
(
old_values
,
values
))
old_values
=
old_values
[:,
-
action_len
:]
values
=
values
[:,
-
action_len
:]
rewards
=
rearrange
(
rewards
,
'b -> b 1'
)
normalize_kwargs
=
dict
(
dim
=
-
1
,
mask
=
action_masks
[:,
-
action_len
:])
if
values
.
ndim
<
rewards
.
ndim
:
values
=
rearrange
(
values
,
'... -> ... 1'
)
# calculate clipped surrogate objective, classic PPO loss
ratios
=
(
action_log_probs
-
old_log_probs
).
exp
()
advantages
=
masked_normalize
(
rewards
-
old_values
,
**
normalize_kwargs
)
if
advantages
.
ndim
==
1
:
advantages
=
rearrange
(
advantages
,
'b -> b 1'
)
surr1
=
ratios
*
advantages
surr2
=
ratios
.
clamp
(
1
-
self
.
args
.
eps_clip
,
1
+
self
.
args
.
eps_clip
)
*
advantages
policy_loss
=
-
torch
.
min
(
surr1
,
surr2
)
-
self
.
args
.
beta_s
*
entropies
# combine losses
loss
=
policy_loss
.
mean
()
+
kl_div_loss
# update actor
self
.
accelerate
.
backward
(
loss
)
self
.
print
(
f
'policy_loss:
{
loss
.
item
():.
3
f
}
'
)
if
exists
(
self
.
args
.
max_norm
):
self
.
accelerator
.
clip_grad_norm_
(
self
.
actor_critic
.
actor_parameters
(),
self
.
args
.
max_norm
)
self
.
actor_optim
.
step
()
self
.
actor_optim
.
zero_grad
()
# calculate value loss and update value network separate from policy network
value_loss
=
clipped_value_loss
(
values
,
rewards
,
old_values
,
self
.
args
.
value_clip
)
value_loss
=
value_loss
.
mean
()
self
.
print
(
f
'critic_loss:
{
value_loss
.
item
():.
3
f
}
'
)
self
.
accelerate
.
backward
(
value_loss
)
if
exists
(
self
.
args
.
max_norm
):
self
.
accelerator
.
clip_grad_norm_
(
self
.
actor_critic
.
critic_parameters
(),
self
.
args
.
max_norm
)
self
.
critic_optim
.
step
()
self
.
critic_optim
.
zero_grad
()
def
train
(
self
,
num_episodes
=
50000
,
max_timesteps
=
500
,
update_timesteps
=
5000
,
max_batch_size
=
16
,
eos_token
=
None
,
temperature
=
1.
):
device
=
self
.
device
time
=
0
memories
=
deque
([])
for
eps
in
tqdm
(
range
(
num_episodes
),
desc
=
'episodes'
):
for
timestep
in
range
(
max_timesteps
):
time
+=
1
# select a bunch of random states (prompts)
# and get the action (sampled sequence from palm as well as the action probs)
# also calculate the reward using reward model and store
rand_prompt_index
=
randrange
(
0
,
self
.
num_prompts
)
state
=
self
.
prompt_token_ids
[
rand_prompt_index
]
# remove padding from state
state_mask
=
state
!=
self
.
args
.
pad_value
state
=
state
[
state_mask
]
# get predicted sequence
(
actions
,
sequence
,
mask
,
prompt_mask
,
action_logits
,
value
)
=
self
.
actor_critic
.
generate
(
rearrange
(
state
,
'n -> 1 n'
),
max_seq_len
=
self
.
args
.
ctx_len
,
eos_token
=
eos_token
,
temperature
=
temperature
,
return_values
=
True
)
action_logits
=
shift
(
action_logits
,
shift
=
1
,
dim
=
-
2
)
# need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
action_prob
=
action_logits
.
softmax
(
dim
=
-
1
)
action_len
=
actions
.
shape
[
-
1
]
action_log_prob
=
log_prob
(
action_prob
,
sequence
)
action_log_prob
=
action_log_prob
[:,
-
action_len
:]
actions
=
rearrange
(
actions
,
'1 ... -> ...'
)
# get reward as given by supervised trained reward model
sequence
=
torch
.
cat
((
state
,
actions
),
dim
=
0
)
prompt_length
=
len
(
state
)
prompt_mask
=
torch
.
arange
(
sequence
.
shape
[
-
1
],
device
=
device
)
<
prompt_length
sequence
=
rearrange
(
sequence
,
'n -> 1 n'
)
prompt_mask
=
rearrange
(
prompt_mask
,
'n -> 1 n'
)
mask
=
rearrange
(
mask
,
'n -> 1 n'
)
if
exists
(
mask
)
else
torch
.
ones
(
sequence
.
shape
,
dtype
=
torch
.
bool
,
device
=
device
)
reward
=
self
.
reward_model
(
sequence
,
prompt_mask
=
prompt_mask
,
mask
=
mask
,
sample
=
True
)
detach_to_cpu_
=
lambda
t
:
rearrange
(
t
.
detach
().
cpu
(),
'1 ... -> ...'
)
# store memory for learning
memories
.
append
(
Memory
(
*
map
(
detach_to_cpu_
,
(
sequence
,
prompt_mask
,
mask
,
action_prob
,
action_log_prob
,
reward
,
value
))))
# learn from the stored memories
if
time
%
update_timesteps
==
0
:
self
.
learn
(
memories
)
memories
.
clear
()
print
(
'rlhf training complete'
)
src/rlhf/utils.py
浏览文件 @
e7dc79af
...
...
@@ -8,6 +8,9 @@ from einops import rearrange
def
exists
(
val
):
return
val
is
not
None
def
identity
(
t
,
*
args
,
**
kwargs
):
return
t
# decorators
def
eval_decorator
(
fn
):
...
...
src/trainer.py
浏览文件 @
e7dc79af
...
...
@@ -287,6 +287,142 @@ class rm_train_callback(pl.Callback):
trainer
.
my_loss_count
=
0
class
rlhf_train_callback
(
pl
.
Callback
):
def
__init__
(
self
,
args
):
super
().
__init__
()
self
.
args
=
args
def
on_train_batch_start
(
self
,
trainer
,
pl_module
,
batch
,
batch_idx
):
args
=
self
.
args
# if args.cuda_cleanup > 0:
# torch.cuda.empty_cache()
real_step
=
trainer
.
global_step
+
args
.
epoch_begin
*
args
.
epoch_steps
# LR schedule
w_step
=
args
.
warmup_steps
if
args
.
lr_final
==
args
.
lr_init
or
args
.
epoch_count
==
0
:
lr
=
args
.
lr_init
else
:
decay_step
=
real_step
-
args
.
my_pile_edecay
*
args
.
epoch_steps
decay_total
=
(
args
.
epoch_count
-
args
.
my_pile_edecay
)
*
args
.
epoch_steps
progress
=
(
decay_step
-
w_step
+
1
)
/
(
decay_total
-
w_step
)
progress
=
min
(
1
,
max
(
0
,
progress
))
if
args
.
lr_final
==
0
or
args
.
lr_init
==
0
:
# linear decay
lr
=
args
.
lr_init
+
(
args
.
lr_final
-
args
.
lr_init
)
*
progress
else
:
# exp decay
lr
=
args
.
lr_init
*
math
.
exp
(
math
.
log
(
args
.
lr_final
/
args
.
lr_init
)
*
pow
(
progress
,
1
))
if
trainer
.
global_step
<
w_step
:
lr
=
lr
*
(
0.2
+
0.8
*
trainer
.
global_step
/
w_step
)
# if trainer.is_global_zero:
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
for
param_group
in
trainer
.
optimizers
[
0
].
param_groups
:
if
args
.
layerwise_lr
>
0
:
param_group
[
"lr"
]
=
lr
*
param_group
[
"my_lr_scale"
]
# print(param_group["lr"], param_group["my_lr_scale"])
else
:
param_group
[
"lr"
]
=
lr
trainer
.
my_lr
=
lr
# rank_zero_info(f"{real_step} {lr}")
if
trainer
.
global_step
==
0
:
if
trainer
.
is_global_zero
:
# logging
trainer
.
my_loss_sum
=
0
trainer
.
my_loss_count
=
0
trainer
.
my_log
=
open
(
args
.
proj_dir
+
"/train_log.txt"
,
"a"
)
trainer
.
my_log
.
write
(
f
"NEW RUN
{
args
.
my_timestamp
}
\n
{
vars
(
self
.
args
)
}
\n
"
)
try
:
print
(
f
"
\n
{
trainer
.
strategy
.
config
}
\n
"
)
trainer
.
my_log
.
write
(
f
"
{
trainer
.
strategy
.
config
}
\n
"
)
except
:
pass
trainer
.
my_log
.
flush
()
if
len
(
args
.
wandb
)
>
0
:
print
(
"Login to wandb..."
)
import
wandb
wandb
.
init
(
project
=
args
.
wandb
,
name
=
args
.
run_name
+
" "
+
args
.
my_timestamp
,
config
=
args
,
save_code
=
False
,
)
trainer
.
my_wandb
=
wandb
def
on_train_batch_end
(
self
,
trainer
,
pl_module
,
outputs
,
batch
,
batch_idx
):
args
=
self
.
args
if
trainer
.
is_global_zero
:
# logging
t_now
=
time
.
time_ns
()
token_per_step
=
args
.
ctx_len
*
args
.
real_bsz
real_step
=
trainer
.
global_step
+
args
.
epoch_begin
*
args
.
epoch_steps
kt_s
=
0
try
:
t_cost
=
(
t_now
-
trainer
.
my_time_ns
)
/
1e9
kt_s
=
token_per_step
/
t_cost
/
1000
self
.
log
(
"REAL it/s"
,
1.0
/
t_cost
,
prog_bar
=
True
,
on_step
=
True
)
self
.
log
(
"Kt/s"
,
kt_s
,
prog_bar
=
True
,
on_step
=
True
)
except
:
pass
trainer
.
my_time_ns
=
t_now
trainer
.
my_loss
=
trainer
.
my_loss_all
.
float
().
mean
().
item
()
trainer
.
my_loss_sum
+=
trainer
.
my_loss
trainer
.
my_loss_count
+=
1
trainer
.
my_epoch_loss
=
trainer
.
my_loss_sum
/
trainer
.
my_loss_count
self
.
log
(
"lr"
,
trainer
.
my_lr
,
prog_bar
=
True
,
on_step
=
True
)
self
.
log
(
"loss"
,
trainer
.
my_epoch_loss
,
prog_bar
=
True
,
on_step
=
True
)
# self.log("s", real_step, prog_bar=True, on_step=True)
if
len
(
args
.
wandb
)
>
0
:
lll
=
{
"loss"
:
trainer
.
my_loss
,
"lr"
:
trainer
.
my_lr
,
"Gtokens"
:
real_step
*
token_per_step
/
1e9
}
if
kt_s
>
0
:
lll
[
"kt/s"
]
=
kt_s
trainer
.
my_wandb
.
log
(
lll
,
step
=
int
(
real_step
))
if
args
.
magic_prime
>
0
:
if
int
(
real_step
)
==
int
(
args
.
magic_prime
*
(
1
+
args
.
my_qa_mask
)
//
args
.
real_bsz
)
-
1
:
to_save_dict
=
pl_module
.
state_dict
()
my_save
(
to_save_dict
,
f
"
{
args
.
proj_dir
}
/rlhf-final.pth"
,
)
def
on_train_epoch_start
(
self
,
trainer
,
pl_module
):
args
=
self
.
args
dataset
=
trainer
.
train_dataloader
.
dataset
.
datasets
assert
"RMDataset"
in
str
(
dataset
)
dataset
.
global_rank
=
trainer
.
global_rank
dataset
.
real_epoch
=
int
(
args
.
epoch_begin
+
trainer
.
current_epoch
)
dataset
.
world_size
=
trainer
.
world_size
# print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
def
on_train_epoch_end
(
self
,
trainer
,
pl_module
):
args
=
self
.
args
if
trainer
.
is_global_zero
:
# logging & save state_dict
if
(
args
.
epoch_save
>
0
and
trainer
.
current_epoch
%
args
.
epoch_save
==
0
)
or
trainer
.
current_epoch
==
args
.
epoch_count
-
1
:
if
args
.
data_type
==
'wds_img'
:
raw_dict
=
pl_module
.
state_dict
()
to_save_dict
=
{}
for
k
in
raw_dict
:
if
k
.
startswith
(
'encoder.'
)
or
k
.
startswith
(
'decoder.'
):
to_save_dict
[
k
]
=
raw_dict
[
k
]
else
:
to_save_dict
=
pl_module
.
state_dict
()
try
:
my_save
(
to_save_dict
,
f
"
{
args
.
proj_dir
}
/rlhf-
{
args
.
epoch_begin
+
trainer
.
current_epoch
}
.pth"
,
)
except
Exception
as
e
:
print
(
'Error
\n\n
'
,
e
,
'
\n\n
'
)
trainer
.
my_log
.
write
(
f
"
{
args
.
epoch_begin
+
trainer
.
current_epoch
}
{
trainer
.
my_epoch_loss
:.
6
f
}
{
math
.
exp
(
trainer
.
my_epoch_loss
):.
4
f
}
{
trainer
.
my_lr
:.
8
f
}
{
datetime
.
datetime
.
now
()
}
{
trainer
.
current_epoch
}
\n
"
)
trainer
.
my_log
.
flush
()
trainer
.
my_loss_sum
=
0
trainer
.
my_loss_count
=
0
@
rank_zero_only
def
generate_init_weight
(
model
,
init_weight_name
):
mm
=
model
.
generate_init_weight
()
...
...
train_ppo.py
浏览文件 @
e7dc79af
...
...
@@ -6,38 +6,293 @@
'''
# here put the import lib
import
torch
from
src.model
import
RWKV
from
src.rlhf.reward
import
RewardModel
from
src.rlhf.ppo
import
RLHFTrainer
# load your pretrained RWKV
# todo(luxin) 加载 SFT 之后的预训练模型
rwkv_model
=
RWKV
()
# palm.load('./path/to/pretrained/palm.pt')
# load your pretrained reward model
# todo(luxin) 加载训练好的 reward Model
reward_model
=
RewardModel
(
rwkv_model
,
num_binned_output
=
5
)
# reward_model.load('./path/to/pretrained/reward_model.pt')
# ready your list of prompts for reinforcement learning
# todo(luxin) 读入 Prompts 数据集(此处的 Prompt 与 SFT、RM 阶段的 Prompt 要不一样)
prompts
=
torch
.
randint
(
0
,
256
,
(
50000
,
512
))
# 50k prompts
# pass it all to the trainer and train
# 训练 PPO 模型
trainer
=
RLHFTrainer
(
palm
=
palm
,
reward_model
=
reward_model
,
prompt_token_ids
=
prompts
)
trainer
.
train
(
num_episodes
=
100
)
# then, if it succeeded...
# generate say 10 samples and use the reward model to return the best one
answer
=
trainer
.
generate
(
2048
,
prompt
=
prompts
[
0
],
num_samples
=
10
)
# (<= 2048,)
print
(
answer
)
\ No newline at end of file
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
if
__name__
==
"__main__"
:
from
argparse
import
ArgumentParser
from
pytorch_lightning
import
Trainer
from
pytorch_lightning.utilities
import
rank_zero_info
,
rank_zero_only
rank_zero_info
(
"########## work in progress ##########"
)
########################################################################################################
#
# example: train a simple L12-D768 RWKV on dummy data
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "" --data_type "dummy" --vocab_size 0 \
# --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
# --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
# --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: train a simple L6-D512 RWKV from scratch on enwik8
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
parser
=
ArgumentParser
()
parser
.
add_argument
(
"--load_sft_model"
,
default
=
""
,
type
=
str
)
# full path, with .pth
parser
.
add_argument
(
"--load_rm_model"
,
default
=
""
,
type
=
str
)
# full path, with .pth
parser
.
add_argument
(
"--wandb"
,
default
=
""
,
type
=
str
)
# wandb project name. if "" then don't use wandb
parser
.
add_argument
(
"--proj_dir"
,
default
=
"out"
,
type
=
str
)
parser
.
add_argument
(
"--random_seed"
,
default
=
"-1"
,
type
=
int
)
parser
.
add_argument
(
"--data_file"
,
default
=
""
,
type
=
str
)
parser
.
add_argument
(
"--data_type"
,
default
=
"utf-8"
,
type
=
str
)
parser
.
add_argument
(
"--vocab_size"
,
default
=
0
,
type
=
int
)
# vocab_size = 0 means auto (for char-level LM and .txt data)
parser
.
add_argument
(
"--ctx_len"
,
default
=
1024
,
type
=
int
)
parser
.
add_argument
(
"--epoch_steps"
,
default
=
1000
,
type
=
int
)
# a mini "epoch" has [epoch_steps] steps
parser
.
add_argument
(
"--epoch_count"
,
default
=
500
,
type
=
int
)
# train for this many "epochs". will continue afterwards with lr = lr_final
parser
.
add_argument
(
"--epoch_begin"
,
default
=
0
,
type
=
int
)
# if you load a model trained for x "epochs", set epoch_begin = x
parser
.
add_argument
(
"--epoch_save"
,
default
=
5
,
type
=
int
)
# save the model every [epoch_save] "epochs"
parser
.
add_argument
(
"--micro_bsz"
,
default
=
12
,
type
=
int
)
# micro batch size (batch size per GPU)
parser
.
add_argument
(
"--n_layer"
,
default
=
6
,
type
=
int
)
parser
.
add_argument
(
"--n_embd"
,
default
=
512
,
type
=
int
)
parser
.
add_argument
(
"--dim_att"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--dim_ffn"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--pre_ffn"
,
default
=
0
,
type
=
int
)
# replace first att layer by ffn (sometimes better)
parser
.
add_argument
(
"--head_qk"
,
default
=
0
,
type
=
int
)
# my headQK trick
parser
.
add_argument
(
"--tiny_att_dim"
,
default
=
0
,
type
=
int
)
# tiny attention dim
parser
.
add_argument
(
"--tiny_att_layer"
,
default
=-
999
,
type
=
int
)
# tiny attention @ which layer
parser
.
add_argument
(
"--lr_init"
,
default
=
6e-4
,
type
=
float
)
# 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser
.
add_argument
(
"--lr_final"
,
default
=
1e-5
,
type
=
float
)
parser
.
add_argument
(
"--warmup_steps"
,
default
=
0
,
type
=
int
)
# try 50 if you load a model
parser
.
add_argument
(
"--beta1"
,
default
=
0.9
,
type
=
float
)
parser
.
add_argument
(
"--beta2"
,
default
=
0.99
,
type
=
float
)
# use 0.999 when your model is close to convergence
parser
.
add_argument
(
"--adam_eps"
,
default
=
1e-8
,
type
=
float
)
parser
.
add_argument
(
"--grad_cp"
,
default
=
0
,
type
=
int
)
# gradient checkpt: saves VRAM, but slower
parser
.
add_argument
(
"--my_pile_stage"
,
default
=
0
,
type
=
int
)
# my special pile mode
parser
.
add_argument
(
"--my_pile_shift"
,
default
=-
1
,
type
=
int
)
# my special pile mode - text shift
parser
.
add_argument
(
"--my_pile_edecay"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--layerwise_lr"
,
default
=
1
,
type
=
int
)
# layerwise lr for faster convergence (but slower it/s)
parser
.
add_argument
(
"--ds_bucket_mb"
,
default
=
200
,
type
=
int
)
# deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
parser
.
add_argument
(
"--my_img_version"
,
default
=
0
,
type
=
str
)
parser
.
add_argument
(
"--my_img_size"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--my_img_bit"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--my_img_clip"
,
default
=
'x'
,
type
=
str
)
parser
.
add_argument
(
"--my_img_clip_scale"
,
default
=
1
,
type
=
float
)
parser
.
add_argument
(
"--my_img_l1_scale"
,
default
=
0
,
type
=
float
)
parser
.
add_argument
(
"--my_img_encoder"
,
default
=
'x'
,
type
=
str
)
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
parser
.
add_argument
(
"--my_sample_len"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--my_ffn_shift"
,
default
=
1
,
type
=
int
)
parser
.
add_argument
(
"--my_att_shift"
,
default
=
1
,
type
=
int
)
parser
.
add_argument
(
"--my_pos_emb"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--load_partial"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--magic_prime"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--my_qa_mask"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--my_testing"
,
default
=
''
,
type
=
str
)
# PPO model parameters
parser
.
add_argument
(
"--critic_pooled_values"
,
default
=
True
,
type
=
bool
)
parser
.
add_argument
(
"--max_norm"
,
default
=
None
,
type
=
float
)
parser
.
add_argument
(
"--kl_div_loss_weight"
,
default
=
0.1
,
type
=
float
)
# between old action probs and new action probs - not sure what the right value is
parser
.
add_argument
(
"--eps_clip"
,
default
=
0.2
,
type
=
float
)
parser
.
add_argument
(
"--value_clip"
,
default
=
0.4
,
type
=
float
)
parser
.
add_argument
(
"--beta_s"
,
default
=
0.01
,
type
=
float
)
parser
.
add_argument
(
"--actor_lr"
,
default
=
1e-4
,
type
=
float
)
parser
.
add_argument
(
"--critic_lr"
,
default
=
1e-4
,
type
=
float
)
parser
.
add_argument
(
"--actor_wd"
,
default
=
0.
,
type
=
float
)
parser
.
add_argument
(
"--critic_wd"
,
default
=
0.
,
type
=
float
)
parser
.
add_argument
(
"--actor_adam_eps"
,
default
=
1e-7
,
type
=
float
)
parser
.
add_argument
(
"--critic_adam_eps"
,
default
=
1e-7
,
type
=
float
)
parser
.
add_argument
(
"--pad_value"
,
default
=
1
,
type
=
float
)
# token pad value
parser
.
add_argument
(
"--use_lion"
,
default
=
False
,
type
=
bool
)
parser
.
add_argument
(
"--num_episodes"
,
default
=
50000
,
type
=
int
)
parser
.
add_argument
(
"--max_timesteps"
,
default
=
500
,
type
=
int
)
parser
.
add_argument
(
"--update_timesteps"
,
default
=
5000
,
type
=
int
)
parser
=
Trainer
.
add_argparse_args
(
parser
)
args
=
parser
.
parse_args
()
########################################################################################################
import
os
,
warnings
,
math
,
datetime
,
sys
,
time
import
numpy
as
np
import
torch
from
torch.utils.data
import
DataLoader
import
deepspeed
import
pytorch_lightning
as
pl
from
pytorch_lightning
import
seed_everything
if
args
.
random_seed
>=
0
:
print
(
f
"########## WARNING: GLOBAL SEED
{
args
.
random_seed
}
THIS WILL AFFECT MULTIGPU SAMPLING ##########
\n
"
*
3
)
seed_everything
(
args
.
random_seed
)
np
.
set_printoptions
(
precision
=
4
,
suppress
=
True
,
linewidth
=
200
)
warnings
.
filterwarnings
(
"ignore"
,
".*Consider increasing the value of the `num_workers` argument*"
)
warnings
.
filterwarnings
(
"ignore"
,
".*The progress bar already tracks a metric with the*"
)
# os.environ["WDS_SHOW_SEED"] = "1"
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
args
.
my_timestamp
=
datetime
.
datetime
.
today
().
strftime
(
"%Y-%m-%d-%H-%M-%S"
)
args
.
enable_checkpointing
=
False
args
.
replace_sampler_ddp
=
False
args
.
logger
=
False
args
.
gradient_clip_val
=
1.0
args
.
num_sanity_val_steps
=
0
args
.
check_val_every_n_epoch
=
int
(
1e20
)
args
.
log_every_n_steps
=
int
(
1e20
)
args
.
max_epochs
=
-
1
# continue forever
args
.
betas
=
(
args
.
beta1
,
args
.
beta2
)
args
.
real_bsz
=
int
(
args
.
num_nodes
)
*
int
(
args
.
devices
)
*
args
.
micro_bsz
os
.
environ
[
"RWKV_T_MAX"
]
=
str
(
args
.
ctx_len
)
os
.
environ
[
"RWKV_MY_TESTING"
]
=
args
.
my_testing
if
args
.
dim_att
<=
0
:
args
.
dim_att
=
args
.
n_embd
if
args
.
dim_ffn
<=
0
:
args
.
dim_ffn
=
args
.
n_embd
*
4
args
.
run_name
=
f
"
{
args
.
vocab_size
}
ctx
{
args
.
ctx_len
}
L
{
args
.
n_layer
}
D
{
args
.
n_embd
}
"
if
not
os
.
path
.
exists
(
args
.
proj_dir
):
os
.
makedirs
(
args
.
proj_dir
)
samples_per_epoch
=
args
.
epoch_steps
*
args
.
real_bsz
tokens_per_epoch
=
samples_per_epoch
*
args
.
ctx_len
rank_zero_info
(
f
"""
############################################################################
#
# RWKV-4
{
args
.
precision
.
upper
()
}
on
{
args
.
num_nodes
}
x
{
args
.
devices
}
{
args
.
accelerator
.
upper
()
}
, bsz
{
args
.
num_nodes
}
x
{
args
.
devices
}
x
{
args
.
micro_bsz
}
=
{
args
.
real_bsz
}
,
{
args
.
strategy
}
{
'with grad_cp'
if
args
.
grad_cp
>
0
else
''
}
#
# Data =
{
args
.
data_file
}
(
{
args
.
data_type
}
), ProjDir =
{
args
.
proj_dir
}
#
# Epoch =
{
args
.
epoch_begin
}
to
{
args
.
epoch_begin
+
args
.
epoch_count
-
1
}
(will continue afterwards), save every
{
args
.
epoch_save
}
epoch
#
# Each "epoch" =
{
args
.
epoch_steps
}
steps,
{
samples_per_epoch
}
samples,
{
tokens_per_epoch
}
tokens
#
# Model =
{
args
.
n_layer
}
n_layer,
{
args
.
n_embd
}
n_embd,
{
args
.
ctx_len
}
ctx_len
#
# Adam = lr
{
args
.
lr_init
}
to
{
args
.
lr_final
}
, warmup
{
args
.
warmup_steps
}
steps, beta
{
args
.
betas
}
, eps
{
args
.
adam_eps
}
#
# Found torch
{
torch
.
__version__
}
, recommend 1.12.1+cu116 or newer
# Found deepspeed
{
deepspeed
.
__version__
}
, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning
{
pl
.
__version__
}
, recommend 1.7.4 or newer
#
############################################################################
"""
)
rank_zero_info
(
str
(
vars
(
args
))
+
"
\n
"
)
assert
args
.
data_type
in
[
"utf-8"
,
"utf-16le"
,
"numpy"
,
"binidx"
,
"dummy"
,
"wds_img"
,
"uint16"
]
if
args
.
lr_final
==
0
or
args
.
lr_init
==
0
:
rank_zero_info
(
"
\n\n
Note: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.
\n\n
"
)
assert
args
.
precision
in
[
"fp32"
,
"tf32"
,
"fp16"
,
"bf16"
]
os
.
environ
[
"RWKV_FLOAT_MODE"
]
=
args
.
precision
if
args
.
precision
==
"fp32"
:
rank_zero_info
(
"
\n\n
Note: you are using fp32 (very slow). Try bf16 / tf32 for faster training.
\n\n
"
)
if
args
.
precision
==
"fp16"
:
rank_zero_info
(
"
\n\n
Note: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.
\n\n
"
)
os
.
environ
[
"RWKV_JIT_ON"
]
=
"1"
if
"deepspeed_stage_3"
in
args
.
strategy
:
os
.
environ
[
"RWKV_JIT_ON"
]
=
"0"
torch
.
backends
.
cudnn
.
benchmark
=
True
torch
.
backends
.
cudnn
.
enabled
=
True
if
args
.
precision
==
"fp32"
:
torch
.
backends
.
cudnn
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
else
:
torch
.
backends
.
cudnn
.
allow_tf32
=
True
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
if
"32"
in
args
.
precision
:
args
.
precision
=
32
elif
args
.
precision
==
"fp16"
:
args
.
precision
=
16
else
:
args
.
precision
=
"bf16"
########################################################################################################
from
tqdm
import
tqdm
from
collections
import
deque
,
namedtuple
from
einops
import
rearrange
from
src.dataset
import
PPODataset
,
load_prompt_data_4_ppo
from
src.rlhf.ppo
import
RLHF
from
src.trainer
import
rlhf_train_callback
# 用于 PPO 训练的数据,需要与 environment 交互获得
memory
=
[]
# 读入训练数据集
prompts
=
load_prompt_data_4_ppo
(
args
)
# PPO 模型
rlhf_model
=
RLHF
(
args
)
# 模型训练
# trainer
trainer
=
Trainer
.
from_argparse_args
(
args
,
callbacks
=
[
rlhf_train_callback
(
args
)],
)
time_cnt
=
0
for
eps
in
tqdm
(
range
(
args
.
num_episodes
),
desc
=
'episodes'
):
for
timestep
in
range
(
args
.
max_timesteps
):
time_cnt
+=
1
# 生成 ppo 模型的训练数据
experience_data
=
rlhf_model
.
make_experience
(
prompts
,
eos_token
=
0
)
memory
.
append
(
experience_data
)
# learn from the stored memories
if
time_cnt
%
args
.
update_timesteps
==
0
:
if
trainer
.
global_rank
==
0
:
for
n
in
rlhf_model
.
state_dict
():
shape
=
rlhf_model
.
state_dict
()[
n
].
shape
shape
=
[
i
for
i
in
shape
if
i
!=
1
]
if
len
(
shape
)
>
1
:
print
(
f
"
{
str
(
shape
[
0
]).
ljust
(
5
)
}
{
str
(
shape
[
1
]).
ljust
(
5
)
}
{
n
}
"
)
else
:
print
(
f
"
{
str
(
shape
[
0
]).
ljust
(
5
)
}
{
n
}
"
)
train_data
=
PPODataset
(
memory
)
data_loader
=
DataLoader
(
train_data
,
shuffle
=
False
,
pin_memory
=
True
,
batch_size
=
args
.
micro_bsz
,
num_workers
=
1
,
persistent_workers
=
False
,
drop_last
=
True
)
trainer
.
fit
(
rlhf_model
,
data_loader
)
print
(
'rlhf training complete'
)
train_rm.py
浏览文件 @
e7dc79af
...
...
@@ -224,8 +224,8 @@ if __name__ == "__main__":
import
torch
from
tqdm
import
tqdm
from
src.trainer
import
rm_train_callback
from
src.rlhf.reward
import
RewardModel
from
src.dataset
import
RMDataset
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录