提交 fba22c90 编写于 作者: U u010280923

bug fixed

上级 40ae4fb9
......@@ -283,7 +283,7 @@ class RMDataset(Dataset):
preferred_idx = self.tokenizer.tokenizer.encode(preferred)
alternate_idx = self.tokenizer.tokenizer.encode(alternate)
prompt_mask = [self.padding_mask_id] * len(prompt_idx)
prompt_mask = [self.prompt_mask_id] * len(prompt_idx)
preferred_mask = [self.response_mask_id] * len(preferred_idx)
alternate_mask = [self.response_mask_id] * len(alternate_idx)
......
......@@ -66,9 +66,9 @@ class RewardModel(pl.LightningModule):
dim = self.args.n_embd
# 用于区分输入中的 prompt 和 response,当作模型参数进行训练,初始化为全0
self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim)).to()
self.response_embed = nn.Parameter(torch.zeros(1, 1, dim))
self.padding_embed = nn.Parameter(torch.zeros(1, 1, dim), requires_grad=False)
self.prompt_embed = nn.Parameter(torch.zeros(dim))
self.response_embed = nn.Parameter(torch.zeros(dim))
self.padding_embed = nn.Parameter(torch.zeros(dim), requires_grad=False)
# reward 得分计算
self.pred_reward = nn.Sequential(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册