Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
CSDN 技术社区
ai
chatCSDN
提交
0e61d27f
C
chatCSDN
项目概览
CSDN 技术社区
/
ai
/
chatCSDN
通知
107
Star
8
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
chatCSDN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
0e61d27f
编写于
3月 13, 2023
作者:
U
u010280923
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
reward model finished
上级
efd856e2
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
139 addition
and
3 deletion
+139
-3
README.md
README.md
+1
-1
src/trainer.py
src/trainer.py
+136
-0
train_rm.py
train_rm.py
+2
-2
未找到文件。
README.md
浏览文件 @
0e61d27f
...
...
@@ -64,7 +64,7 @@ python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft"
### Reward Model
```
python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_
sft
" \
python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_
rm
" \
--data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \
--ctx_len 2048 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
--micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
...
...
src/trainer.py
浏览文件 @
0e61d27f
...
...
@@ -151,6 +151,142 @@ class train_callback(pl.Callback):
trainer
.
my_loss_count
=
0
class
rm_train_callback
(
pl
.
Callback
):
def
__init__
(
self
,
args
):
super
().
__init__
()
self
.
args
=
args
def
on_train_batch_start
(
self
,
trainer
,
pl_module
,
batch
,
batch_idx
):
args
=
self
.
args
# if args.cuda_cleanup > 0:
# torch.cuda.empty_cache()
real_step
=
trainer
.
global_step
+
args
.
epoch_begin
*
args
.
epoch_steps
# LR schedule
w_step
=
args
.
warmup_steps
if
args
.
lr_final
==
args
.
lr_init
or
args
.
epoch_count
==
0
:
lr
=
args
.
lr_init
else
:
decay_step
=
real_step
-
args
.
my_pile_edecay
*
args
.
epoch_steps
decay_total
=
(
args
.
epoch_count
-
args
.
my_pile_edecay
)
*
args
.
epoch_steps
progress
=
(
decay_step
-
w_step
+
1
)
/
(
decay_total
-
w_step
)
progress
=
min
(
1
,
max
(
0
,
progress
))
if
args
.
lr_final
==
0
or
args
.
lr_init
==
0
:
# linear decay
lr
=
args
.
lr_init
+
(
args
.
lr_final
-
args
.
lr_init
)
*
progress
else
:
# exp decay
lr
=
args
.
lr_init
*
math
.
exp
(
math
.
log
(
args
.
lr_final
/
args
.
lr_init
)
*
pow
(
progress
,
1
))
if
trainer
.
global_step
<
w_step
:
lr
=
lr
*
(
0.2
+
0.8
*
trainer
.
global_step
/
w_step
)
# if trainer.is_global_zero:
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
for
param_group
in
trainer
.
optimizers
[
0
].
param_groups
:
if
args
.
layerwise_lr
>
0
:
param_group
[
"lr"
]
=
lr
*
param_group
[
"my_lr_scale"
]
# print(param_group["lr"], param_group["my_lr_scale"])
else
:
param_group
[
"lr"
]
=
lr
trainer
.
my_lr
=
lr
# rank_zero_info(f"{real_step} {lr}")
if
trainer
.
global_step
==
0
:
if
trainer
.
is_global_zero
:
# logging
trainer
.
my_loss_sum
=
0
trainer
.
my_loss_count
=
0
trainer
.
my_log
=
open
(
args
.
proj_dir
+
"/train_log.txt"
,
"a"
)
trainer
.
my_log
.
write
(
f
"NEW RUN
{
args
.
my_timestamp
}
\n
{
vars
(
self
.
args
)
}
\n
"
)
try
:
print
(
f
"
\n
{
trainer
.
strategy
.
config
}
\n
"
)
trainer
.
my_log
.
write
(
f
"
{
trainer
.
strategy
.
config
}
\n
"
)
except
:
pass
trainer
.
my_log
.
flush
()
if
len
(
args
.
wandb
)
>
0
:
print
(
"Login to wandb..."
)
import
wandb
wandb
.
init
(
project
=
args
.
wandb
,
name
=
args
.
run_name
+
" "
+
args
.
my_timestamp
,
config
=
args
,
save_code
=
False
,
)
trainer
.
my_wandb
=
wandb
def
on_train_batch_end
(
self
,
trainer
,
pl_module
,
outputs
,
batch
,
batch_idx
):
args
=
self
.
args
if
trainer
.
is_global_zero
:
# logging
t_now
=
time
.
time_ns
()
token_per_step
=
args
.
ctx_len
*
args
.
real_bsz
real_step
=
trainer
.
global_step
+
args
.
epoch_begin
*
args
.
epoch_steps
kt_s
=
0
try
:
t_cost
=
(
t_now
-
trainer
.
my_time_ns
)
/
1e9
kt_s
=
token_per_step
/
t_cost
/
1000
self
.
log
(
"REAL it/s"
,
1.0
/
t_cost
,
prog_bar
=
True
,
on_step
=
True
)
self
.
log
(
"Kt/s"
,
kt_s
,
prog_bar
=
True
,
on_step
=
True
)
except
:
pass
trainer
.
my_time_ns
=
t_now
trainer
.
my_loss
=
trainer
.
my_loss_all
.
float
().
mean
().
item
()
trainer
.
my_loss_sum
+=
trainer
.
my_loss
trainer
.
my_loss_count
+=
1
trainer
.
my_epoch_loss
=
trainer
.
my_loss_sum
/
trainer
.
my_loss_count
self
.
log
(
"lr"
,
trainer
.
my_lr
,
prog_bar
=
True
,
on_step
=
True
)
self
.
log
(
"loss"
,
trainer
.
my_epoch_loss
,
prog_bar
=
True
,
on_step
=
True
)
# self.log("s", real_step, prog_bar=True, on_step=True)
if
len
(
args
.
wandb
)
>
0
:
lll
=
{
"loss"
:
trainer
.
my_loss
,
"lr"
:
trainer
.
my_lr
,
"Gtokens"
:
real_step
*
token_per_step
/
1e9
}
if
kt_s
>
0
:
lll
[
"kt/s"
]
=
kt_s
trainer
.
my_wandb
.
log
(
lll
,
step
=
int
(
real_step
))
if
args
.
magic_prime
>
0
:
if
int
(
real_step
)
==
int
(
args
.
magic_prime
*
(
1
+
args
.
my_qa_mask
)
//
args
.
real_bsz
)
-
1
:
to_save_dict
=
pl_module
.
state_dict
()
my_save
(
to_save_dict
,
f
"
{
args
.
proj_dir
}
/rwkv-final.pth"
,
)
def
on_train_epoch_start
(
self
,
trainer
,
pl_module
):
args
=
self
.
args
dataset
=
trainer
.
train_dataloader
.
dataset
.
datasets
assert
"RMDataset"
in
str
(
dataset
)
dataset
.
global_rank
=
trainer
.
global_rank
dataset
.
real_epoch
=
int
(
args
.
epoch_begin
+
trainer
.
current_epoch
)
dataset
.
world_size
=
trainer
.
world_size
# print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
def
on_train_epoch_end
(
self
,
trainer
,
pl_module
):
args
=
self
.
args
if
trainer
.
is_global_zero
:
# logging & save state_dict
if
(
args
.
epoch_save
>
0
and
trainer
.
current_epoch
%
args
.
epoch_save
==
0
)
or
trainer
.
current_epoch
==
args
.
epoch_count
-
1
:
if
args
.
data_type
==
'wds_img'
:
raw_dict
=
pl_module
.
state_dict
()
to_save_dict
=
{}
for
k
in
raw_dict
:
if
k
.
startswith
(
'encoder.'
)
or
k
.
startswith
(
'decoder.'
):
to_save_dict
[
k
]
=
raw_dict
[
k
]
else
:
to_save_dict
=
pl_module
.
state_dict
()
try
:
my_save
(
to_save_dict
,
f
"
{
args
.
proj_dir
}
/rwkv-
{
args
.
epoch_begin
+
trainer
.
current_epoch
}
.pth"
,
)
except
Exception
as
e
:
print
(
'Error
\n\n
'
,
e
,
'
\n\n
'
)
trainer
.
my_log
.
write
(
f
"
{
args
.
epoch_begin
+
trainer
.
current_epoch
}
{
trainer
.
my_epoch_loss
:.
6
f
}
{
math
.
exp
(
trainer
.
my_epoch_loss
):.
4
f
}
{
trainer
.
my_lr
:.
8
f
}
{
datetime
.
datetime
.
now
()
}
{
trainer
.
current_epoch
}
\n
"
)
trainer
.
my_log
.
flush
()
trainer
.
my_loss_sum
=
0
trainer
.
my_loss_count
=
0
@
rank_zero_only
def
generate_init_weight
(
model
,
init_weight_name
):
mm
=
model
.
generate_init_weight
()
...
...
train_rm.py
浏览文件 @
0e61d27f
...
...
@@ -224,7 +224,7 @@ if __name__ == "__main__":
import
torch
from
tqdm
import
tqdm
from
src.trainer
import
train_callback
from
src.trainer
import
rm_
train_callback
from
src.rlhf.reward
import
RewardModel
from
src.dataset
import
RMDataset
...
...
@@ -239,7 +239,7 @@ if __name__ == "__main__":
# 训练
trainer
=
Trainer
.
from_argparse_args
(
args
,
callbacks
=
[
train_callback
(
args
)],
callbacks
=
[
rm_
train_callback
(
args
)],
)
if
trainer
.
global_rank
==
0
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录