Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
CSDN 技术社区
ai
chatCSDN
提交
a1fe3755
C
chatCSDN
项目概览
CSDN 技术社区
/
ai
/
chatCSDN
通知
107
Star
8
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
chatCSDN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a1fe3755
编写于
3月 10, 2023
作者:
U
u010280923
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
opt reward model
上级
dfeee746
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
43 addition
and
17 deletion
+43
-17
src/dataset.py
src/dataset.py
+37
-12
src/rlhf/reward.py
src/rlhf/reward.py
+6
-5
未找到文件。
src/dataset.py
浏览文件 @
a1fe3755
...
...
@@ -267,6 +267,9 @@ class RMDataset(Dataset):
"20B_tokenizer.json"
,
"20B_tokenizer.json"
,
]
# [vocab, vocab] for Pile model
self
.
prompt_mask_id
=
0
self
.
response_mask_id
=
1
self
.
padding_mask_id
=
2
self
.
tokenizer
=
TOKENIZER
(
WORD_NAME
)
pf
=
pd
.
read_csv
(
args
.
data_file
)
...
...
@@ -275,10 +278,26 @@ class RMDataset(Dataset):
prompt
=
row
[
"prompt"
]
preferred
=
row
[
"preferred"
]
alternate
=
row
[
"alternate"
]
preferred_sample
=
f
"
{
prompt
}
\n
{
preferred
}
"
alternate_sample
=
f
"
{
prompt
}
\n
{
alternate
}
"
data_list
.
append
((
self
.
tokenizer
.
tokenizer
.
encode
(
preferred_sample
),
self
.
tokenizer
.
tokenizer
.
encode
(
alternate_sample
)))
prompt_idx
=
self
.
tokenizer
.
tokenizer
.
encode
(
prompt
)
preferred_idx
=
self
.
tokenizer
.
tokenizer
.
encode
(
preferred
)
alternate_idx
=
self
.
tokenizer
.
tokenizer
.
encode
(
alternate
)
prompt_mask
=
[
self
.
padding_mask_id
]
*
len
(
prompt_idx
)
preferred_mask
=
[
self
.
response_mask_id
]
*
len
(
preferred_idx
)
alternate_mask
=
[
self
.
response_mask_id
]
*
len
(
alternate_idx
)
prompt_prefer_idx
=
prompt_idx
+
preferred_idx
prompt_alter_idx
=
prompt_idx
+
alternate_idx
prompt_prefer_mask
=
prompt_mask
+
preferred_mask
prompt_alter_mask
=
prompt_mask
+
alternate_mask
data_list
.
append
((
prompt_prefer_idx
,
prompt_alter_idx
,
prompt_prefer_mask
,
prompt_alter_mask
))
self
.
data
=
data_list
def
__len__
(
self
):
...
...
@@ -287,15 +306,21 @@ class RMDataset(Dataset):
def
__getitem__
(
self
,
index
):
ctx_len
=
self
.
args
.
ctx_len
req_len
=
ctx_len
+
1
pr
eferred_sample
,
alternate_sample
=
self
.
data
[
index
]
pr
ompt_prefer_idx
,
prompt_alter_idx
,
prompt_prefer_mask
,
prompt_alter_mask
=
self
.
data
[
index
]
preferred_sample
=
preferred_sample
[:
req_len
]
alternate_sample
=
alternate_sample
[:
req_len
]
prompt_prefer_idx
=
prompt_prefer_idx
[:
req_len
]
prompt_alter_idx
=
prompt_alter_idx
[:
req_len
]
prompt_prefer_mask
=
prompt_prefer_mask
[:
req_len
]
prompt_alter_mask
=
prompt_alter_mask
[:
req_len
]
preferred_sample
=
preferred_sample
+
[
0
]
*
(
req_len
-
len
(
preferred_sample
))
alternate_sample
=
alternate_sample
+
[
0
]
*
(
req_len
-
len
(
alternate_sample
))
prompt_prefer_idx
=
prompt_prefer_idx
+
[
1
]
*
(
req_len
-
len
(
prompt_prefer_idx
))
prompt_alter_idx
=
prompt_alter_idx
+
[
1
]
*
(
req_len
-
len
(
prompt_alter_idx
))
prompt_prefer_mask
=
prompt_prefer_mask
+
[
self
.
padding_mask_id
]
*
(
req_len
-
len
(
prompt_prefer_mask
))
prompt_alter_mask
=
prompt_alter_mask
+
[
self
.
padding_mask_id
]
*
(
req_len
-
len
(
prompt_alter_mask
))
x_p
=
torch
.
tensor
(
preferred_sample
,
dtype
=
torch
.
long
)
x_a
=
torch
.
tensor
(
alternate_sample
,
dtype
=
torch
.
long
)
x_p
=
torch
.
tensor
(
prompt_prefer_idx
,
dtype
=
torch
.
long
)
x_a
=
torch
.
tensor
(
prompt_alter_idx
,
dtype
=
torch
.
long
)
m_p
=
torch
.
tensor
(
prompt_prefer_mask
,
dtype
=
torch
.
long
)
m_a
=
torch
.
tensor
(
prompt_alter_mask
,
dtype
=
torch
.
long
)
return
x_p
,
x_a
\ No newline at end of file
return
x_p
,
x_a
,
m_p
,
m_a
\ No newline at end of file
src/rlhf/reward.py
浏览文件 @
a1fe3755
...
...
@@ -65,6 +65,7 @@ class RewardModel(pl.LightningModule):
# 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0
self
.
prompt_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
dim
))
self
.
response_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
dim
))
self
.
padding_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
dim
),
requires_grad
=
False
)
# reward 得分计算
self
.
pred_reward
=
nn
.
Sequential
(
...
...
@@ -135,16 +136,16 @@ class RewardModel(pl.LightningModule):
return
reward
def
forward
(
self
,
prefer_x
,
alter_x
,
prefer_x_prompt_mask
,
alter_x_prompt_mask
):
prefer_reward
=
self
.
single_forward
(
prefer_x
,
prefer_x_prompt_mask
)
alter_reward
=
self
.
single_forward
(
alter_x
,
alter_x_prompt_mask
)
def
forward
(
self
,
x_p
,
x_a
,
m_p
,
m_a
):
prefer_reward
=
self
.
single_forward
(
x_p
,
prompt_mask
=
m_p
)
alter_reward
=
self
.
single_forward
(
x_a
,
prompt_mask
=
m_a
)
return
prefer_reward
,
alter_reward
def
training_step
(
self
,
batch
,
batch_idx
):
prefer_x
,
alter_x
,
prefer_x_prompt_mask
,
alter_x_prompt_mask
=
batch
x_p
,
x_a
,
m_p
,
m_a
=
batch
prefer_reward
,
alter_reward
=
self
(
prefer_x
,
alter_x
,
prefer_x_prompt_mask
,
alter_x_prompt_mask
)
x_p
,
x_a
,
m_p
,
m_a
)
loss
=
loss_function
(
prefer_reward
,
alter_reward
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录