提交 08d3c799 编写于 作者: L Le Hou 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 339901099
上级 499a003d
......@@ -36,6 +36,9 @@ class MaskedLMConfig(cfg.TaskConfig):
bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
])
# TODO(b/154564893): Mathematically, scale_loss should be True.
# However, it works better with scale_loss being False.
scale_loss: bool = False
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
......@@ -161,12 +164,15 @@ class MaskedLMTask(base_task.Task):
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
# TODO(b/154564893): enable loss scaling.
# scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
if self.task_config.scale_loss:
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
tvars = model.trainable_variables
grads = tape.gradient(loss, tvars)
if self.task_config.scale_loss:
grads = tape.gradient(scaled_loss, tvars)
else:
grads = tape.gradient(loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss}
......
......@@ -28,6 +28,7 @@ class MLMTaskTest(tf.test.TestCase):
def test_task(self):
config = masked_lm.MaskedLMConfig(
init_checkpoint=self.get_temp_dir(),
scale_loss=True,
model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册