提交 fa211938 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 396035361
上级 b0707104
......@@ -116,9 +116,6 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._attention_initializer = self._kernel_initializer
self._attention_axes = attention_axes
def _maybe_build(self, inputs):
super()._maybe_build(inputs[:1])
def build(self, input_shape):
if isinstance(input_shape, tf.TensorShape):
input_tensor_shape = input_shape
......@@ -250,9 +247,6 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
[`query tensor`, `key value tensor`, `attention mask`] to have separate
input streams for the query, and key/value to the multi-head
attention.
[`query tensor`, `key value tensor`, `attention mask`, `pos_embed`] to
have an additional pos_embed that is added to the query and key of
every self-attention layer.
Returns:
An output tensor with the same dimensions as input/query tensor.
......@@ -261,18 +255,13 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
pos_embed = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
pos_embed = None
elif len(inputs) == 4:
input_tensor, key_value, attention_mask, pos_embed = inputs
else:
raise ValueError("Unexpected inputs to %s with length at %d" %
(self.__class__, len(inputs)))
else:
input_tensor, key_value, attention_mask, pos_embed = (inputs, None, None,
None)
input_tensor, key_value, attention_mask = (inputs, None, None)
if self._output_range:
if self._norm_first:
......@@ -293,14 +282,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
if key_value is None:
key_value = input_tensor
if pos_embed is None:
query = target_tensor
key = key_value
else:
query = target_tensor + pos_embed
key = key_value + pos_embed
attention_output = self._attention_layer(
query=query, key=key, value=key_value, attention_mask=attention_mask)
query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
attention_output = source_tensor + attention_output
......
......@@ -232,9 +232,6 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
else:
self._cross_attention_cls = attention.MultiHeadAttention
def _maybe_build(self, inputs):
super()._maybe_build(inputs[:1])
def build(self, input_shape):
target_tensor_shape = tf.TensorShape(input_shape[0])
if len(target_tensor_shape.as_list()) != 3:
......@@ -373,57 +370,22 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self.intermediate_dense, self.output_dense, self.output_layer_norm
]
def _parse_inputs(self, inputs, multi_channel_cross_attention):
if multi_channel_cross_attention:
if len(inputs) < 5:
def call(self, inputs, cache=None, decode_loop_step=None):
if self.multi_channel_cross_attention:
if len(inputs) != 5:
raise ValueError(
"TransformerDecoderBlock must have at least 5 inputs, when it uses "
"TransformerDecoderBlock must have 5 inputs, when it uses "
"multi_channel_cross_attention. But it got: %d" % len(inputs))
elif len(inputs) == 5:
input_tensor, memory, attention_mask, self_attention_mask, context_attention_weights = inputs
input_pos_embed = None
memory_pos_embed = None
elif len(inputs) == 6:
input_tensor, memory, attention_mask, self_attention_mask, context_attention_weights, input_pos_embed = inputs
memory_pos_embed = None
else:
input_tensor, memory, attention_mask, self_attention_mask, context_attention_weights, input_pos_embed, memory_pos_embed = inputs[:
7]
else:
context_attention_weights = None
if len(inputs) < 4:
raise ValueError(
"TransformerDecoderBlock must have at leaset 4 inputs, but it "
"got: %d" % len(inputs))
elif len(inputs) == 4:
input_tensor, memory, attention_mask, self_attention_mask = inputs
input_pos_embed = None
memory_pos_embed = None
elif len(inputs) == 5:
input_tensor, memory, attention_mask, self_attention_mask, input_pos_embed = inputs
memory_pos_embed = None
else:
input_tensor, memory, attention_mask, self_attention_mask, input_pos_embed, memory_pos_embed = inputs[:
6]
return input_tensor, memory, attention_mask, self_attention_mask, context_attention_weights, input_pos_embed, memory_pos_embed
def call(self, inputs, cache=None, decode_loop_step=None):
input_tensor, memory, attention_mask, self_attention_mask, context_attention_weights, input_pos_embed, memory_pos_embed = self._parse_inputs(
inputs, self.multi_channel_cross_attention)
elif len(inputs) != 4:
raise ValueError(
"TransformerDecoderBlock must have 4 inputs, but it got: %d" %
len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
source_tensor = input_tensor
if self._norm_first:
input_tensor = self.self_attention_layer_norm(input_tensor)
if input_pos_embed is None:
self_attn_query = input_tensor
self_attn_key = input_tensor
else:
self_attn_query = input_tensor + input_pos_embed
self_attn_key = input_tensor + input_pos_embed
self_attention_output, cache = self.self_attention(
query=self_attn_query,
key=self_attn_key,
query=input_tensor,
value=input_tensor,
attention_mask=self_attention_mask,
cache=cache,
......@@ -438,22 +400,13 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
source_self_attention_output = self_attention_output
self_attention_output = self.encdec_attention_layer_norm(
self_attention_output)
if input_pos_embed is None:
cross_attn_query = self_attention_output
else:
cross_attn_query = self_attention_output + input_pos_embed
if memory_pos_embed is None:
cross_attn_key = memory
else:
cross_attn_key = memory + memory_pos_embed
cross_attn_inputs = dict(
query=cross_attn_query,
key=cross_attn_key,
query=self_attention_output,
value=memory,
attention_mask=attention_mask)
if self.multi_channel_cross_attention:
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs["context_attention_weights"] = context_attention_weights
cross_attn_inputs["context_attention_weights"] = inputs[-1]
attention_output = self.encdec_attention(**cross_attn_inputs)
attention_output = self.encdec_attention_dropout(attention_output)
if self._norm_first:
......
......@@ -425,7 +425,7 @@ class TransformerEncoder(tf.keras.layers.Layer):
base_config = super(TransformerEncoder, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, encoder_inputs, attention_mask=None, pos_embed=None):
def call(self, encoder_inputs, attention_mask=None):
"""Return the output of the encoder.
Args:
......@@ -433,17 +433,14 @@ class TransformerEncoder(tf.keras.layers.Layer):
hidden_size)`.
attention_mask: A mask for the encoder self-attention layer with shape
`(batch_size, input_length, input_length)`.
pos_embed: A tensor or a float that is added to the query and key of every
self-attention layer. Defaults to None.
Returns:
Output of encoder which is a `float32` tensor with shape
`(batch_size, input_length, hidden_size)`.
"""
for layer_idx in range(self.num_layers):
encoder_inputs = self.encoder_layers[layer_idx](
[encoder_inputs, encoder_inputs, attention_mask, pos_embed])
[encoder_inputs, attention_mask])
output_tensor = encoder_inputs
output_tensor = self.output_normalization(output_tensor)
......@@ -522,7 +519,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
attention_initializer=attention_initializer(input_shape[2]),
name=("layer_%d" % i)))
self.output_normalization = tf.keras.layers.LayerNormalization(
epsilon=self._norm_epsilon, dtype="float32")
epsilon=1e-6, dtype="float32")
super(TransformerDecoder, self).build(input_shape)
def get_config(self):
......@@ -548,9 +545,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
cross_attention_mask=None,
cache=None,
decode_loop_step=None,
return_all_decoder_outputs=False,
input_pos_embed=None,
memory_pos_embed=None):
return_all_decoder_outputs=False):
"""Return the output of the decoder layer stacks.
Args:
......@@ -570,10 +565,6 @@ class TransformerDecoder(tf.keras.layers.Layer):
return_all_decoder_outputs: Return all decoder layer outputs.
Note that the outputs are layer normed.
This is useful when introducing per layer auxiliary loss.
input_pos_embed: A tensor or float that is added to the target embedding
in every self-attention and cross-attention layer. Defaults to None.
memory_pos_embed: A tensor or float that is added to the memory embedding
in every cross-attention layer. Defaults to None.
Returns:
Output of decoder.
......@@ -584,8 +575,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
decoder_outputs = []
for layer_idx in range(self.num_layers):
transformer_inputs = [
output_tensor, memory, cross_attention_mask, self_attention_mask,
input_pos_embed, memory_pos_embed
output_tensor, memory, cross_attention_mask, self_attention_mask
]
# Gets the cache for decoding.
if cache is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册