提交 012d7e74 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 490119364
上级 24a39da3
......@@ -100,6 +100,7 @@ class Router(tf.keras.layers.Layer):
use_bias: bool = True,
kernel_initializer: _InitializerType = _DEFAULT_KERNEL_INITIALIZER,
bias_initializer: _InitializerType = _DEFAULT_BIAS_INITIALIZER,
router_z_loss_weight: float = 0.0,
name: str = "router",
dtype: Any = tf.float32,
**kwargs):
......@@ -112,6 +113,8 @@ class Router(tf.keras.layers.Layer):
weights.
kernel_initializer: Kernel initializer for router weights.
bias_initializer: Bias initializer for router weights.
router_z_loss_weight: Weight for router_z_loss. Use non-zero values if
running into training instability (esp. with dtype 'bfloat16' or lower).
name: Layer name.
dtype: The dtype of the layer's computations and weights. tf.float32 is
recommended for stability.
......@@ -122,7 +125,7 @@ class Router(tf.keras.layers.Layer):
self.num_experts = num_experts # Used to check consistency with
# FeedForwardExperts.
self.jitter_noise = jitter_noise
self.router_z_loss_weight = router_z_loss_weight
self.router_weights = tf.keras.layers.Dense(
num_experts,
use_bias=use_bias,
......@@ -156,8 +159,10 @@ class Router(tf.keras.layers.Layer):
inputs, apply_jitter=training)
# router_probs <float32>[num_groups, tokens_per_group, num_experts]
# router_logits <float>[num_groups, tokens_per_group, num_experts]
router_z_loss = _router_z_loss(router_logits)
unscaled_router_z_loss = _router_z_loss(router_logits)
router_z_loss = self.router_z_loss_weight * unscaled_router_z_loss
self.add_loss(router_z_loss)
self.add_metric(unscaled_router_z_loss, name="unscaled_router_z_loss")
self.add_metric(router_z_loss, name="router_z_loss")
routing_instructions = self._compute_routing_instructions(
......
......@@ -224,7 +224,7 @@ class MoeTest(tf.test.TestCase):
metrics = [metric.name for metric in moe_layer.metrics]
self.assertSetEqual(
{
'router_z_loss', 'load_balancing_loss',
'router_z_loss', 'unscaled_router_z_loss', 'load_balancing_loss',
'fraction_tokens_left_behind', 'router_confidence', 'expert_usage'
}, set(metrics))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册