Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
CSDN 技术社区
ai
chatCSDN
提交
b7f231a9
C
chatCSDN
项目概览
CSDN 技术社区
/
ai
/
chatCSDN
通知
107
Star
8
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
chatCSDN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b7f231a9
编写于
3月 20, 2023
作者:
U
u010280923
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
opt ppo model
上级
0cca2efe
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
0 addition
and
623 deletion
+0
-623
src/rlhf/ppo_old.py
src/rlhf/ppo_old.py
+0
-623
未找到文件。
src/rlhf/ppo_old.py
已删除
100644 → 0
浏览文件 @
0cca2efe
import
math
from
pathlib
import
Path
import
copy
from
tqdm
import
tqdm
from
functools
import
partial
from
collections
import
deque
,
namedtuple
from
random
import
randrange
from
beartype
import
beartype
from
beartype.typing
import
List
,
Optional
,
Callable
,
Deque
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
torch.optim
import
Adam
from
torch.utils.data
import
Dataset
,
DataLoader
from
torch.nn.utils.rnn
import
pad_sequence
from
pytorch_lightning.utilities
import
rank_zero_info
from
einops
import
rearrange
,
repeat
from
einops.layers.torch
import
Rearrange
from
src.model
import
RWKV
from
src.rlhf.reward
import
RewardModel
from
src.rlhf.optimizer
import
get_optimizer
from
src.rlhf.utils
import
masked_mean
,
eval_decorator
from
accelerate
import
Accelerator
# actor critic - PaLM with lora
PPOActionCriticReturn
=
namedtuple
(
'PPOActionCriticReturn'
,
[
'actions'
,
'sequence'
,
'mask'
,
'prompt_mask'
,
'action_logits'
,
'values'
])
@
beartype
class
ActorCritic
(
nn
.
Module
):
def
__init__
(
self
,
rwkv
:
RWKV
,
critic_palm
:
Optional
[
RWKV
]
=
None
,
pooled_values
=
False
,
actor_lora
=
True
,
critic_lora
=
True
,
actor_lora_r
=
8
,
critic_lora_r
=
8
,
actor_lora_scope
=
'actor'
,
critic_lora_scope
=
'critic'
,
actor_dropout
=
0.
,
critic_dropout
=
0.
):
super
().
__init__
()
self
.
actor_palm
=
rwkv
self
.
critic_palm
=
critic_palm
if
not
exists
(
self
.
critic_palm
):
self
.
critic_palm
=
copy
.
deepcopy
(
rwkv
)
self
.
actor_palm
.
set_dropout
(
actor_dropout
)
self
.
critic_palm
.
set_dropout
(
critic_dropout
)
self
.
actor_lora
=
actor_lora
self
.
critic_lora
=
critic_lora
self
.
actor_lora_scope
=
actor_lora_scope
if
actor_lora
else
None
self
.
critic_lora_scope
=
critic_lora_scope
if
critic_lora
else
None
if
self
.
actor_lora
:
self
.
actor_palm
.
add_finetune_params
(
actor_lora_scope
,
lora_r
=
actor_lora_r
)
if
self
.
critic_lora
:
self
.
critic_palm
.
add_finetune_params
(
critic_lora_scope
,
lora_r
=
critic_lora_r
)
self
.
pooled_values
=
pooled_values
self
.
value_head
=
nn
.
Sequential
(
nn
.
Linear
(
rwkv
.
dim
,
1
),
Rearrange
(
'... 1 -> ...'
)
)
nn
.
init
.
zeros_
(
self
.
value_head
[
0
].
bias
)
nn
.
init
.
orthogonal_
(
self
.
value_head
[
0
].
weight
,
gain
=
math
.
sqrt
(
2
))
def
actor_parameters
(
self
):
if
not
self
.
actor_lora
:
return
self
.
actor_palm
.
parameters
()
return
[
*
self
.
actor_palm
.
finetune_parameters
(
self
.
actor_lora_scope
)
]
def
critic_parameters
(
self
):
if
not
self
.
actor_lora
:
return
[
*
self
.
critic_palm
.
parameters
(),
*
self
.
value_head
.
parameters
()]
return
[
*
self
.
critic_palm
.
finetune_parameters
(
self
.
critic_lora_scope
),
*
self
.
value_head
.
parameters
()
]
@
torch
.
no_grad
()
@
eval_decorator
def
generate
(
self
,
state
,
max_seq_len
,
eos_token
=
None
,
return_values
=
False
,
**
kwargs
):
actions
=
self
.
actor_palm
.
generate
(
max_seq_len
,
prompt
=
state
,
eos_token
=
eos_token
,
finetune_scope
=
self
.
actor_lora_scope
,
use_tqdm
=
True
,
**
kwargs
)
sequence
=
torch
.
cat
((
state
,
actions
),
dim
=
-
1
)
action_len
=
actions
.
shape
[
-
1
]
state_len
=
state
.
shape
[
-
1
]
prompt_mask
=
torch
.
arange
(
sequence
.
shape
[
-
1
],
device
=
state
.
device
)
<
state_len
prompt_mask
=
repeat
(
prompt_mask
,
'n -> b n'
,
b
=
sequence
.
shape
[
0
])
action_mask
=
~
prompt_mask
mask
=
None
if
exists
(
eos_token
):
mask
=
((
sequence
==
eos_token
).
cumsum
(
dim
=
-
1
)
==
0
)
mask
=
F
.
pad
(
mask
,
(
1
,
-
1
),
value
=
True
)
# include eos token
action_mask
&=
mask
action_logits
,
value
=
self
.
forward
(
sequence
,
mask
=
action_mask
,
return_values
=
return_values
)
return
PPOActionCriticReturn
(
actions
,
sequence
,
mask
,
prompt_mask
,
action_logits
,
value
)
def
forward
(
self
,
x
,
mask
=
None
,
return_values
=
True
):
action_logits
=
self
.
actor_palm
(
x
,
finetune_scope
=
self
.
actor_lora_scope
)
if
not
return_values
:
return
action_logits
,
None
critic_embeds
=
self
.
critic_palm
(
x
,
return_only_embedding
=
True
,
finetune_scope
=
self
.
critic_lora_scope
)
if
self
.
pooled_values
:
critic_embeds
=
shift
(
critic_embeds
,
shift
=
1
,
dim
=
-
2
)
critic_embeds
=
masked_mean
(
critic_embeds
,
mask
,
dim
=
1
)
values
=
self
.
value_head
(
critic_embeds
)
return
action_logits
,
values
# data
Memory
=
namedtuple
(
'Memory'
,
[
'sequence'
,
'prompt_mask'
,
'mask'
,
'action_prob'
,
'action_log_prob'
,
'reward'
,
'value'
])
@
beartype
class
ExperienceDataset
(
Dataset
):
def
__init__
(
self
,
data
:
List
[
torch
.
Tensor
],
device
=
None
):
super
().
__init__
()
self
.
data
=
data
self
.
device
=
device
def
__len__
(
self
):
return
self
.
data
[
0
].
shape
[
0
]
def
__getitem__
(
self
,
ind
):
return
tuple
(
map
(
lambda
t
:
t
[
ind
].
to
(
self
.
device
),
self
.
data
))
def
create_dataloader
(
data
,
batch_size
,
shuffle
=
True
,
device
=
None
,
**
kwargs
):
ds
=
ExperienceDataset
(
data
,
device
=
device
)
return
DataLoader
(
ds
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
**
kwargs
)
# helper functions
def
exists
(
val
):
return
val
is
not
None
def
default
(
val
,
d
):
return
val
if
exists
(
val
)
else
d
def
masked_normalize
(
t
,
eps
=
1e-5
,
mask
=
None
,
dim
=
None
):
dim
=
default
(
dim
,
tuple
(
range
(
t
.
ndim
)))
kwargs
=
dict
(
dim
=
dim
,
keepdim
=
True
)
mean
=
masked_mean
(
t
,
mask
=
mask
,
**
kwargs
)
mean_centered
=
t
-
mean
var
=
masked_mean
(
mean_centered
**
2
,
mask
=
mask
,
**
kwargs
)
return
mean_centered
*
var
.
clamp
(
min
=
eps
).
rsqrt
()
def
pad_sequence_fixed
(
sequences
,
*
args
,
**
kwargs
):
first_el
=
sequences
[
0
]
has_no_dimension
=
first_el
.
ndim
==
0
# if no dimensions, add a single dimension
if
has_no_dimension
:
sequences
=
tuple
(
map
(
lambda
t
:
t
[
None
],
sequences
))
out
=
pad_sequence
(
sequences
,
*
args
,
**
kwargs
)
if
has_no_dimension
:
out
=
rearrange
(
out
,
'... 1 -> ...'
)
return
out
def
log
(
t
,
eps
=
1e-20
):
return
torch
.
log
(
t
.
clamp
(
min
=
eps
))
def
log_prob
(
prob
,
indices
):
assert
prob
.
shape
[:
2
]
==
indices
.
shape
,
f
'preceding shapes of prob
{
prob
.
shape
[:
2
]
}
and indices
{
indices
.
shape
}
must match'
return
log
(
prob
.
gather
(
-
1
,
indices
[...,
None
])).
squeeze
(
-
1
)
def
shift
(
t
,
value
=
0
,
shift
=
1
,
dim
=
-
1
):
zeros
=
(
0
,
0
)
*
(
-
dim
-
1
)
return
F
.
pad
(
t
,
(
*
zeros
,
shift
,
-
shift
),
value
=
value
)
def
masked_entropy
(
prob
,
dim
=
-
1
,
mask
=
None
):
entropies
=
(
prob
*
log
(
prob
)).
sum
(
dim
=
-
1
)
return
masked_mean
(
entropies
,
mask
=
mask
).
mean
()
def
masked_kl_div
(
prob1
,
prob2
,
mask
=
None
):
"""
need to account for variable sequence lengths, therefore not using the built-in functional version
"""
kl_divs
=
(
prob1
*
(
log
(
prob2
)
-
log
(
prob1
))).
sum
(
dim
=
-
1
)
if
not
exists
(
mask
):
return
kl_divs
.
mean
()
return
masked_mean
(
kl_divs
,
mask
).
mean
()
def
clipped_value_loss
(
values
,
rewards
,
old_values
,
clip
):
value_clipped
=
old_values
+
(
values
-
old_values
).
clamp
(
-
clip
,
clip
)
value_loss_1
=
(
value_clipped
.
flatten
()
-
rewards
)
**
2
value_loss_2
=
(
values
.
flatten
()
-
rewards
)
**
2
return
torch
.
mean
(
torch
.
max
(
value_loss_1
,
value_loss_2
))
# rlhf trainer
@
beartype
class
RLHFTrainer
(
nn
.
Module
):
def
__init__
(
self
,
args
,
accelerate_kwargs
:
dict
=
{}
):
super
().
__init__
()
self
.
args
=
args
self
.
accelerate
=
Accelerator
(
**
accelerate_kwargs
)
# 加载 RWKV 模型
rwkv
=
RWKV
(
args
)
if
len
(
args
.
load_sft_model
)
==
0
:
rank_zero_info
(
f
"SFT must load model, please input "
)
exit
(
1
)
rank_zero_info
(
f
"########## Loading
{
args
.
load_sft_model
}
... ##########"
)
try
:
load_dict
=
torch
.
load
(
args
.
load_sft_model
,
map_location
=
"cpu"
)
except
:
rank_zero_info
(
f
"Bad checkpoint
{
args
.
load_sft_model
}
"
)
exit
(
1
)
if
args
.
load_partial
==
1
:
load_keys
=
load_dict
.
keys
()
for
k
in
rwkv
.
state_dict
():
if
k
not
in
load_keys
:
load_dict
[
k
]
=
rwkv
.
state_dict
()[
k
]
rwkv
.
load_state_dict
(
load_dict
)
self
.
rwkv
=
rwkv
# 使用 RWKV 初始化 actor_critic
actor_critic
=
ActorCritic
(
rwkv
=
self
.
rwkv
,
actor_lora
=
args
.
actor_lora
,
critic_lora
=
args
.
critic_lora
,
actor_lora_r
=
args
.
actor_lora_r
,
critic_lora_r
=
args
.
critic_lora_r
,
pooled_values
=
args
.
critic_pooled_values
,
actor_dropout
=
args
.
actor_dropout
,
critic_dropout
=
args
.
critic_dropout
).
to
(
self
.
rwkv
.
device
)
self
.
actor_critic
=
actor_critic
# 加载 reward_model,并将 reward_model 设置为 evaluation 模式
reward_model
=
RewardModel
(
args
)
reward_model
.
load
(
args
.
load_rm_model
)
self
.
reward_model
=
reward_model
.
eval
()
# optimizers
self
.
actor_optim
=
get_optimizer
(
actor_critic
.
actor_parameters
(),
lr
=
self
.
args
.
actor_lr
,
wd
=
self
.
args
.
actor_wd
,
betas
=
self
.
args
.
betas
,
eps
=
self
.
args
.
actor_adam_eps
,
use_lion
=
self
.
args
.
use_lion
)
self
.
critic_optim
=
get_optimizer
(
actor_critic
.
critic_parameters
(),
lr
=
self
.
args
.
critic_lr
,
wd
=
self
.
args
.
critic_wd
,
betas
=
self
.
args
.
betas
,
eps
=
self
.
args
.
critic_adam_eps
,
use_lion
=
self
.
args
.
use_lion
)
# prepare with accelerator
(
self
.
actor_critic
,
self
.
reward_model
,
self
.
actor_optim
,
self
.
critic_optim
)
=
self
.
accelerate
.
prepare
(
self
.
actor_critic
,
self
.
reward_model
,
self
.
actor_optim
,
self
.
critic_optim
)
def
print
(
self
,
msg
):
return
self
.
accelerate
.
print
(
msg
)
def
save
(
self
,
filepath
=
'./checkpoint.pt'
):
torch
.
save
(
self
.
actor_critic
.
state_dict
(),
filepath
)
def
load
(
self
,
filepath
=
'./checkpoint.pt'
):
state_dict
=
torch
.
load
(
filepath
)
self
.
actor_critic
.
load_state_dict
(
state_dict
)
@
property
def
device
(
self
):
return
self
.
accelerate
.
device
@
torch
.
no_grad
()
def
generate
(
self
,
max_seq_len
,
*
args
,
prompt
,
num_samples
=
4
,
# sample 4 per prompt and select the one with highest reward
**
kwargs
):
assert
prompt
.
ndim
==
1
,
'only one prompt allowed at a time for now'
prompt
=
repeat
(
prompt
,
'n -> b n'
,
b
=
num_samples
)
actor_critic
=
self
.
accelerate
.
unwrap_model
(
self
.
actor_critic
)
reward_model
=
self
.
accelerate
.
unwrap_model
(
self
.
reward_model
)
actor_critic
.
eval
()
(
actions
,
sequences
,
mask
,
prompt_mask
,
action_logits
,
_
)
=
actor_critic
.
generate
(
prompt
,
*
args
,
max_seq_len
=
max_seq_len
,
return_values
=
False
,
**
kwargs
)
rewards
=
reward_model
(
sequences
,
prompt_mask
=
prompt_mask
,
mask
=
mask
,
sample
=
True
)
best_sequence_index
=
rewards
.
topk
(
1
,
dim
=
-
1
).
indices
best_sequence
=
sequences
[
best_sequence_index
]
best_sequence
=
rearrange
(
best_sequence
,
'1 ... -> ...'
)
return
best_sequence
def
learn
(
self
,
memories
:
Deque
[
Memory
]
):
# stack all data stored in the memories
all_memories_stacked_and_padded
=
list
(
map
(
partial
(
pad_sequence_fixed
,
batch_first
=
True
),
zip
(
*
memories
)))
# prepare dataloader for policy phase training
dl
=
create_dataloader
(
all_memories_stacked_and_padded
,
self
.
minibatch_size
,
device
=
self
.
device
)
self
.
actor_critic
.
train
()
# PPO training
for
_
in
range
(
self
.
epochs
):
for
(
sequences
,
prompt_masks
,
masks
,
old_action_probs
,
old_log_probs
,
rewards
,
old_values
)
in
dl
:
action_masks
=
~
prompt_masks
&
masks
action_logits
,
values
=
self
.
actor_critic
(
sequences
,
mask
=
action_masks
)
action_logits
=
shift
(
action_logits
,
shift
=
1
,
dim
=
-
2
)
# need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
action_len
=
old_log_probs
.
shape
[
-
1
]
action_probs
=
action_logits
.
softmax
(
dim
=
-
1
)
action_log_probs
=
log_prob
(
action_probs
,
sequences
)
action_log_probs
=
action_log_probs
[:,
-
action_len
:]
# calculate entropies, taking into account which part of the sequence is actually an action
entropies
=
masked_entropy
(
action_probs
,
mask
=
action_masks
)
# calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not
kl_div_loss
=
0.
if
self
.
args
.
kl_div_loss_weight
>
0
:
kl_div_loss
=
masked_kl_div
(
action_probs
,
old_action_probs
,
mask
=
action_masks
)
*
self
.
args
.
kl_div_loss_weight
# handle non-pooled values
normalize_kwargs
=
dict
()
if
old_values
.
ndim
==
2
:
old_values
,
values
=
map
(
lambda
t
:
shift
(
t
,
shift
=
1
,
dim
=
-
2
),
(
old_values
,
values
))
old_values
=
old_values
[:,
-
action_len
:]
values
=
values
[:,
-
action_len
:]
rewards
=
rearrange
(
rewards
,
'b -> b 1'
)
normalize_kwargs
=
dict
(
dim
=
-
1
,
mask
=
action_masks
[:,
-
action_len
:])
if
values
.
ndim
<
rewards
.
ndim
:
values
=
rearrange
(
values
,
'... -> ... 1'
)
# calculate clipped surrogate objective, classic PPO loss
ratios
=
(
action_log_probs
-
old_log_probs
).
exp
()
advantages
=
masked_normalize
(
rewards
-
old_values
,
**
normalize_kwargs
)
if
advantages
.
ndim
==
1
:
advantages
=
rearrange
(
advantages
,
'b -> b 1'
)
surr1
=
ratios
*
advantages
surr2
=
ratios
.
clamp
(
1
-
self
.
args
.
eps_clip
,
1
+
self
.
args
.
eps_clip
)
*
advantages
policy_loss
=
-
torch
.
min
(
surr1
,
surr2
)
-
self
.
args
.
beta_s
*
entropies
# combine losses
loss
=
policy_loss
.
mean
()
+
kl_div_loss
# update actor
self
.
accelerate
.
backward
(
loss
)
self
.
print
(
f
'policy_loss:
{
loss
.
item
():.
3
f
}
'
)
if
exists
(
self
.
args
.
max_norm
):
self
.
accelerator
.
clip_grad_norm_
(
self
.
actor_critic
.
actor_parameters
(),
self
.
args
.
max_norm
)
self
.
actor_optim
.
step
()
self
.
actor_optim
.
zero_grad
()
# calculate value loss and update value network separate from policy network
value_loss
=
clipped_value_loss
(
values
,
rewards
,
old_values
,
self
.
args
.
value_clip
)
value_loss
=
value_loss
.
mean
()
self
.
print
(
f
'critic_loss:
{
value_loss
.
item
():.
3
f
}
'
)
self
.
accelerate
.
backward
(
value_loss
)
if
exists
(
self
.
args
.
max_norm
):
self
.
accelerator
.
clip_grad_norm_
(
self
.
actor_critic
.
critic_parameters
(),
self
.
args
.
max_norm
)
self
.
critic_optim
.
step
()
self
.
critic_optim
.
zero_grad
()
def
train
(
self
,
num_episodes
=
50000
,
max_timesteps
=
500
,
update_timesteps
=
5000
,
max_batch_size
=
16
,
eos_token
=
None
,
temperature
=
1.
):
device
=
self
.
device
time
=
0
memories
=
deque
([])
for
eps
in
tqdm
(
range
(
num_episodes
),
desc
=
'episodes'
):
for
timestep
in
range
(
max_timesteps
):
time
+=
1
# select a bunch of random states (prompts)
# and get the action (sampled sequence from palm as well as the action probs)
# also calculate the reward using reward model and store
rand_prompt_index
=
randrange
(
0
,
self
.
num_prompts
)
state
=
self
.
prompt_token_ids
[
rand_prompt_index
]
# remove padding from state
state_mask
=
state
!=
self
.
args
.
pad_value
state
=
state
[
state_mask
]
# get predicted sequence
(
actions
,
sequence
,
mask
,
prompt_mask
,
action_logits
,
value
)
=
self
.
actor_critic
.
generate
(
rearrange
(
state
,
'n -> 1 n'
),
max_seq_len
=
self
.
args
.
ctx_len
,
eos_token
=
eos_token
,
temperature
=
temperature
,
return_values
=
True
)
action_logits
=
shift
(
action_logits
,
shift
=
1
,
dim
=
-
2
)
# need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
action_prob
=
action_logits
.
softmax
(
dim
=
-
1
)
action_len
=
actions
.
shape
[
-
1
]
action_log_prob
=
log_prob
(
action_prob
,
sequence
)
action_log_prob
=
action_log_prob
[:,
-
action_len
:]
actions
=
rearrange
(
actions
,
'1 ... -> ...'
)
# get reward as given by supervised trained reward model
sequence
=
torch
.
cat
((
state
,
actions
),
dim
=
0
)
prompt_length
=
len
(
state
)
prompt_mask
=
torch
.
arange
(
sequence
.
shape
[
-
1
],
device
=
device
)
<
prompt_length
sequence
=
rearrange
(
sequence
,
'n -> 1 n'
)
prompt_mask
=
rearrange
(
prompt_mask
,
'n -> 1 n'
)
mask
=
rearrange
(
mask
,
'n -> 1 n'
)
if
exists
(
mask
)
else
torch
.
ones
(
sequence
.
shape
,
dtype
=
torch
.
bool
,
device
=
device
)
reward
=
self
.
reward_model
(
sequence
,
prompt_mask
=
prompt_mask
,
mask
=
mask
,
sample
=
True
)
detach_to_cpu_
=
lambda
t
:
rearrange
(
t
.
detach
().
cpu
(),
'1 ... -> ...'
)
# store memory for learning
memories
.
append
(
Memory
(
*
map
(
detach_to_cpu_
,
(
sequence
,
prompt_mask
,
mask
,
action_prob
,
action_log_prob
,
reward
,
value
))))
# learn from the stored memories
if
time
%
update_timesteps
==
0
:
self
.
learn
(
memories
)
memories
.
clear
()
print
(
'rlhf training complete'
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录