提交 7dd9852e 编写于 作者: J Jinoo Baek 提交者: A. Unique TensorFlower

Indentation bug. Divide by num_replicas_in_sync once.

PiperOrigin-RevId: 448508763
上级 f7201d1a
......@@ -138,10 +138,10 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
self.tasks[name].process_metrics(task_metrics[name], labels, outputs,
**kwargs)
# Scales loss as the default gradients allreduce performs sum inside
# the optimizer.
scaled_loss = total_loss / tf.distribute.get_strategy(
).num_replicas_in_sync
# Scales loss as the default gradients allreduce performs sum inside
# the optimizer.
scaled_loss = total_loss / tf.distribute.get_strategy(
).num_replicas_in_sync
tvars = multi_task_model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册