Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
CSDN 技术社区
ai
chatCSDN
提交
51560034
C
chatCSDN
项目概览
CSDN 技术社区
/
ai
/
chatCSDN
通知
107
Star
8
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
chatCSDN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
51560034
编写于
3月 09, 2023
作者:
U
u010280923
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
opt reward model
上级
46f70cba
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
395 addition
and
62 deletion
+395
-62
forward_demo.py
forward_demo.py
+3
-3
src/model.py
src/model.py
+5
-1
src/rlhf/reward.py
src/rlhf/reward.py
+25
-45
src/rlhf/rwkv/model.py
src/rlhf/rwkv/model.py
+15
-2
src/rlhf/rwkv/utils.py
src/rlhf/rwkv/utils.py
+104
-0
train_rm.py
train_rm.py
+243
-11
未找到文件。
forward_demo.py
浏览文件 @
51560034
...
...
@@ -30,7 +30,7 @@ os.environ['RWKV_JIT_ON'] = '1'
os
.
environ
[
"RWKV_CUDA_ON"
]
=
'0'
# if '1' then compile CUDA kernel for seq mode (much faster)
# from rwkv.model import RWKV # pip install rwkv
from
src.rlhf.rwkv
import
RWKV
from
src.rlhf.rwkv
.model
import
RWKV
# model = RWKV(model='./model/rwkv-190.pth', strategy='cpu fp32')
model
=
RWKV
(
model
=
'./model/RWKV-4-Pile-169M-20220807-8023.pth'
,
strategy
=
'cpu fp32'
)
...
...
@@ -46,7 +46,7 @@ model = RWKV(model='./model/RWKV-4-Pile-169M-20220807-8023.pth', strategy='cpu f
# model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230213-8019', strategy='cuda fp16 *0+ -> cpu fp32 *1')
# model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221110-ctx4096', strategy='cuda:0 fp16 *25 -> cuda:1 fp16')
out
,
state
=
model
.
forward
([
187
,
510
,
1563
,
310
,
247
],
None
)
out
,
state
,
token_embed
=
model
.
forward
([
187
,
510
,
1563
,
310
,
247
],
None
)
print
(
out
.
detach
().
cpu
().
numpy
())
# get logits
# out, state = model.forward([187, 510], None)
# out, state = model.forward([1563], state) # RNN has state (use deepcopy to clone states)
...
...
@@ -58,7 +58,7 @@ ipdb.set_trace()
# print('\n')
# from src.utils import PIPELINE, PIPELINE_ARGS
# from src.
rlhf.rwkv.
utils import PIPELINE, PIPELINE_ARGS
# pipeline = PIPELINE(model, "20B_tokenizer.json")
# ctx = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
...
...
src/model.py
浏览文件 @
51560034
...
...
@@ -429,7 +429,7 @@ class RWKV(pl.LightningModule):
return
cfg
.
get
(
"offload_optimizer"
)
or
cfg
.
get
(
"offload_param"
)
return
False
def
forward
(
self
,
idx
):
def
forward
(
self
,
idx
,
extra_embed
=
None
):
args
=
self
.
args
B
,
T
=
idx
.
size
()
assert
T
<=
args
.
ctx_len
,
"Cannot forward, model ctx_len is exhausted."
...
...
@@ -437,6 +437,10 @@ class RWKV(pl.LightningModule):
x
=
self
.
emb
(
idx
)
x_emb
=
x
# 给 x 加入额外的 embedding,例如在训练 RM 的时候,区分 prompt 和 response
if
extra_embed
is
not
None
:
x_emb
=
x_emb
+
extra_embed
if
args
.
tiny_att_dim
>
0
:
for
block
in
self
.
blocks
:
if
args
.
grad_cp
==
1
:
...
...
src/rlhf/reward.py
浏览文件 @
51560034
...
...
@@ -13,7 +13,7 @@ from einops import rearrange, repeat, reduce, pack, unpack
from
einops.layers.torch
import
Rearrange
,
Reduce
from
src.rlhf.utils
import
masked_mean
,
gumbel_sample
from
src.model
import
RWKV
from
src.
rlhf.rwkv.
model
import
RWKV
# helper functions
...
...
@@ -26,35 +26,25 @@ def exists(val):
class
RewardModel
(
nn
.
Module
):
def
__init__
(
self
,
rwkv
:
RWKV
,
dropout
=
0.1
,
num_binned_output
=
0.
rwkv
:
RWKV
):
super
().
__init__
()
# 用预训练模型初始化奖励模型
self
.
rwkv
=
copy
.
deepcopy
(
rwkv
)
self
.
rwkv
=
rwkv
# 输出 token 向量的维度
dim
=
rwkv
.
dim
# todo(luxin)
dim
=
rwkv
.
args
.
n_embd
# 打分等级,如果为5,打分等级分为 [0, 1, 2, 3, 4],共 5 个等级
self
.
binned_output
=
num_binned_output
>
1
# todo(luxin):prompt_embed 和 response_embed 都是初始化为全0?不应该有区分么
# 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0
self
.
prompt_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
dim
))
self
.
response_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
dim
))
# self.response_embed = nn.Parameter(torch.ones(1, 1, dim))
if
self
.
binned_output
:
# 如果打分等级的类别数大于1,则为多分类问题
self
.
to_pred
=
nn
.
Linear
(
dim
,
num_binned_output
)
else
:
# 否则,直接是一个二分类问题
self
.
to_pred
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
1
,
bias
=
False
),
Rearrange
(
'... 1 -> ...'
)
# 降维
)
# reward 得分计算
self
.
pred_reward
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
1
),
Rearrange
(
'... 1 -> ...'
)
# 降维
)
def
load
(
self
,
path
):
path
=
Path
(
path
)
...
...
@@ -72,13 +62,10 @@ class RewardModel(nn.Module):
x
,
mask
=
None
,
prompt_mask
=
None
,
prompt_lengths
=
None
,
labels
=
None
,
sample
=
False
,
sample_temperature
=
1.
prompt_lengths
=
None
):
# prompt_mask 和 prompt_lengths 只能
给1个
# prompt_mask 和 prompt_lengths 只能
二选一
assert
not
(
exists
(
prompt_mask
)
and
exists
(
prompt_lengths
))
# derive prompt mask from prompt lengths
...
...
@@ -98,26 +85,19 @@ class RewardModel(nn.Module):
self
.
response_embed
)
#
todo(luxin) get embeddings from rwkv
embeds
=
self
.
rwkv
(
#
获得最后一个 token 的 embedding
last_token_
embeds
=
self
.
rwkv
(
x
,
extra_embed
=
extra_embed
,
return_only_embedding
=
True
state
=
None
,
extra_embed
=
extra_embed
)
# 所有的 token 向量求平均,并输入到打分模块进行打分
pooled
=
masked_mean
(
embeds
,
mask
,
dim
=
1
)
pred
=
self
.
to_pred
(
pooled
)
if
sample
and
self
.
binned_output
:
assert
not
exists
(
labels
)
pred
=
gumbel_sample
(
pred
,
temperature
=
sample_temperature
,
dim
=
-
1
)
if
not
exists
(
labels
):
return
pred
# todo(luxin) 作者没有使用论文中考虑两个样本的 loss,而是单个样本的 loss
if
not
self
.
binned_output
:
return
F
.
mse_loss
(
pred
,
labels
)
return
F
.
cross_entropy
(
pred
,
labels
)
try
:
pooled
=
masked_mean
(
last_token_embeds
,
mask
,
dim
=
1
)
except
:
import
ipdb
ipdb
.
set_trace
()
reward
=
self
.
pred_reward
(
pooled
)
return
reward
src/rlhf/rwkv.py
→
src/rlhf/rwkv
/model
.py
浏览文件 @
51560034
...
...
@@ -4,6 +4,7 @@
import
types
,
gc
,
os
,
time
import
torch
import
copy
from
torch.nn
import
functional
as
F
torch
.
backends
.
cudnn
.
benchmark
=
True
torch
.
backends
.
cudnn
.
allow_tf32
=
True
...
...
@@ -364,12 +365,13 @@ class RWKV(MyModule):
########################################################################################################
def
forward
(
self
,
tokens
,
state
,
full_output
=
False
):
def
forward
(
self
,
tokens
,
state
,
full_output
=
False
,
extra_embed
=
None
):
with
torch
.
no_grad
():
w
=
self
.
w
args
=
self
.
args
if
state
==
None
:
# 初始化 state
state
=
[
None
]
*
args
.
n_layer
*
5
for
i
in
range
(
args
.
n_layer
):
# state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx
dd
=
self
.
strategy
[
i
]
...
...
@@ -382,10 +384,15 @@ class RWKV(MyModule):
state
[
i
*
5
+
4
]
=
torch
.
zeros
(
args
.
n_embd
,
dtype
=
atype
,
requires_grad
=
False
,
device
=
dev
)
seq_mode
=
len
(
tokens
)
>
1
import
ipdb
ipdb
.
set_trace
()
# 输入:根据 idx 取每个 token 的 embedding
x
=
w
[
'emb.weight'
][
tokens
if
seq_mode
else
tokens
[
0
]]
if
extra_embed
is
not
None
:
x
=
x
+
extra_embed
# 推理:N 层的 Block(Attention, Feed Forward)
for
i
in
range
(
args
.
n_layer
):
bbb
=
f
'blocks.
{
i
}
.'
att
=
f
'blocks.
{
i
}
.att.'
...
...
@@ -404,8 +411,10 @@ class RWKV(MyModule):
ATT
=
self
.
att_one
FFN
=
self
.
ffn_one
# Tensor dtype and/or device 类型转换
x
=
x
.
to
(
dtype
=
atype
,
device
=
dev
)
# Attention 层
kw
=
self
.
get_w
(
f
'
{
att
}
key.weight'
,
atype
)
vw
=
self
.
get_w
(
f
'
{
att
}
value.weight'
,
atype
)
rw
=
self
.
get_w
(
f
'
{
att
}
receptance.weight'
,
atype
)
...
...
@@ -424,6 +433,7 @@ class RWKV(MyModule):
if
wtype
==
torch
.
uint8
or
dd
.
stream
:
del
kw
,
vw
,
rw
,
ow
# Feed Forward 层
kw
=
self
.
get_w
(
f
'
{
ffn
}
key.weight'
,
atype
)
vw
=
self
.
get_w
(
f
'
{
ffn
}
value.weight'
,
atype
)
rw
=
self
.
get_w
(
f
'
{
ffn
}
receptance.weight'
,
atype
)
...
...
@@ -443,15 +453,18 @@ class RWKV(MyModule):
if
(
i
+
1
)
%
self
.
RESCALE_LAYER
==
0
:
x
=
x
/
2
# 取所有 token 还是最后一个 token 的 embedding
dd
=
self
.
strategy
[
args
.
n_layer
]
x
=
x
[
-
1
,:]
if
(
seq_mode
and
(
not
full_output
))
else
x
x
=
x
.
to
(
dtype
=
dd
.
atype
,
device
=
dd
.
device
)
# 对 token embedding 进行 LayerNorm,维度不变
x
=
F
.
layer_norm
(
x
,
(
args
.
n_embd
,),
weight
=
w
[
'ln_out.weight'
],
bias
=
w
[
'ln_out.bias'
])
token_embed
=
copy
.
deepcopy
(
x
)
if
w
[
'head.weight'
].
dtype
!=
torch
.
uint8
:
x
=
x
@
w
[
'head.weight'
]
else
:
x
=
x
@
self
.
get_w
(
'head.weight'
,
dd
.
atype
)
return
x
.
float
(),
state
return
x
.
float
(),
state
,
token_embed
.
float
()
src/rlhf/rwkv/utils.py
0 → 100644
浏览文件 @
51560034
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import
json
,
time
,
random
,
os
import
numpy
as
np
import
torch
from
torch.nn
import
functional
as
F
from
tokenizers
import
Tokenizer
class
PIPELINE_ARGS
():
def
__init__
(
self
,
temperature
=
1.0
,
top_p
=
0.85
,
top_k
=
0
,
alpha_frequency
=
0.2
,
alpha_presence
=
0.2
,
token_ban
=
[],
token_stop
=
[]):
self
.
temperature
=
temperature
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
alpha_frequency
=
alpha_frequency
# Frequency Penalty (as in GPT-3)
self
.
alpha_presence
=
alpha_presence
# Presence Penalty (as in GPT-3)
self
.
token_ban
=
token_ban
# ban the generation of some tokens
self
.
token_stop
=
token_stop
# stop generation whenever you see any token here
class
PIPELINE
():
def
__init__
(
self
,
model
,
WORD_NAME
):
self
.
model
=
model
self
.
tokenizer
=
Tokenizer
.
from_file
(
WORD_NAME
)
def
refine_context
(
self
,
context
):
context
=
context
.
strip
().
split
(
'
\n
'
)
for
c
in
range
(
len
(
context
)):
context
[
c
]
=
context
[
c
].
strip
().
strip
(
'
\u3000
'
).
strip
(
'
\r
'
)
context
=
list
(
filter
(
lambda
c
:
c
!=
''
,
context
))
context
=
'
\n
'
+
(
'
\n
'
.
join
(
context
)).
strip
()
if
context
==
''
:
context
=
'
\n
'
return
context
def
encode
(
self
,
x
):
return
self
.
tokenizer
.
encode
(
x
).
ids
def
decode
(
self
,
x
):
return
self
.
tokenizer
.
decode
(
x
)
def
sample_logits
(
self
,
logits
,
temperature
=
1.0
,
top_p
=
0.85
,
top_k
=
0
):
probs
=
F
.
softmax
(
logits
.
float
(),
dim
=-
1
)
top_k
=
int
(
top_k
)
if
probs
.
device
==
torch
.
device
(
'cpu'
):
probs
=
probs
.
numpy
()
sorted_ids
=
np
.
argsort
(
probs
)
sorted_probs
=
probs
[
sorted_ids
][::
-
1
]
cumulative_probs
=
np
.
cumsum
(
sorted_probs
)
cutoff
=
float
(
sorted_probs
[
np
.
argmax
(
cumulative_probs
>
top_p
)])
probs
[
probs
<
cutoff
]
=
0
if
top_k
<
len
(
probs
)
and
top_k
>
0
:
probs
[
sorted_ids
[:
-
top_k
]]
=
0
if
temperature
!=
1.0
:
probs
=
probs
**
(
1.0
/
temperature
)
probs
=
probs
/
np
.
sum
(
probs
)
out
=
np
.
random
.
choice
(
a
=
len
(
probs
),
p
=
probs
)
return
int
(
out
)
else
:
sorted_ids
=
torch
.
argsort
(
probs
)
sorted_probs
=
probs
[
sorted_ids
]
sorted_probs
=
torch
.
flip
(
sorted_probs
,
dims
=
(
0
,))
cumulative_probs
=
torch
.
cumsum
(
sorted_probs
,
dim
=-
1
).
cpu
().
numpy
()
cutoff
=
float
(
sorted_probs
[
np
.
argmax
(
cumulative_probs
>
top_p
)])
probs
[
probs
<
cutoff
]
=
0
if
top_k
<
len
(
probs
)
and
top_k
>
0
:
probs
[
sorted_ids
[:
-
top_k
]]
=
0
if
temperature
!=
1.0
:
probs
=
probs
**
(
1.0
/
temperature
)
out
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)[
0
]
return
int
(
out
)
def
generate
(
self
,
ctx
,
token_count
=
100
,
args
=
PIPELINE_ARGS
(),
callback
=
None
,
state
=
None
):
all_tokens
=
[]
out_last
=
0
out_str
=
''
occurrence
=
{}
for
i
in
range
(
token_count
):
# forward & adjust prob.
out
,
state
=
self
.
model
.
forward
(
self
.
encode
(
ctx
)
if
i
==
0
else
[
token
],
state
)
for
n
in
args
.
token_ban
:
out
[
n
]
=
-
float
(
'inf'
)
for
n
in
occurrence
:
out
[
n
]
-=
(
args
.
alpha_presence
+
occurrence
[
n
]
*
args
.
alpha_frequency
)
# sampler
token
=
self
.
sample_logits
(
out
,
temperature
=
args
.
temperature
,
top_p
=
args
.
top_p
,
top_k
=
args
.
top_k
)
if
token
in
args
.
token_stop
:
break
all_tokens
+=
[
token
]
if
token
not
in
occurrence
:
occurrence
[
token
]
=
1
else
:
occurrence
[
token
]
+=
1
# output
tmp
=
self
.
decode
(
all_tokens
[
out_last
:])
if
'
\ufffd
'
not
in
tmp
:
# is valid utf-8 string?
if
callback
:
callback
(
tmp
)
out_str
+=
tmp
out_last
=
i
+
1
return
out_str
train_rm.py
浏览文件 @
51560034
...
...
@@ -6,27 +6,259 @@
'''
# here put the import lib
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
if
__name__
==
"__main__"
:
from
argparse
import
ArgumentParser
from
pytorch_lightning
import
Trainer
from
pytorch_lightning.utilities
import
rank_zero_info
,
rank_zero_only
rank_zero_info
(
"########## work in progress ##########"
)
########################################################################################################
#
# example: train a simple L12-D768 RWKV on dummy data
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "" --data_type "dummy" --vocab_size 0 \
# --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
# --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
# --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: train a simple L6-D512 RWKV from scratch on enwik8
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
parser
=
ArgumentParser
()
parser
.
add_argument
(
"--load_model"
,
default
=
""
,
type
=
str
)
# full path, with .pth
parser
.
add_argument
(
"--wandb"
,
default
=
""
,
type
=
str
)
# wandb project name. if "" then don't use wandb
parser
.
add_argument
(
"--proj_dir"
,
default
=
"out"
,
type
=
str
)
parser
.
add_argument
(
"--random_seed"
,
default
=
"-1"
,
type
=
int
)
parser
.
add_argument
(
"--data_file"
,
default
=
""
,
type
=
str
)
parser
.
add_argument
(
"--data_type"
,
default
=
"utf-8"
,
type
=
str
)
parser
.
add_argument
(
"--vocab_size"
,
default
=
0
,
type
=
int
)
# vocab_size = 0 means auto (for char-level LM and .txt data)
parser
.
add_argument
(
"--ctx_len"
,
default
=
1024
,
type
=
int
)
parser
.
add_argument
(
"--epoch_steps"
,
default
=
1000
,
type
=
int
)
# a mini "epoch" has [epoch_steps] steps
parser
.
add_argument
(
"--epoch_count"
,
default
=
500
,
type
=
int
)
# train for this many "epochs". will continue afterwards with lr = lr_final
parser
.
add_argument
(
"--epoch_begin"
,
default
=
0
,
type
=
int
)
# if you load a model trained for x "epochs", set epoch_begin = x
parser
.
add_argument
(
"--epoch_save"
,
default
=
5
,
type
=
int
)
# save the model every [epoch_save] "epochs"
parser
.
add_argument
(
"--micro_bsz"
,
default
=
12
,
type
=
int
)
# micro batch size (batch size per GPU)
parser
.
add_argument
(
"--n_layer"
,
default
=
6
,
type
=
int
)
parser
.
add_argument
(
"--n_embd"
,
default
=
512
,
type
=
int
)
parser
.
add_argument
(
"--dim_att"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--dim_ffn"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--pre_ffn"
,
default
=
0
,
type
=
int
)
# replace first att layer by ffn (sometimes better)
parser
.
add_argument
(
"--head_qk"
,
default
=
0
,
type
=
int
)
# my headQK trick
parser
.
add_argument
(
"--tiny_att_dim"
,
default
=
0
,
type
=
int
)
# tiny attention dim
parser
.
add_argument
(
"--tiny_att_layer"
,
default
=-
999
,
type
=
int
)
# tiny attention @ which layer
parser
.
add_argument
(
"--lr_init"
,
default
=
6e-4
,
type
=
float
)
# 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser
.
add_argument
(
"--lr_final"
,
default
=
1e-5
,
type
=
float
)
parser
.
add_argument
(
"--warmup_steps"
,
default
=
0
,
type
=
int
)
# try 50 if you load a model
parser
.
add_argument
(
"--beta1"
,
default
=
0.9
,
type
=
float
)
parser
.
add_argument
(
"--beta2"
,
default
=
0.99
,
type
=
float
)
# use 0.999 when your model is close to convergence
parser
.
add_argument
(
"--adam_eps"
,
default
=
1e-8
,
type
=
float
)
parser
.
add_argument
(
"--grad_cp"
,
default
=
0
,
type
=
int
)
# gradient checkpt: saves VRAM, but slower
parser
.
add_argument
(
"--my_pile_stage"
,
default
=
0
,
type
=
int
)
# my special pile mode
parser
.
add_argument
(
"--my_pile_shift"
,
default
=-
1
,
type
=
int
)
# my special pile mode - text shift
parser
.
add_argument
(
"--my_pile_edecay"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--layerwise_lr"
,
default
=
1
,
type
=
int
)
# layerwise lr for faster convergence (but slower it/s)
parser
.
add_argument
(
"--ds_bucket_mb"
,
default
=
200
,
type
=
int
)
# deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
parser
.
add_argument
(
"--my_img_version"
,
default
=
0
,
type
=
str
)
parser
.
add_argument
(
"--my_img_size"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--my_img_bit"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--my_img_clip"
,
default
=
'x'
,
type
=
str
)
parser
.
add_argument
(
"--my_img_clip_scale"
,
default
=
1
,
type
=
float
)
parser
.
add_argument
(
"--my_img_l1_scale"
,
default
=
0
,
type
=
float
)
parser
.
add_argument
(
"--my_img_encoder"
,
default
=
'x'
,
type
=
str
)
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
parser
.
add_argument
(
"--my_sample_len"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--my_ffn_shift"
,
default
=
1
,
type
=
int
)
parser
.
add_argument
(
"--my_att_shift"
,
default
=
1
,
type
=
int
)
parser
.
add_argument
(
"--my_pos_emb"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--load_partial"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--magic_prime"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--my_qa_mask"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--my_testing"
,
default
=
''
,
type
=
str
)
parser
=
Trainer
.
add_argparse_args
(
parser
)
args
=
parser
.
parse_args
()
########################################################################################################
import
os
,
warnings
,
math
,
datetime
,
sys
,
time
import
numpy
as
np
import
torch
from
torch.utils.data
import
DataLoader
import
deepspeed
import
pytorch_lightning
as
pl
from
pytorch_lightning
import
seed_everything
if
args
.
random_seed
>=
0
:
print
(
f
"########## WARNING: GLOBAL SEED
{
args
.
random_seed
}
THIS WILL AFFECT MULTIGPU SAMPLING ##########
\n
"
*
3
)
seed_everything
(
args
.
random_seed
)
np
.
set_printoptions
(
precision
=
4
,
suppress
=
True
,
linewidth
=
200
)
warnings
.
filterwarnings
(
"ignore"
,
".*Consider increasing the value of the `num_workers` argument*"
)
warnings
.
filterwarnings
(
"ignore"
,
".*The progress bar already tracks a metric with the*"
)
# os.environ["WDS_SHOW_SEED"] = "1"
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
args
.
my_timestamp
=
datetime
.
datetime
.
today
().
strftime
(
"%Y-%m-%d-%H-%M-%S"
)
args
.
enable_checkpointing
=
False
args
.
replace_sampler_ddp
=
False
args
.
logger
=
False
args
.
gradient_clip_val
=
1.0
args
.
num_sanity_val_steps
=
0
args
.
check_val_every_n_epoch
=
int
(
1e20
)
args
.
log_every_n_steps
=
int
(
1e20
)
args
.
max_epochs
=
-
1
# continue forever
args
.
betas
=
(
args
.
beta1
,
args
.
beta2
)
args
.
real_bsz
=
int
(
args
.
num_nodes
)
*
int
(
args
.
devices
)
*
args
.
micro_bsz
os
.
environ
[
"RWKV_T_MAX"
]
=
str
(
args
.
ctx_len
)
os
.
environ
[
"RWKV_MY_TESTING"
]
=
args
.
my_testing
if
args
.
dim_att
<=
0
:
args
.
dim_att
=
args
.
n_embd
if
args
.
dim_ffn
<=
0
:
args
.
dim_ffn
=
args
.
n_embd
*
4
args
.
run_name
=
f
"
{
args
.
vocab_size
}
ctx
{
args
.
ctx_len
}
L
{
args
.
n_layer
}
D
{
args
.
n_embd
}
"
if
not
os
.
path
.
exists
(
args
.
proj_dir
):
os
.
makedirs
(
args
.
proj_dir
)
samples_per_epoch
=
args
.
epoch_steps
*
args
.
real_bsz
tokens_per_epoch
=
samples_per_epoch
*
args
.
ctx_len
rank_zero_info
(
f
"""
############################################################################
#
# RWKV-4
{
args
.
precision
.
upper
()
}
on
{
args
.
num_nodes
}
x
{
args
.
devices
}
{
args
.
accelerator
.
upper
()
}
, bsz
{
args
.
num_nodes
}
x
{
args
.
devices
}
x
{
args
.
micro_bsz
}
=
{
args
.
real_bsz
}
,
{
args
.
strategy
}
{
'with grad_cp'
if
args
.
grad_cp
>
0
else
''
}
#
# Data =
{
args
.
data_file
}
(
{
args
.
data_type
}
), ProjDir =
{
args
.
proj_dir
}
#
# Epoch =
{
args
.
epoch_begin
}
to
{
args
.
epoch_begin
+
args
.
epoch_count
-
1
}
(will continue afterwards), save every
{
args
.
epoch_save
}
epoch
#
# Each "epoch" =
{
args
.
epoch_steps
}
steps,
{
samples_per_epoch
}
samples,
{
tokens_per_epoch
}
tokens
#
# Model =
{
args
.
n_layer
}
n_layer,
{
args
.
n_embd
}
n_embd,
{
args
.
ctx_len
}
ctx_len
#
# Adam = lr
{
args
.
lr_init
}
to
{
args
.
lr_final
}
, warmup
{
args
.
warmup_steps
}
steps, beta
{
args
.
betas
}
, eps
{
args
.
adam_eps
}
#
# Found torch
{
torch
.
__version__
}
, recommend 1.12.1+cu116 or newer
# Found deepspeed
{
deepspeed
.
__version__
}
, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning
{
pl
.
__version__
}
, recommend 1.7.4 or newer
#
############################################################################
"""
)
rank_zero_info
(
str
(
vars
(
args
))
+
"
\n
"
)
assert
args
.
data_type
in
[
"utf-8"
,
"utf-16le"
,
"numpy"
,
"binidx"
,
"dummy"
,
"wds_img"
,
"uint16"
]
if
args
.
lr_final
==
0
or
args
.
lr_init
==
0
:
rank_zero_info
(
"
\n\n
Note: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.
\n\n
"
)
assert
args
.
precision
in
[
"fp32"
,
"tf32"
,
"fp16"
,
"bf16"
]
os
.
environ
[
"RWKV_FLOAT_MODE"
]
=
args
.
precision
if
args
.
precision
==
"fp32"
:
rank_zero_info
(
"
\n\n
Note: you are using fp32 (very slow). Try bf16 / tf32 for faster training.
\n\n
"
)
if
args
.
precision
==
"fp16"
:
rank_zero_info
(
"
\n\n
Note: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.
\n\n
"
)
os
.
environ
[
"RWKV_JIT_ON"
]
=
"1"
if
"deepspeed_stage_3"
in
args
.
strategy
:
os
.
environ
[
"RWKV_JIT_ON"
]
=
"0"
torch
.
backends
.
cudnn
.
benchmark
=
True
torch
.
backends
.
cudnn
.
enabled
=
True
if
args
.
precision
==
"fp32"
:
torch
.
backends
.
cudnn
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
else
:
torch
.
backends
.
cudnn
.
allow_tf32
=
True
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
if
"32"
in
args
.
precision
:
args
.
precision
=
32
elif
args
.
precision
==
"fp16"
:
args
.
precision
=
16
else
:
args
.
precision
=
"bf16"
########################################################################################################
import
torch
from
src.rlhf.reward
import
RewardModel
from
src.model
import
RWKV
from
src.
rlhf.rwkv.
model
import
RWKV
rwkv_model
=
RWKV
()
model
=
"./model/RWKV-4-Pile-169M-20220807-8023.pth"
strategy
=
"cpu fp32"
rwkv_model
=
RWKV
(
model
,
strategy
)
dim
=
rwkv_model
.
args
.
n_embd
reward_model
=
RewardModel
(
rwkv_model
,
num_binned_output
=
5
# 打分等级,如果为5,打分等级分为 [0, 1, 2, 3, 4],共 5 个等级
rwkv_model
)
# mock data
seq
=
torch
.
randint
(
0
,
20000
,
(
1
,
100
))
# prompt_mask = torch.zeros(1, 100).bool() # which part of the sequence is prompt, which part is response
prompt
=
torch
.
randint
(
0
,
dim
,
(
1
,
50
))
prefer_response
=
torch
.
randint
(
0
,
dim
,
(
1
,
50
))
alter_response
=
torch
.
randint
(
0
,
dim
,
(
1
,
50
))
prefer_pair
=
torch
.
concat
((
prompt
,
prefer_response
),
dim
=
1
)
alter_pair
=
torch
.
concat
((
prompt
,
alter_response
),
dim
=
1
)
# which part of the sequence is prompt, which part is response
prompt_mask
=
torch
.
cat
((
torch
.
ones
(
1
,
50
).
bool
(),
torch
.
zeros
(
1
,
50
).
bool
()),
dim
=
1
)
labels
=
torch
.
randint
(
0
,
5
,
(
1
,))
#
labels = torch.randint(0, 5, (1,))
# train
loss
=
reward_model
(
seq
,
prompt_mask
=
prompt_mask
,
labels
=
labels
)
loss
.
backward
()
# loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)
# loss.backward()
# inference
prefer_reward
=
reward_model
(
prefer_pair
,
prompt_mask
=
prompt_mask
)
alter_reward
=
reward_model
(
alter_pair
,
prompt_mask
=
prompt_mask
)
# after much training
reward
=
reward_model
(
seq
,
prompt_mask
=
prompt_mask
)
\ No newline at end of file
print
(
"Preferred response reward:"
,
prefer_reward
)
print
(
"Alternate response reward:"
,
alter_reward
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录