Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
CSDN 技术社区
ai
chatCSDN
提交
be4440da
C
chatCSDN
项目概览
CSDN 技术社区
/
ai
/
chatCSDN
通知
107
Star
8
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
chatCSDN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
be4440da
编写于
3月 22, 2023
作者:
U
u010280923
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
transfer ppo code to pytorch_lightning style
上级
2164e3e5
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
221 addition
and
166 deletion
+221
-166
src/dataset.py
src/dataset.py
+11
-21
src/model.py
src/model.py
+3
-4
src/rlhf/ppo.py
src/rlhf/ppo.py
+205
-122
train_ppo.py
train_ppo.py
+2
-19
未找到文件。
src/dataset.py
浏览文件 @
be4440da
...
@@ -12,6 +12,9 @@ from src.utils import TOKENIZER
...
@@ -12,6 +12,9 @@ from src.utils import TOKENIZER
from
.binidx
import
MMapIndexedDataset
from
.binidx
import
MMapIndexedDataset
from
.utils
import
MaybeIsPrime
from
.utils
import
MaybeIsPrime
from
typing
import
Iterable
,
Callable
from
torch.utils.data
import
IterableDataset
class
MyDataset
(
Dataset
):
class
MyDataset
(
Dataset
):
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
...
@@ -326,25 +329,14 @@ class RMDataset(Dataset):
...
@@ -326,25 +329,14 @@ class RMDataset(Dataset):
return
x_p
,
x_a
,
m_p
,
m_a
return
x_p
,
x_a
,
m_p
,
m_a
class
PPODataset
(
Dataset
):
class
ExperienceDataset
(
IterableDataset
):
def
__init__
(
self
,
memory
):
def
__init__
(
self
,
generate_batch
:
Callable
):
self
.
data
=
memory
super
().
__init__
()
self
.
generate_batch
=
generate_batch
def
__len__
(
self
):
def
__iter__
(
self
)
->
Iterable
:
return
len
(
self
.
data
)
iterator
=
self
.
generate_batch
()
return
iterator
def
__getitem__
(
self
,
index
):
# todo(luxin) 是否需要 padding ???
sequence
,
\
prompt_mask
,
\
mask
,
\
action_prob
,
\
action_log_prob
,
\
reward
,
\
value
=
self
.
data
[
index
]
return
sequence
,
prompt_mask
,
mask
,
action_prob
,
action_log_prob
,
reward
,
value
def
load_prompt_data_4_ppo
(
args
):
def
load_prompt_data_4_ppo
(
args
):
...
@@ -356,14 +348,12 @@ def load_prompt_data_4_ppo(args):
...
@@ -356,14 +348,12 @@ def load_prompt_data_4_ppo(args):
]
# [vocab, vocab] for Pile model
]
# [vocab, vocab] for Pile model
tokenizer
=
TOKENIZER
(
WORD_NAME
)
tokenizer
=
TOKENIZER
(
WORD_NAME
)
ctx_len
=
args
.
ctx_len
req_len
=
ctx_len
pf
=
pd
.
read_csv
(
args
.
data_file
)
pf
=
pd
.
read_csv
(
args
.
data_file
)
for
index
,
row
in
pf
.
iterrows
():
for
index
,
row
in
pf
.
iterrows
():
prompt
=
row
[
"prompt"
]
prompt
=
row
[
"prompt"
]
prompt_idx
=
tokenizer
.
tokenizer
.
encode
(
prompt
)
prompt_idx
=
tokenizer
.
tokenizer
.
encode
(
prompt
)
prompt_idx
=
prompt_idx
[:
req
_len
]
prompt_idx
=
prompt_idx
[:
args
.
ctx
_len
]
prompt_token_ids
.
append
(
prompt_token_ids
.
append
(
torch
.
tensor
(
prompt_idx
,
dtype
=
torch
.
long
))
torch
.
tensor
(
prompt_idx
,
dtype
=
torch
.
long
))
...
...
src/model.py
浏览文件 @
be4440da
...
@@ -508,7 +508,6 @@ class RWKV(pl.LightningModule):
...
@@ -508,7 +508,6 @@ class RWKV(pl.LightningModule):
filter_logits_fn
=
top_k
,
filter_logits_fn
=
top_k
,
filter_thres
=
0.9
,
filter_thres
=
0.9
,
pad_value
=
0.
,
pad_value
=
0.
,
eos_token
=
None
,
return_seq_without_prompt
=
True
return_seq_without_prompt
=
True
):
):
''' 生成 response,用于 ppo 模型的训练
''' 生成 response,用于 ppo 模型的训练
...
@@ -521,7 +520,7 @@ class RWKV(pl.LightningModule):
...
@@ -521,7 +520,7 @@ class RWKV(pl.LightningModule):
sample_num_times
=
max
(
1
,
seq_len
-
prompt
.
shape
[
-
1
])
sample_num_times
=
max
(
1
,
seq_len
-
prompt
.
shape
[
-
1
])
for
_
in
tqdm
(
range
(
sample_num_times
),
desc
=
"gen responses"
):
for
_
in
tqdm
(
range
(
sample_num_times
),
desc
=
"gen responses"
):
pad_idx
=
torch
.
tensor
([[
eos_token
]
*
(
self
.
args
.
ctx_len
-
out
.
shape
[
-
1
])])
pad_idx
=
torch
.
tensor
([[
self
.
args
.
eos_token
]
*
(
self
.
args
.
ctx_len
-
out
.
shape
[
-
1
])])
query_idx
=
torch
.
cat
((
out
,
pad_idx
),
dim
=-
1
)
query_idx
=
torch
.
cat
((
out
,
pad_idx
),
dim
=-
1
)
logits
,
embeds
=
self
.
forward
(
query_idx
,
ppo_train
=
True
)
logits
,
embeds
=
self
.
forward
(
query_idx
,
ppo_train
=
True
)
logits
,
embeds
=
logits
[:,
-
1
],
embeds
[:,
-
1
]
logits
,
embeds
=
logits
[:,
-
1
],
embeds
[:,
-
1
]
...
@@ -532,8 +531,8 @@ class RWKV(pl.LightningModule):
...
@@ -532,8 +531,8 @@ class RWKV(pl.LightningModule):
sample
=
gumbel_sample
(
logits
,
temperature
=
temperature
,
dim
=
-
1
)
sample
=
gumbel_sample
(
logits
,
temperature
=
temperature
,
dim
=
-
1
)
out
,
_
=
pack
([
out
,
sample
],
'b *'
)
out
,
_
=
pack
([
out
,
sample
],
'b *'
)
if
exists
(
eos_token
):
if
exists
(
self
.
args
.
eos_token
):
is_eos_tokens
=
(
out
==
eos_token
)
is_eos_tokens
=
(
out
==
self
.
args
.
eos_token
)
if
is_eos_tokens
.
any
(
dim
=
-
1
).
all
():
if
is_eos_tokens
.
any
(
dim
=
-
1
).
all
():
# mask out everything after the eos tokens
# mask out everything after the eos tokens
...
...
src/rlhf/ppo.py
浏览文件 @
be4440da
...
@@ -29,6 +29,8 @@ from src.model import RWKV
...
@@ -29,6 +29,8 @@ from src.model import RWKV
from
src.rlhf.reward
import
RewardModel
from
src.rlhf.reward
import
RewardModel
from
src.rlhf.optimizer
import
get_optimizer
from
src.rlhf.optimizer
import
get_optimizer
from
src.rlhf.utils
import
masked_mean
,
eval_decorator
from
src.rlhf.utils
import
masked_mean
,
eval_decorator
from
src.dataset
import
load_prompt_data_4_ppo
from
src.dataset
import
ExperienceDataset
# actor critic
# actor critic
...
@@ -52,12 +54,13 @@ class ActorCritic(nn.Module):
...
@@ -52,12 +54,13 @@ class ActorCritic(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
args
=
args
self
.
actor
=
actor
self
.
actor
=
actor
self
.
critic
=
critic
self
.
critic
=
critic
self
.
pooled_values
=
pooled_values
self
.
pooled_values
=
pooled_values
self
.
value_head
=
nn
.
Sequential
(
self
.
value_head
=
nn
.
Sequential
(
nn
.
Linear
(
args
.
n_embd
,
1
),
nn
.
Linear
(
self
.
args
.
n_embd
,
1
),
Rearrange
(
'... 1 -> ...'
)
Rearrange
(
'... 1 -> ...'
)
)
)
...
@@ -70,14 +73,12 @@ class ActorCritic(nn.Module):
...
@@ -70,14 +73,12 @@ class ActorCritic(nn.Module):
self
,
self
,
state
,
state
,
max_seq_len
,
max_seq_len
,
eos_token
=
None
,
return_values
=
False
return_values
=
False
):
):
# 产生一条 response,相当于采取了一次 action
# 产生一条 response,相当于采取了一次 action
actions
=
self
.
actor
.
generate
(
actions
=
self
.
actor
.
generate
(
max_seq_len
,
max_seq_len
,
prompt
=
state
,
prompt
=
state
eos_token
=
eos_token
)
)
# 将 prompt (state) 和 response (action) 进行拼接
# 将 prompt (state) 和 response (action) 进行拼接
...
@@ -93,8 +94,8 @@ class ActorCritic(nn.Module):
...
@@ -93,8 +94,8 @@ class ActorCritic(nn.Module):
# 考虑 eos token
# 考虑 eos token
mask
=
None
mask
=
None
if
exists
(
eos_token
):
if
exists
(
self
.
args
.
eos_token
):
mask
=
((
sequence
==
eos_token
).
cumsum
(
dim
=
-
1
)
==
0
)
mask
=
((
sequence
==
self
.
args
.
eos_token
).
cumsum
(
dim
=
-
1
)
==
0
)
mask
=
F
.
pad
(
mask
,
(
1
,
-
1
),
value
=
True
)
# include eos token
mask
=
F
.
pad
(
mask
,
(
1
,
-
1
),
value
=
True
)
# include eos token
action_mask
&=
mask
action_mask
&=
mask
...
@@ -143,27 +144,6 @@ class ActorCritic(nn.Module):
...
@@ -143,27 +144,6 @@ class ActorCritic(nn.Module):
return
action_logits
,
values
return
action_logits
,
values
@
beartype
class
ExperienceDataset
(
Dataset
):
def
__init__
(
self
,
data
:
List
[
torch
.
Tensor
],
device
=
None
):
super
().
__init__
()
self
.
data
=
data
self
.
device
=
device
def
__len__
(
self
):
return
self
.
data
[
0
].
shape
[
0
]
def
__getitem__
(
self
,
ind
):
return
tuple
(
map
(
lambda
t
:
t
[
ind
].
to
(
self
.
device
),
self
.
data
))
def
create_dataloader
(
data
,
batch_size
,
shuffle
=
True
,
device
=
None
,
**
kwargs
):
ds
=
ExperienceDataset
(
data
,
device
=
device
)
return
DataLoader
(
ds
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
**
kwargs
)
# helper functions
# helper functions
def
exists
(
val
):
def
exists
(
val
):
...
@@ -230,7 +210,6 @@ def clipped_value_loss(values, rewards, old_values, clip):
...
@@ -230,7 +210,6 @@ def clipped_value_loss(values, rewards, old_values, clip):
return
torch
.
mean
(
torch
.
max
(
value_loss_1
,
value_loss_2
))
return
torch
.
mean
(
torch
.
max
(
value_loss_1
,
value_loss_2
))
# rlhf
# rlhf
@
beartype
@
beartype
class
RLHF
(
pl
.
LightningModule
):
class
RLHF
(
pl
.
LightningModule
):
def
__init__
(
def
__init__
(
...
@@ -244,6 +223,18 @@ class RLHF(pl.LightningModule):
...
@@ -244,6 +223,18 @@ class RLHF(pl.LightningModule):
self
.
args
=
args
self
.
args
=
args
# 读入 prompts 数据
self
.
prompts
=
load_prompt_data_4_ppo
(
args
)
# 用于保存与 environment 的交互数据,用于训练 actor_critic (agent)
self
.
sequence_batch
=
[]
self
.
prompt_mask_batch
=
[]
self
.
mask_batch
=
[]
self
.
action_prob_batch
=
[]
self
.
action_log_prob_batch
=
[]
self
.
reward_batch
=
[]
self
.
value_batch
=
[]
# 使用 RWKV 初始化 actor_critic
# 使用 RWKV 初始化 actor_critic
actor_critic
=
ActorCritic
(
actor_critic
=
ActorCritic
(
args
=
self
.
args
,
args
=
self
.
args
,
...
@@ -266,50 +257,105 @@ class RLHF(pl.LightningModule):
...
@@ -266,50 +257,105 @@ class RLHF(pl.LightningModule):
def
configure_optimizers
(
self
):
def
configure_optimizers
(
self
):
args
=
self
.
args
args
=
self
.
args
optim_groups_actor
=
[]
optim_groups_critic
=
[]
if
args
.
layerwise_lr
>
0
:
if
args
.
layerwise_lr
>
0
:
lr_1x
=
set
()
lr_1x_actor
=
set
()
lr_2x
=
set
()
lr_2x_actor
=
set
()
lr_3x
=
set
()
lr_3x_actor
=
set
()
lr_1x_critic
=
set
()
lr_2x_critic
=
set
()
lr_3x_critic
=
set
()
for
n
,
p
in
self
.
named_parameters
():
for
n
,
p
in
self
.
named_parameters
():
if
"time_mix"
in
n
:
if
"time_mix"
in
n
:
if
args
.
my_pile_stage
==
2
:
if
args
.
my_pile_stage
==
2
:
lr_2x
.
add
(
n
)
if
"actor"
in
n
:
lr_2x_actor
.
add
(
n
)
elif
"critic"
in
n
:
lr_2x_critic
.
add
(
n
)
else
:
else
:
lr_1x
.
add
(
n
)
if
"actor"
in
n
:
lr_1x_actor
.
add
(
n
)
elif
"critic"
in
n
:
lr_1x_critic
.
add
(
n
)
elif
"time_decay"
in
n
:
elif
"time_decay"
in
n
:
if
args
.
my_pile_stage
==
2
:
if
args
.
my_pile_stage
==
2
:
lr_3x
.
add
(
n
)
if
"actor"
in
n
:
lr_3x_actor
.
add
(
n
)
elif
"critic"
in
n
:
lr_3x_critic
.
add
(
n
)
else
:
else
:
lr_2x
.
add
(
n
)
if
"actor"
in
n
:
lr_2x_actor
.
add
(
n
)
elif
"critic"
in
n
:
lr_2x_critic
.
add
(
n
)
elif
"time_first"
in
n
:
elif
"time_first"
in
n
:
lr_3x
.
add
(
n
)
if
"actor"
in
n
:
lr_3x_actor
.
add
(
n
)
elif
"critic"
in
n
:
lr_3x_critic
.
add
(
n
)
else
:
else
:
lr_1x
.
add
(
n
)
if
"actor"
in
n
:
lr_1x
=
sorted
(
list
(
lr_1x
))
lr_1x_actor
.
add
(
n
)
lr_2x
=
sorted
(
list
(
lr_2x
))
elif
"critic"
in
n
:
lr_3x
=
sorted
(
list
(
lr_3x
))
lr_1x_critic
.
add
(
n
)
lr_1x_actor
=
sorted
(
list
(
lr_1x_actor
))
lr_2x_actor
=
sorted
(
list
(
lr_2x_actor
))
lr_3x_actor
=
sorted
(
list
(
lr_3x_actor
))
lr_1x_critic
=
sorted
(
list
(
lr_1x_critic
))
lr_2x_critic
=
sorted
(
list
(
lr_2x_critic
))
lr_3x_critic
=
sorted
(
list
(
lr_3x_critic
))
param_dict
=
{
n
:
p
for
n
,
p
in
self
.
named_parameters
()}
param_dict
=
{
n
:
p
for
n
,
p
in
self
.
named_parameters
()}
if
args
.
my_pile_stage
==
2
:
if
args
.
my_pile_stage
==
2
:
optim_groups
=
[
optim_groups_actor
=
[
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_1x
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
1.0
},
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_1x_actor
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
1.0
},
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_2x
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
5.0
},
# test: 2e-3 / args.lr_init},
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_2x_actor
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
5.0
},
# test: 2e-3 / args.lr_init},
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_3x
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
5.0
},
# test: 3e-3 / args.lr_init},
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_3x_actor
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
5.0
},
# test: 3e-3 / args.lr_init},
]
optim_groups_critic
=
[
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_1x_critic
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
1.0
},
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_2x_critic
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
5.0
},
# test: 2e-3 / args.lr_init},
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_3x_critic
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
5.0
},
# test: 3e-3 / args.lr_init},
]
]
else
:
else
:
optim_groups
=
[
optim_groups_actor
=
[
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_1x
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
1.0
},
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_1x_actor
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
1.0
},
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_2x
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
2.0
},
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_2x_actor
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
2.0
},
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_3x
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
3.0
},
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_3x_actor
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
3.0
},
]
optim_groups_critic
=
[
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_1x_critic
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
1.0
},
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_2x_critic
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
2.0
},
{
"params"
:
[
param_dict
[
n
]
for
n
in
lr_3x_critic
],
"weight_decay"
:
0.0
,
"my_lr_scale"
:
3.0
},
]
]
else
:
else
:
optim_groups
=
[
optim_groups_actor
=
[
{
"params"
:
[
p
for
n
,
p
in
self
.
named_parameters
()],
"weight_decay"
:
0.0
},
{
"params"
:
[
p
for
n
,
p
in
self
.
named_parameters
()
if
"actor"
in
n
],
"weight_decay"
:
0.0
},
]
optim_groups_critic
=
[
{
"params"
:
[
p
for
n
,
p
in
self
.
named_parameters
()
if
"critic"
in
n
],
"weight_decay"
:
0.0
},
]
]
if
self
.
deepspeed_offload
:
if
self
.
deepspeed_offload
:
return
DeepSpeedCPUAdam
(
optim_groups
,
lr
=
self
.
args
.
lr_init
,
betas
=
self
.
args
.
betas
,
eps
=
self
.
args
.
adam_eps
,
bias_correction
=
True
,
adamw_mode
=
False
,
weight_decay
=
0
,
amsgrad
=
False
)
actor_optimizer
=
DeepSpeedCPUAdam
(
optim_groups_actor
,
lr
=
self
.
args
.
actor_lr
,
betas
=
self
.
args
.
betas
,
eps
=
self
.
args
.
adam_eps
,
bias_correction
=
True
,
adamw_mode
=
False
,
weight_decay
=
0
,
amsgrad
=
False
)
return
FusedAdam
(
optim_groups
,
lr
=
self
.
args
.
lr_init
,
betas
=
self
.
args
.
betas
,
eps
=
self
.
args
.
adam_eps
,
bias_correction
=
True
,
adam_w_mode
=
False
,
weight_decay
=
0
,
amsgrad
=
False
)
critic_optimizer
=
DeepSpeedCPUAdam
(
optim_groups_critic
,
lr
=
self
.
args
.
critic_lr
,
betas
=
self
.
args
.
betas
,
eps
=
self
.
args
.
adam_eps
,
bias_correction
=
True
,
adamw_mode
=
False
,
weight_decay
=
0
,
amsgrad
=
False
)
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
return
actor_optimizer
,
critic_optimizer
actor_optimizer
=
FusedAdam
(
optim_groups_actor
,
lr
=
self
.
args
.
actor_lr
,
betas
=
self
.
args
.
betas
,
eps
=
self
.
args
.
adam_eps
,
bias_correction
=
True
,
adam_w_mode
=
False
,
weight_decay
=
0
,
amsgrad
=
False
)
critic_optimizer
=
FusedAdam
(
optim_groups_critic
,
lr
=
self
.
args
.
critic_lr
,
betas
=
self
.
args
.
betas
,
eps
=
self
.
args
.
adam_eps
,
bias_correction
=
True
,
adam_w_mode
=
False
,
weight_decay
=
0
,
amsgrad
=
False
)
return
actor_optimizer
,
critic_optimizer
@
property
@
property
def
deepspeed_offload
(
self
)
->
bool
:
def
deepspeed_offload
(
self
)
->
bool
:
...
@@ -360,7 +406,7 @@ class RLHF(pl.LightningModule):
...
@@ -360,7 +406,7 @@ class RLHF(pl.LightningModule):
return
best_sequence
return
best_sequence
def
training_step
(
self
,
batch
,
batch_idx
):
def
training_step
(
self
,
batch
,
batch_idx
,
optimizer_idx
):
sequences
,
\
sequences
,
\
prompt_masks
,
\
prompt_masks
,
\
masks
,
\
masks
,
\
...
@@ -423,27 +469,34 @@ class RLHF(pl.LightningModule):
...
@@ -423,27 +469,34 @@ class RLHF(pl.LightningModule):
policy_loss
=
-
torch
.
min
(
surr1
,
surr2
)
-
self
.
args
.
beta_s
*
entropies
policy_loss
=
-
torch
.
min
(
surr1
,
surr2
)
-
self
.
args
.
beta_s
*
entropies
# actor loss (也称为 policy loss, 是最终要使用模型的 loss)
# actor loss (也称为 policy loss, 是最终要使用模型的 loss)
if
optimizer_idx
==
0
:
actor_loss
=
policy_loss
.
mean
()
+
kl_div_loss
actor_loss
=
policy_loss
.
mean
()
+
kl_div_loss
return
actor_loss
# critic loss (也称为 value loss)
# critic loss (也称为 value loss)
# update value network separate from policy network
# update value network separate from policy network
if
optimizer_idx
==
1
:
critic_loss
=
clipped_value_loss
(
values
,
rewards
,
old_values
,
self
.
args
.
value_clip
)
critic_loss
=
clipped_value_loss
(
values
,
rewards
,
old_values
,
self
.
args
.
value_clip
)
critic_loss
=
critic_loss
.
mean
()
critic_loss
=
critic_loss
.
mean
()
return
critic_loss
return
{
'actor_loss'
:
actor_loss
.
item
(),
'critic_loss'
:
critic_loss
.
item
()}
def
gen_experience_dataset
(
self
):
def
make_experience
(
self
,
prompts
,
eos_token
=
None
,
temperature
=
1
):
''' 通过与 environment 交互产生训练数据
''' 通过与 environment 交互产生训练数据
'''
'''
device
=
self
.
device
device
=
self
.
device
time_cnt
=
0
for
eps
in
tqdm
(
range
(
self
.
args
.
num_episodes
),
desc
=
'episodes'
):
for
timestep
in
range
(
self
.
args
.
max_timesteps
):
time_cnt
+=
1
# select a bunch of random states (prompts)
# select a bunch of random states (prompts)
# and get the action (sampled sequence from rwkv as well as the action probs)
# and get the action (sampled sequence from rwkv as well as the action probs)
# also calculate the reward using reward model and store
# also calculate the reward using reward model and store
# 随机挑选一条 prompt
# 随机挑选一条 prompt
rand_prompt_index
=
randrange
(
0
,
len
(
prompts
))
rand_prompt_index
=
randrange
(
0
,
len
(
self
.
prompts
))
state
=
prompts
[
rand_prompt_index
]
state
=
self
.
prompts
[
rand_prompt_index
]
# remove padding from state
# remove padding from state
state_mask
=
state
!=
self
.
args
.
pad_value
state_mask
=
state
!=
self
.
args
.
pad_value
...
@@ -463,7 +516,6 @@ class RLHF(pl.LightningModule):
...
@@ -463,7 +516,6 @@ class RLHF(pl.LightningModule):
)
=
self
.
actor_critic
.
generate
(
)
=
self
.
actor_critic
.
generate
(
rearrange
(
state
,
'n -> 1 n'
),
rearrange
(
state
,
'n -> 1 n'
),
max_seq_len
=
self
.
args
.
ctx_len
,
max_seq_len
=
self
.
args
.
ctx_len
,
eos_token
=
eos_token
,
return_values
=
True
return_values
=
True
)
)
action_logits
=
shift
(
action_logits
,
shift
=
1
,
dim
=
-
2
)
# need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
action_logits
=
shift
(
action_logits
,
shift
=
1
,
dim
=
-
2
)
# need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
...
@@ -493,12 +545,43 @@ class RLHF(pl.LightningModule):
...
@@ -493,12 +545,43 @@ class RLHF(pl.LightningModule):
sample
=
True
sample
=
True
)
)
return
(
self
.
sequence_batch
.
append
(
sequence
)
sequence
,
self
.
prompt_mask_batch
.
append
(
prompt_mask
)
prompt_mask
,
self
.
mask_batch
.
append
(
mask
)
mask
,
self
.
action_prob_batch
.
append
(
action_prob
)
action_prob
,
self
.
action_log_prob_batch
.
append
(
action_log_prob
)
action_log_prob
,
self
.
reward_batch
.
append
(
reward
)
reward
,
self
.
value_batch
.
append
(
value
)
value
if
time_cnt
%
self
.
args
.
update_timesteps
==
0
:
train_data
=
zip
(
self
.
sequence_batch
,
self
.
prompt_mask_batch
,
self
.
mask_batch
,
self
.
action_prob_batch
,
self
.
action_log_prob_batch
,
self
.
reward_batch
,
self
.
value_batch
)
)
for
_sequence
,
_prompt_mask
,
_mask
,
_action_prob
,
_action_log_prob
,
_reward
,
_value
in
train_data
:
yield
_sequence
,
_prompt_mask
,
_mask
,
_action_prob
,
_action_log_prob
,
_reward
,
_value
self
.
sequence_batch
.
clear
()
self
.
prompt_mask_batch
.
clear
()
self
.
mask_batch
.
clear
()
self
.
action_prob_batch
.
clear
()
self
.
action_log_prob_batch
.
clear
()
self
.
reward_batch
.
clear
()
self
.
value_batch
.
clear
()
def
_dataloader
(
self
)
->
DataLoader
:
''' Initialize the Replay Buffer dataset used for retrieving experiences '''
dataset
=
ExperienceDataset
(
self
.
gen_experience_dataset
)
dataloader
=
DataLoader
(
dataset
=
dataset
,
batch_size
=
self
.
args
.
micro_bsz
)
return
dataloader
def
train_dataloader
(
self
)
->
DataLoader
:
''' Get train loader '''
return
self
.
_dataloader
()
train_ppo.py
浏览文件 @
be4440da
...
@@ -138,6 +138,7 @@ if __name__ == "__main__":
...
@@ -138,6 +138,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--num_episodes"
,
default
=
50000
,
type
=
int
)
parser
.
add_argument
(
"--num_episodes"
,
default
=
50000
,
type
=
int
)
parser
.
add_argument
(
"--max_timesteps"
,
default
=
500
,
type
=
int
)
parser
.
add_argument
(
"--max_timesteps"
,
default
=
500
,
type
=
int
)
parser
.
add_argument
(
"--update_timesteps"
,
default
=
5000
,
type
=
int
)
parser
.
add_argument
(
"--update_timesteps"
,
default
=
5000
,
type
=
int
)
parser
.
add_argument
(
"--eos_token"
,
default
=
0
,
type
=
int
)
parser
=
Trainer
.
add_argparse_args
(
parser
)
parser
=
Trainer
.
add_argparse_args
(
parser
)
...
@@ -249,7 +250,6 @@ if __name__ == "__main__":
...
@@ -249,7 +250,6 @@ if __name__ == "__main__":
from
collections
import
deque
,
namedtuple
from
collections
import
deque
,
namedtuple
from
einops
import
rearrange
from
einops
import
rearrange
from
src.dataset
import
PPODataset
,
load_prompt_data_4_ppo
from
src.rlhf.ppo
import
RLHF
from
src.rlhf.ppo
import
RLHF
from
src.trainer
import
rlhf_train_callback
from
src.trainer
import
rlhf_train_callback
from
src.model
import
RWKV
from
src.model
import
RWKV
...
@@ -258,9 +258,6 @@ if __name__ == "__main__":
...
@@ -258,9 +258,6 @@ if __name__ == "__main__":
# 用于 PPO 训练的数据,需要与 environment 交互获得
# 用于 PPO 训练的数据,需要与 environment 交互获得
memory
=
[]
memory
=
[]
# 读入训练数据集
prompts
=
load_prompt_data_4_ppo
(
args
)
# 用 rwkv 初始化 actor 模型
# 用 rwkv 初始化 actor 模型
actor
=
RWKV
(
args
)
actor
=
RWKV
(
args
)
actor
.
load
(
args
.
load_sft_model
)
actor
.
load
(
args
.
load_sft_model
)
...
@@ -298,21 +295,7 @@ if __name__ == "__main__":
...
@@ -298,21 +295,7 @@ if __name__ == "__main__":
trainer
.
strategy
.
config
[
"zero_optimization"
][
"allgather_bucket_size"
]
=
args
.
ds_bucket_mb
*
1000
*
1000
trainer
.
strategy
.
config
[
"zero_optimization"
][
"allgather_bucket_size"
]
=
args
.
ds_bucket_mb
*
1000
*
1000
trainer
.
strategy
.
config
[
"zero_optimization"
][
"reduce_bucket_size"
]
=
args
.
ds_bucket_mb
*
1000
*
1000
trainer
.
strategy
.
config
[
"zero_optimization"
][
"reduce_bucket_size"
]
=
args
.
ds_bucket_mb
*
1000
*
1000
time_cnt
=
0
trainer
.
fit
(
rlhf_model
)
for
eps
in
tqdm
(
range
(
args
.
num_episodes
),
desc
=
'episodes'
):
for
timestep
in
range
(
args
.
max_timesteps
):
time_cnt
+=
1
# 生成 ppo 模型的训练数据
experience_data
=
rlhf_model
.
make_experience
(
prompts
,
eos_token
=
0
)
memory
.
append
(
experience_data
)
# learn from the stored memories
if
time_cnt
%
args
.
update_timesteps
==
0
:
train_data
=
PPODataset
(
memory
)
data_loader
=
DataLoader
(
train_data
,
shuffle
=
False
,
pin_memory
=
True
,
batch_size
=
args
.
micro_bsz
,
num_workers
=
1
,
persistent_workers
=
False
,
drop_last
=
True
)
trainer
.
fit
(
rlhf_model
,
data_loader
)
print
(
'rlhf training complete'
)
print
(
'rlhf training complete'
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录