提交 aca137c1 编写于 作者: R Ran Chen 提交者: A. Unique TensorFlower

Workaround a known issue with control dependency on external tensors

This fix tensorflow_models/official/nlp/bert/run_pretraining failure with
explicit allreduce.

PiperOrigin-RevId: 338521478
上级 2d2582fc
......@@ -199,7 +199,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
# backward pass.
# TODO(b/171088214): Remove it after the control dependency in
# nested function is fixed.
with tf.control_dependencies([grad]):
with tf.control_dependencies([tf.identity(grad)]):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
......@@ -212,7 +212,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
# backward pass.
# TODO(b/171088214): Remove it after the control dependency in
# nested function is fixed.
with tf.control_dependencies([grad]):
with tf.control_dependencies([tf.identity(grad)]):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册