提交 66f76664 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 485676095
上级 00351f9d
......@@ -51,10 +51,14 @@ class TpuBatchNormalization(tf.keras.layers.BatchNormalization):
return tf1.tpu.cross_replica_sum(t, group_assignment) / tf.cast(
num_shards_per_group, t.dtype)
def _moments(self, inputs: tf.Tensor, reduction_axes: int, keep_dims: int):
def _moments(self,
inputs: tf.Tensor,
reduction_axes: int,
keep_dims: int,
mask: Optional[tf.Tensor] = None):
"""Compute the mean and variance: it overrides the original _moments."""
shard_mean, shard_variance = super(TpuBatchNormalization, self)._moments(
inputs, reduction_axes, keep_dims=keep_dims)
inputs, reduction_axes, keep_dims=keep_dims, mask=mask)
num_shards = tpu_function.get_tpu_context().number_of_shards or 1
if num_shards <= 8: # Skip cross_replica for 2x2 or smaller slices.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册