From 790ed2262656b3e6c5b1db7dc66548ab53ca4125 Mon Sep 17 00:00:00 2001 From: Aston Zhang Date: Tue, 7 Dec 2021 20:30:43 +0000 Subject: [PATCH] fix bert --- .../bert-dataset.md | 2 +- .../bert-pretraining.md | 4 ++-- d2l/mxnet.py | 2 +- d2l/torch.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/chapter_natural-language-processing-pretraining/bert-dataset.md b/chapter_natural-language-processing-pretraining/bert-dataset.md index eb21efdc..d9655c1a 100644 --- a/chapter_natural-language-processing-pretraining/bert-dataset.md +++ b/chapter_natural-language-processing-pretraining/bert-dataset.md @@ -110,7 +110,7 @@ def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds, masked_token = tokens[mlm_pred_position] # 10%的时间:用随机词替换该词 else: - masked_token = random.randint(0, len(vocab) - 1) + masked_token = random.choice(vocab.idx_to_token) mlm_input_tokens[mlm_pred_position] = masked_token pred_positions_and_labels.append( (mlm_pred_position, tokens[mlm_pred_position])) diff --git a/chapter_natural-language-processing-pretraining/bert-pretraining.md b/chapter_natural-language-processing-pretraining/bert-pretraining.md index 0e47c6c2..94e190c1 100644 --- a/chapter_natural-language-processing-pretraining/bert-pretraining.md +++ b/chapter_natural-language-processing-pretraining/bert-pretraining.md @@ -108,7 +108,7 @@ def _get_batch_loss_bert(net, loss, vocab_size, tokens_X, ```{.python .input} def train_bert(train_iter, net, loss, vocab_size, devices, num_steps): trainer = gluon.Trainer(net.collect_params(), 'adam', - {'learning_rate': 1e-3}) + {'learning_rate': 0.01}) step, timer = 0, d2l.Timer() animator = d2l.Animator(xlabel='step', ylabel='loss', xlim=[1, num_steps], legend=['mlm', 'nsp']) @@ -151,7 +151,7 @@ def train_bert(train_iter, net, loss, vocab_size, devices, num_steps): #@tab pytorch def train_bert(train_iter, net, loss, vocab_size, devices, num_steps): net = nn.DataParallel(net, device_ids=devices).to(devices[0]) - trainer = torch.optim.Adam(net.parameters(), lr=1e-3) + trainer = torch.optim.Adam(net.parameters(), lr=0.01) step, timer = 0, d2l.Timer() animator = d2l.Animator(xlabel='step', ylabel='loss', xlim=[1, num_steps], legend=['mlm', 'nsp']) diff --git a/d2l/mxnet.py b/d2l/mxnet.py index acac2995..fee984ad 100644 --- a/d2l/mxnet.py +++ b/d2l/mxnet.py @@ -1729,7 +1729,7 @@ def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds, masked_token = tokens[mlm_pred_position] # 10%的时间:用随机词替换该词 else: - masked_token = random.randint(0, len(vocab) - 1) + masked_token = random.choice(vocab.idx_to_token) mlm_input_tokens[mlm_pred_position] = masked_token pred_positions_and_labels.append( (mlm_pred_position, tokens[mlm_pred_position])) diff --git a/d2l/torch.py b/d2l/torch.py index d4f8bae2..44542a22 100644 --- a/d2l/torch.py +++ b/d2l/torch.py @@ -1828,7 +1828,7 @@ def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds, masked_token = tokens[mlm_pred_position] # 10%的时间:用随机词替换该词 else: - masked_token = random.randint(0, len(vocab) - 1) + masked_token = random.choice(vocab.idx_to_token) mlm_input_tokens[mlm_pred_position] = masked_token pred_positions_and_labels.append( (mlm_pred_position, tokens[mlm_pred_position])) -- GitLab