From 08d3c799486515d3ef6ded7e588a70e92e8dda61 Mon Sep 17 00:00:00 2001 From: Le Hou Date: Fri, 30 Oct 2020 10:54:47 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 339901099 --- official/nlp/tasks/masked_lm.py | 16 +++++++++++----- official/nlp/tasks/masked_lm_test.py | 1 + 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/official/nlp/tasks/masked_lm.py b/official/nlp/tasks/masked_lm.py index 81b2be7f8..c854634cc 100644 --- a/official/nlp/tasks/masked_lm.py +++ b/official/nlp/tasks/masked_lm.py @@ -36,6 +36,9 @@ class MaskedLMConfig(cfg.TaskConfig): bert.ClsHeadConfig( inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence') ]) + # TODO(b/154564893): Mathematically, scale_loss should be True. + # However, it works better with scale_loss being False. + scale_loss: bool = False train_data: cfg.DataConfig = cfg.DataConfig() validation_data: cfg.DataConfig = cfg.DataConfig() @@ -161,12 +164,15 @@ class MaskedLMTask(base_task.Task): model_outputs=outputs, metrics=metrics, aux_losses=model.losses) - # Scales loss as the default gradients allreduce performs sum inside the - # optimizer. - # TODO(b/154564893): enable loss scaling. - # scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync + if self.task_config.scale_loss: + # Scales loss as the default gradients allreduce performs sum inside the + # optimizer. + scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync tvars = model.trainable_variables - grads = tape.gradient(loss, tvars) + if self.task_config.scale_loss: + grads = tape.gradient(scaled_loss, tvars) + else: + grads = tape.gradient(loss, tvars) optimizer.apply_gradients(list(zip(grads, tvars))) self.process_metrics(metrics, inputs, outputs) return {self.loss: loss} diff --git a/official/nlp/tasks/masked_lm_test.py b/official/nlp/tasks/masked_lm_test.py index 89edf4ffa..937862da2 100644 --- a/official/nlp/tasks/masked_lm_test.py +++ b/official/nlp/tasks/masked_lm_test.py @@ -28,6 +28,7 @@ class MLMTaskTest(tf.test.TestCase): def test_task(self): config = masked_lm.MaskedLMConfig( init_checkpoint=self.get_temp_dir(), + scale_loss=True, model=bert.PretrainerConfig( encoder=encoders.EncoderConfig( bert=encoders.BertEncoderConfig(vocab_size=30522, -- GitLab