提交 09fc08b8 编写于 作者: U u010280923

opt reward model

上级 5f46304e
......@@ -66,14 +66,9 @@ class RewardModel(pl.LightningModule):
dim = self.args.n_embd
# 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0
self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim))
self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim)).to()
self.response_embed = nn.Parameter(torch.zeros(1, 1, dim))
self.padding_embed = nn.Parameter(torch.zeros(1, 1, dim), requires_grad=False)
self.prompt_response_mask_embed = torch.stack([
self.prompt_embed,
self.response_embed,
self.padding_embed
])
# reward 得分计算
self.pred_reward = nn.Sequential(
......@@ -158,15 +153,20 @@ class RewardModel(pl.LightningModule):
# derive prompt mask from prompt lengths
if exists(prompt_lengths):
batch, seq_len = x.shape
arange = torch.arange(seq_len, device = x.device)
arange = torch.arange(seq_len, device=x.device)
prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1')
# reward model should have an understanding of which section is prompt, and which section is response
# 根据 prompt_mask 中 token 的 True 和 False,从 prompt_embed 或 response_embed 中取值
# 如果为 True,则从 prompt_embed 中选,否则从 response_embed 中选
prompt_response_mask_embed = torch.stack([
self.prompt_embed,
self.response_embed,
self.padding_embed
]).to(prompt_mask.device)
extra_embed = None
if exists(prompt_mask):
extra_embed = self.prompt_response_mask_embed[prompt_mask]
extra_embed = prompt_response_mask_embed[prompt_mask]
# 获得最后一个 token 的 embedding
last_token_embeds = self.rwkv(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册