diff --git a/official/nlp/keras_nlp/layers/transformer_encoder_block.py b/official/nlp/keras_nlp/layers/transformer_encoder_block.py index a518349878896abbc096d20c638b813d4ed0d16a..672bf080701956b15d9439621db7c73a85f0606f 100644 --- a/official/nlp/keras_nlp/layers/transformer_encoder_block.py +++ b/official/nlp/keras_nlp/layers/transformer_encoder_block.py @@ -116,9 +116,6 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): self._attention_initializer = self._kernel_initializer self._attention_axes = attention_axes - def _maybe_build(self, inputs): - super()._maybe_build(inputs[:1]) - def build(self, input_shape): if isinstance(input_shape, tf.TensorShape): input_tensor_shape = input_shape @@ -250,9 +247,6 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): [`query tensor`, `key value tensor`, `attention mask`] to have separate input streams for the query, and key/value to the multi-head attention. - [`query tensor`, `key value tensor`, `attention mask`, `pos_embed`] to - have an additional pos_embed that is added to the query and key of - every self-attention layer. Returns: An output tensor with the same dimensions as input/query tensor. @@ -261,18 +255,13 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): if len(inputs) == 2: input_tensor, attention_mask = inputs key_value = None - pos_embed = None elif len(inputs) == 3: input_tensor, key_value, attention_mask = inputs - pos_embed = None - elif len(inputs) == 4: - input_tensor, key_value, attention_mask, pos_embed = inputs else: raise ValueError("Unexpected inputs to %s with length at %d" % (self.__class__, len(inputs))) else: - input_tensor, key_value, attention_mask, pos_embed = (inputs, None, None, - None) + input_tensor, key_value, attention_mask = (inputs, None, None) if self._output_range: if self._norm_first: @@ -293,14 +282,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): if key_value is None: key_value = input_tensor - if pos_embed is None: - query = target_tensor - key = key_value - else: - query = target_tensor + pos_embed - key = key_value + pos_embed attention_output = self._attention_layer( - query=query, key=key, value=key_value, attention_mask=attention_mask) + query=target_tensor, value=key_value, attention_mask=attention_mask) attention_output = self._attention_dropout(attention_output) if self._norm_first: attention_output = source_tensor + attention_output diff --git a/official/nlp/modeling/layers/transformer.py b/official/nlp/modeling/layers/transformer.py index bed79ac773db68c940d8ec0dc2d54995eb096a38..05d8b60fbec1d3429e5c7cdfedb43b91cdb5a5cf 100644 --- a/official/nlp/modeling/layers/transformer.py +++ b/official/nlp/modeling/layers/transformer.py @@ -232,9 +232,6 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): else: self._cross_attention_cls = attention.MultiHeadAttention - def _maybe_build(self, inputs): - super()._maybe_build(inputs[:1]) - def build(self, input_shape): target_tensor_shape = tf.TensorShape(input_shape[0]) if len(target_tensor_shape.as_list()) != 3: @@ -373,57 +370,22 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): self.intermediate_dense, self.output_dense, self.output_layer_norm ] - def _parse_inputs(self, inputs, multi_channel_cross_attention): - if multi_channel_cross_attention: - if len(inputs) < 5: + def call(self, inputs, cache=None, decode_loop_step=None): + if self.multi_channel_cross_attention: + if len(inputs) != 5: raise ValueError( - "TransformerDecoderBlock must have at least 5 inputs, when it uses " + "TransformerDecoderBlock must have 5 inputs, when it uses " "multi_channel_cross_attention. But it got: %d" % len(inputs)) - elif len(inputs) == 5: - input_tensor, memory, attention_mask, self_attention_mask, context_attention_weights = inputs - input_pos_embed = None - memory_pos_embed = None - elif len(inputs) == 6: - input_tensor, memory, attention_mask, self_attention_mask, context_attention_weights, input_pos_embed = inputs - memory_pos_embed = None - else: - input_tensor, memory, attention_mask, self_attention_mask, context_attention_weights, input_pos_embed, memory_pos_embed = inputs[: - 7] - else: - context_attention_weights = None - if len(inputs) < 4: - raise ValueError( - "TransformerDecoderBlock must have at leaset 4 inputs, but it " - "got: %d" % len(inputs)) - elif len(inputs) == 4: - input_tensor, memory, attention_mask, self_attention_mask = inputs - input_pos_embed = None - memory_pos_embed = None - elif len(inputs) == 5: - input_tensor, memory, attention_mask, self_attention_mask, input_pos_embed = inputs - memory_pos_embed = None - else: - input_tensor, memory, attention_mask, self_attention_mask, input_pos_embed, memory_pos_embed = inputs[: - 6] - - return input_tensor, memory, attention_mask, self_attention_mask, context_attention_weights, input_pos_embed, memory_pos_embed - - def call(self, inputs, cache=None, decode_loop_step=None): - input_tensor, memory, attention_mask, self_attention_mask, context_attention_weights, input_pos_embed, memory_pos_embed = self._parse_inputs( - inputs, self.multi_channel_cross_attention) - + elif len(inputs) != 4: + raise ValueError( + "TransformerDecoderBlock must have 4 inputs, but it got: %d" % + len(inputs)) + input_tensor, memory, attention_mask, self_attention_mask = inputs[:4] source_tensor = input_tensor if self._norm_first: input_tensor = self.self_attention_layer_norm(input_tensor) - if input_pos_embed is None: - self_attn_query = input_tensor - self_attn_key = input_tensor - else: - self_attn_query = input_tensor + input_pos_embed - self_attn_key = input_tensor + input_pos_embed self_attention_output, cache = self.self_attention( - query=self_attn_query, - key=self_attn_key, + query=input_tensor, value=input_tensor, attention_mask=self_attention_mask, cache=cache, @@ -438,22 +400,13 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): source_self_attention_output = self_attention_output self_attention_output = self.encdec_attention_layer_norm( self_attention_output) - if input_pos_embed is None: - cross_attn_query = self_attention_output - else: - cross_attn_query = self_attention_output + input_pos_embed - if memory_pos_embed is None: - cross_attn_key = memory - else: - cross_attn_key = memory + memory_pos_embed cross_attn_inputs = dict( - query=cross_attn_query, - key=cross_attn_key, + query=self_attention_output, value=memory, attention_mask=attention_mask) if self.multi_channel_cross_attention: # Accesses the 5-th input tensor for the doc-attention probabilities. - cross_attn_inputs["context_attention_weights"] = context_attention_weights + cross_attn_inputs["context_attention_weights"] = inputs[-1] attention_output = self.encdec_attention(**cross_attn_inputs) attention_output = self.encdec_attention_dropout(attention_output) if self._norm_first: diff --git a/official/nlp/modeling/models/seq2seq_transformer.py b/official/nlp/modeling/models/seq2seq_transformer.py index d0ddaf5af85ba151f59b636aa3d0243fdc5ef68e..1099c603ac70fbb837c78bf87881415aa4f4a558 100644 --- a/official/nlp/modeling/models/seq2seq_transformer.py +++ b/official/nlp/modeling/models/seq2seq_transformer.py @@ -425,7 +425,7 @@ class TransformerEncoder(tf.keras.layers.Layer): base_config = super(TransformerEncoder, self).get_config() return dict(list(base_config.items()) + list(config.items())) - def call(self, encoder_inputs, attention_mask=None, pos_embed=None): + def call(self, encoder_inputs, attention_mask=None): """Return the output of the encoder. Args: @@ -433,17 +433,14 @@ class TransformerEncoder(tf.keras.layers.Layer): hidden_size)`. attention_mask: A mask for the encoder self-attention layer with shape `(batch_size, input_length, input_length)`. - pos_embed: A tensor or a float that is added to the query and key of every - self-attention layer. Defaults to None. Returns: Output of encoder which is a `float32` tensor with shape `(batch_size, input_length, hidden_size)`. """ - for layer_idx in range(self.num_layers): encoder_inputs = self.encoder_layers[layer_idx]( - [encoder_inputs, encoder_inputs, attention_mask, pos_embed]) + [encoder_inputs, attention_mask]) output_tensor = encoder_inputs output_tensor = self.output_normalization(output_tensor) @@ -522,7 +519,7 @@ class TransformerDecoder(tf.keras.layers.Layer): attention_initializer=attention_initializer(input_shape[2]), name=("layer_%d" % i))) self.output_normalization = tf.keras.layers.LayerNormalization( - epsilon=self._norm_epsilon, dtype="float32") + epsilon=1e-6, dtype="float32") super(TransformerDecoder, self).build(input_shape) def get_config(self): @@ -548,9 +545,7 @@ class TransformerDecoder(tf.keras.layers.Layer): cross_attention_mask=None, cache=None, decode_loop_step=None, - return_all_decoder_outputs=False, - input_pos_embed=None, - memory_pos_embed=None): + return_all_decoder_outputs=False): """Return the output of the decoder layer stacks. Args: @@ -570,10 +565,6 @@ class TransformerDecoder(tf.keras.layers.Layer): return_all_decoder_outputs: Return all decoder layer outputs. Note that the outputs are layer normed. This is useful when introducing per layer auxiliary loss. - input_pos_embed: A tensor or float that is added to the target embedding - in every self-attention and cross-attention layer. Defaults to None. - memory_pos_embed: A tensor or float that is added to the memory embedding - in every cross-attention layer. Defaults to None. Returns: Output of decoder. @@ -584,8 +575,7 @@ class TransformerDecoder(tf.keras.layers.Layer): decoder_outputs = [] for layer_idx in range(self.num_layers): transformer_inputs = [ - output_tensor, memory, cross_attention_mask, self_attention_mask, - input_pos_embed, memory_pos_embed + output_tensor, memory, cross_attention_mask, self_attention_mask ] # Gets the cache for decoding. if cache is None: