提交 790ed226 编写于 作者: A Aston Zhang

fix bert

上级 e8473558
......@@ -110,7 +110,7 @@ def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,
masked_token = tokens[mlm_pred_position]
# 10%的时间:用随机词替换该词
else:
masked_token = random.randint(0, len(vocab) - 1)
masked_token = random.choice(vocab.idx_to_token)
mlm_input_tokens[mlm_pred_position] = masked_token
pred_positions_and_labels.append(
(mlm_pred_position, tokens[mlm_pred_position]))
......
......@@ -108,7 +108,7 @@ def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,
```{.python .input}
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
trainer = gluon.Trainer(net.collect_params(), 'adam',
{'learning_rate': 1e-3})
{'learning_rate': 0.01})
step, timer = 0, d2l.Timer()
animator = d2l.Animator(xlabel='step', ylabel='loss',
xlim=[1, num_steps], legend=['mlm', 'nsp'])
......@@ -151,7 +151,7 @@ def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
#@tab pytorch
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
net = nn.DataParallel(net, device_ids=devices).to(devices[0])
trainer = torch.optim.Adam(net.parameters(), lr=1e-3)
trainer = torch.optim.Adam(net.parameters(), lr=0.01)
step, timer = 0, d2l.Timer()
animator = d2l.Animator(xlabel='step', ylabel='loss',
xlim=[1, num_steps], legend=['mlm', 'nsp'])
......
......@@ -1729,7 +1729,7 @@ def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,
masked_token = tokens[mlm_pred_position]
# 10%的时间:用随机词替换该词
else:
masked_token = random.randint(0, len(vocab) - 1)
masked_token = random.choice(vocab.idx_to_token)
mlm_input_tokens[mlm_pred_position] = masked_token
pred_positions_and_labels.append(
(mlm_pred_position, tokens[mlm_pred_position]))
......
......@@ -1828,7 +1828,7 @@ def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,
masked_token = tokens[mlm_pred_position]
# 10%的时间:用随机词替换该词
else:
masked_token = random.randint(0, len(vocab) - 1)
masked_token = random.choice(vocab.idx_to_token)
mlm_input_tokens[mlm_pred_position] = masked_token
pred_positions_and_labels.append(
(mlm_pred_position, tokens[mlm_pred_position]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册