diff --git a/modeling.py b/modeling.py index 8b5da0003ac4b324568ac34996670fd522e21576..ea575220a3a36448e8b714db47815f3f30a85612 100644 --- a/modeling.py +++ b/modeling.py @@ -740,12 +740,12 @@ def attention_layer(from_tensor, context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) if do_return_2d_tensor: - # `context_layer` = [B*F, N*V] + # `context_layer` = [B*F, N*H] context_layer = tf.reshape( context_layer, [batch_size * from_seq_length, num_attention_heads * size_per_head]) else: - # `context_layer` = [B, F, N*V] + # `context_layer` = [B, F, N*H] context_layer = tf.reshape( context_layer, [batch_size, from_seq_length, num_attention_heads * size_per_head])