Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
CSDN 技术社区
ai
chatCSDN
提交
d19266ad
C
chatCSDN
项目概览
CSDN 技术社区
/
ai
/
chatCSDN
通知
107
Star
8
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
chatCSDN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d19266ad
编写于
3月 01, 2023
作者:
CSDN-Ada助手
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
sft
上级
e7d452e4
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
308 addition
and
0 deletion
+308
-0
src/dataset.py
src/dataset.py
+38
-0
train_sft.py
train_sft.py
+270
-0
未找到文件。
src/dataset.py
浏览文件 @
d19266ad
...
...
@@ -4,9 +4,11 @@
import
json
,
math
,
random
,
os
,
sys
import
numpy
as
np
import
pandas
as
pd
import
torch
from
torch.utils.data
import
Dataset
from
pytorch_lightning.utilities
import
rank_zero_info
from
src.utils
import
TOKENIZER
from
.binidx
import
MMapIndexedDataset
from
.utils
import
MaybeIsPrime
...
...
@@ -216,3 +218,39 @@ class MyDataset(Dataset):
return
x
,
y
,
z
return
x
,
y
class
S2SDataset
(
Dataset
):
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
vocab_size
=
args
.
vocab_size
WORD_NAME
=
[
"20B_tokenizer.json"
,
"20B_tokenizer.json"
,
]
# [vocab, vocab] for Pile model
self
.
tokenizer
=
TOKENIZER
(
WORD_NAME
)
pf
=
pd
.
read_csv
(
args
.
data_file
)
data_list
=
[]
for
index
,
row
in
pf
.
iterrows
():
question
=
row
[
"question"
]
answer
=
row
[
"answer"
]
data_list
.
append
((
self
.
tokenizer
.
tokenizer
.
encode
(
question
),
self
.
tokenizer
.
tokenizer
.
encode
(
answer
)))
self
.
data
=
data_list
def
__len__
(
self
):
return
len
(
self
.
data
)
def
__getitem__
(
self
,
index
):
ctx_len
=
self
.
args
.
ctx_len
req_len
=
ctx_len
+
1
question
,
answer
=
self
.
data
[
index
]
text
=
question
+
answer
text
=
text
[:
req_len
]
x
=
torch
.
tensor
(
text
,
dtype
=
torch
.
long
)
y
=
torch
.
tensor
(
answer
,
dtype
=
torch
.
long
)
z
=
[
1
]
*
len
(
question
)
+
[
0
]
*
(
req_len
-
len
(
question
))
z
=
torch
.
tensor
(
z
,
dtype
=
torch
.
long
)
return
x
,
y
,
z
train_sft.py
0 → 100644
浏览文件 @
d19266ad
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/3/1 11:54
# @Author : clong
# @File : train_sft.py
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
if
__name__
==
"__main__"
:
from
argparse
import
ArgumentParser
from
pytorch_lightning
import
Trainer
from
pytorch_lightning.utilities
import
rank_zero_info
,
rank_zero_only
rank_zero_info
(
"########## work in progress ##########"
)
########################################################################################################
#
# example: train a simple L12-D768 RWKV on dummy data
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "" --data_type "dummy" --vocab_size 0 \
# --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
# --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
# --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: train a simple L6-D512 RWKV from scratch on enwik8
#
# python train.py --load_model "" --wandb "" --proj_dir "out" \
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
#
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
parser
=
ArgumentParser
()
parser
.
add_argument
(
"--load_model"
,
default
=
""
,
type
=
str
)
# full path, with .pth
parser
.
add_argument
(
"--wandb"
,
default
=
""
,
type
=
str
)
# wandb project name. if "" then don't use wandb
parser
.
add_argument
(
"--proj_dir"
,
default
=
"out"
,
type
=
str
)
parser
.
add_argument
(
"--random_seed"
,
default
=
"-1"
,
type
=
int
)
parser
.
add_argument
(
"--data_file"
,
default
=
""
,
type
=
str
)
parser
.
add_argument
(
"--data_type"
,
default
=
"utf-8"
,
type
=
str
)
parser
.
add_argument
(
"--vocab_size"
,
default
=
0
,
type
=
int
)
# vocab_size = 0 means auto (for char-level LM and .txt data)
parser
.
add_argument
(
"--ctx_len"
,
default
=
1024
,
type
=
int
)
parser
.
add_argument
(
"--epoch_steps"
,
default
=
1000
,
type
=
int
)
# a mini "epoch" has [epoch_steps] steps
parser
.
add_argument
(
"--epoch_count"
,
default
=
500
,
type
=
int
)
# train for this many "epochs". will continue afterwards with lr = lr_final
parser
.
add_argument
(
"--epoch_begin"
,
default
=
0
,
type
=
int
)
# if you load a model trained for x "epochs", set epoch_begin = x
parser
.
add_argument
(
"--epoch_save"
,
default
=
5
,
type
=
int
)
# save the model every [epoch_save] "epochs"
parser
.
add_argument
(
"--micro_bsz"
,
default
=
12
,
type
=
int
)
# micro batch size (batch size per GPU)
parser
.
add_argument
(
"--n_layer"
,
default
=
6
,
type
=
int
)
parser
.
add_argument
(
"--n_embd"
,
default
=
512
,
type
=
int
)
parser
.
add_argument
(
"--dim_att"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--dim_ffn"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--pre_ffn"
,
default
=
0
,
type
=
int
)
# replace first att layer by ffn (sometimes better)
parser
.
add_argument
(
"--head_qk"
,
default
=
0
,
type
=
int
)
# my headQK trick
parser
.
add_argument
(
"--tiny_att_dim"
,
default
=
0
,
type
=
int
)
# tiny attention dim
parser
.
add_argument
(
"--tiny_att_layer"
,
default
=-
999
,
type
=
int
)
# tiny attention @ which layer
parser
.
add_argument
(
"--lr_init"
,
default
=
6e-4
,
type
=
float
)
# 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser
.
add_argument
(
"--lr_final"
,
default
=
1e-5
,
type
=
float
)
parser
.
add_argument
(
"--warmup_steps"
,
default
=
0
,
type
=
int
)
# try 50 if you load a model
parser
.
add_argument
(
"--beta1"
,
default
=
0.9
,
type
=
float
)
parser
.
add_argument
(
"--beta2"
,
default
=
0.99
,
type
=
float
)
# use 0.999 when your model is close to convergence
parser
.
add_argument
(
"--adam_eps"
,
default
=
1e-8
,
type
=
float
)
parser
.
add_argument
(
"--grad_cp"
,
default
=
0
,
type
=
int
)
# gradient checkpt: saves VRAM, but slower
parser
.
add_argument
(
"--my_pile_stage"
,
default
=
0
,
type
=
int
)
# my special pile mode
parser
.
add_argument
(
"--my_pile_shift"
,
default
=-
1
,
type
=
int
)
# my special pile mode - text shift
parser
.
add_argument
(
"--my_pile_edecay"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--layerwise_lr"
,
default
=
1
,
type
=
int
)
# layerwise lr for faster convergence (but slower it/s)
parser
.
add_argument
(
"--ds_bucket_mb"
,
default
=
200
,
type
=
int
)
# deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
parser
.
add_argument
(
"--my_img_version"
,
default
=
0
,
type
=
str
)
parser
.
add_argument
(
"--my_img_size"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--my_img_bit"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--my_img_clip"
,
default
=
'x'
,
type
=
str
)
parser
.
add_argument
(
"--my_img_clip_scale"
,
default
=
1
,
type
=
float
)
parser
.
add_argument
(
"--my_img_l1_scale"
,
default
=
0
,
type
=
float
)
parser
.
add_argument
(
"--my_img_encoder"
,
default
=
'x'
,
type
=
str
)
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
parser
.
add_argument
(
"--my_sample_len"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--my_ffn_shift"
,
default
=
1
,
type
=
int
)
parser
.
add_argument
(
"--my_att_shift"
,
default
=
1
,
type
=
int
)
parser
.
add_argument
(
"--my_pos_emb"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--load_partial"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--magic_prime"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--my_qa_mask"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--my_testing"
,
default
=
''
,
type
=
str
)
parser
=
Trainer
.
add_argparse_args
(
parser
)
args
=
parser
.
parse_args
()
########################################################################################################
import
os
,
warnings
,
math
,
datetime
,
sys
,
time
import
numpy
as
np
import
torch
from
torch.utils.data
import
DataLoader
import
deepspeed
import
pytorch_lightning
as
pl
from
pytorch_lightning
import
seed_everything
if
args
.
random_seed
>=
0
:
print
(
f
"########## WARNING: GLOBAL SEED
{
args
.
random_seed
}
THIS WILL AFFECT MULTIGPU SAMPLING ##########
\n
"
*
3
)
seed_everything
(
args
.
random_seed
)
np
.
set_printoptions
(
precision
=
4
,
suppress
=
True
,
linewidth
=
200
)
warnings
.
filterwarnings
(
"ignore"
,
".*Consider increasing the value of the `num_workers` argument*"
)
warnings
.
filterwarnings
(
"ignore"
,
".*The progress bar already tracks a metric with the*"
)
# os.environ["WDS_SHOW_SEED"] = "1"
args
.
my_timestamp
=
datetime
.
datetime
.
today
().
strftime
(
"%Y-%m-%d-%H-%M-%S"
)
args
.
enable_checkpointing
=
False
args
.
replace_sampler_ddp
=
False
args
.
logger
=
False
args
.
gradient_clip_val
=
1.0
args
.
num_sanity_val_steps
=
0
args
.
check_val_every_n_epoch
=
int
(
1e20
)
args
.
log_every_n_steps
=
int
(
1e20
)
args
.
max_epochs
=
-
1
# continue forever
args
.
betas
=
(
args
.
beta1
,
args
.
beta2
)
args
.
real_bsz
=
int
(
args
.
num_nodes
)
*
int
(
args
.
devices
)
*
args
.
micro_bsz
os
.
environ
[
"RWKV_T_MAX"
]
=
str
(
args
.
ctx_len
)
os
.
environ
[
"RWKV_MY_TESTING"
]
=
args
.
my_testing
if
args
.
dim_att
<=
0
:
args
.
dim_att
=
args
.
n_embd
if
args
.
dim_ffn
<=
0
:
args
.
dim_ffn
=
args
.
n_embd
*
4
args
.
run_name
=
f
"
{
args
.
vocab_size
}
ctx
{
args
.
ctx_len
}
L
{
args
.
n_layer
}
D
{
args
.
n_embd
}
"
if
not
os
.
path
.
exists
(
args
.
proj_dir
):
os
.
makedirs
(
args
.
proj_dir
)
samples_per_epoch
=
args
.
epoch_steps
*
args
.
real_bsz
tokens_per_epoch
=
samples_per_epoch
*
args
.
ctx_len
rank_zero_info
(
f
"""
############################################################################
#
# RWKV-4
{
args
.
precision
.
upper
()
}
on
{
args
.
num_nodes
}
x
{
args
.
devices
}
{
args
.
accelerator
.
upper
()
}
, bsz
{
args
.
num_nodes
}
x
{
args
.
devices
}
x
{
args
.
micro_bsz
}
=
{
args
.
real_bsz
}
,
{
args
.
strategy
}
{
'with grad_cp'
if
args
.
grad_cp
>
0
else
''
}
#
# Data =
{
args
.
data_file
}
(
{
args
.
data_type
}
), ProjDir =
{
args
.
proj_dir
}
#
# Epoch =
{
args
.
epoch_begin
}
to
{
args
.
epoch_begin
+
args
.
epoch_count
-
1
}
(will continue afterwards), save every
{
args
.
epoch_save
}
epoch
#
# Each "epoch" =
{
args
.
epoch_steps
}
steps,
{
samples_per_epoch
}
samples,
{
tokens_per_epoch
}
tokens
#
# Model =
{
args
.
n_layer
}
n_layer,
{
args
.
n_embd
}
n_embd,
{
args
.
ctx_len
}
ctx_len
#
# Adam = lr
{
args
.
lr_init
}
to
{
args
.
lr_final
}
, warmup
{
args
.
warmup_steps
}
steps, beta
{
args
.
betas
}
, eps
{
args
.
adam_eps
}
#
# Found torch
{
torch
.
__version__
}
, recommend 1.12.1+cu116 or newer
# Found deepspeed
{
deepspeed
.
__version__
}
, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning
{
pl
.
__version__
}
, recommend 1.7.4 or newer
#
############################################################################
"""
)
rank_zero_info
(
str
(
vars
(
args
))
+
"
\n
"
)
assert
args
.
data_type
in
[
"utf-8"
,
"utf-16le"
,
"numpy"
,
"binidx"
,
"dummy"
,
"wds_img"
,
"uint16"
]
if
args
.
lr_final
==
0
or
args
.
lr_init
==
0
:
rank_zero_info
(
"
\n\n
Note: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.
\n\n
"
)
assert
args
.
precision
in
[
"fp32"
,
"tf32"
,
"fp16"
,
"bf16"
]
os
.
environ
[
"RWKV_FLOAT_MODE"
]
=
args
.
precision
if
args
.
precision
==
"fp32"
:
rank_zero_info
(
"
\n\n
Note: you are using fp32 (very slow). Try bf16 / tf32 for faster training.
\n\n
"
)
if
args
.
precision
==
"fp16"
:
rank_zero_info
(
"
\n\n
Note: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.
\n\n
"
)
os
.
environ
[
"RWKV_JIT_ON"
]
=
"1"
if
"deepspeed_stage_3"
in
args
.
strategy
:
os
.
environ
[
"RWKV_JIT_ON"
]
=
"0"
torch
.
backends
.
cudnn
.
benchmark
=
True
torch
.
backends
.
cudnn
.
enabled
=
True
if
args
.
precision
==
"fp32"
:
torch
.
backends
.
cudnn
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
else
:
torch
.
backends
.
cudnn
.
allow_tf32
=
True
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
if
"32"
in
args
.
precision
:
args
.
precision
=
32
elif
args
.
precision
==
"fp16"
:
args
.
precision
=
16
else
:
args
.
precision
=
"bf16"
########################################################################################################
from
src.trainer
import
train_callback
,
generate_init_weight
from
src.dataset
import
S2SDataset
train_data
=
S2SDataset
(
args
)
args
.
vocab_size
=
train_data
.
vocab_size
from
src.model
import
RWKV
model
=
RWKV
(
args
)
if
len
(
args
.
load_model
)
==
0
:
rank_zero_info
(
f
"SFT must load model, please input "
)
exit
(
1
)
rank_zero_info
(
f
"########## Loading
{
args
.
load_model
}
... ##########"
)
try
:
load_dict
=
torch
.
load
(
args
.
load_model
,
map_location
=
"cpu"
)
except
:
rank_zero_info
(
f
"Bad checkpoint
{
args
.
load_model
}
"
)
exit
(
1
)
if
args
.
load_partial
==
1
:
load_keys
=
load_dict
.
keys
()
for
k
in
model
.
state_dict
():
if
k
not
in
load_keys
:
load_dict
[
k
]
=
model
.
state_dict
()[
k
]
model
.
load_state_dict
(
load_dict
)
trainer
=
Trainer
.
from_argparse_args
(
args
,
callbacks
=
[
train_callback
(
args
)],
)
if
trainer
.
global_rank
==
0
:
for
n
in
model
.
state_dict
():
shape
=
model
.
state_dict
()[
n
].
shape
shape
=
[
i
for
i
in
shape
if
i
!=
1
]
if
len
(
shape
)
>
1
:
print
(
f
"
{
str
(
shape
[
0
]).
ljust
(
5
)
}
{
str
(
shape
[
1
]).
ljust
(
5
)
}
{
n
}
"
)
else
:
print
(
f
"
{
str
(
shape
[
0
]).
ljust
(
5
)
}
{
n
}
"
)
if
"deepspeed"
in
args
.
strategy
:
trainer
.
strategy
.
config
[
"zero_optimization"
][
"allgather_bucket_size"
]
=
args
.
ds_bucket_mb
*
1000
*
1000
trainer
.
strategy
.
config
[
"zero_optimization"
][
"reduce_bucket_size"
]
=
args
.
ds_bucket_mb
*
1000
*
1000
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
data_loader
=
DataLoader
(
train_data
,
shuffle
=
False
,
pin_memory
=
True
,
batch_size
=
args
.
micro_bsz
,
num_workers
=
1
,
persistent_workers
=
False
,
drop_last
=
True
)
trainer
.
fit
(
model
,
data_loader
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录