提交 8c6df641 编写于 作者: A Allen Wang 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 341642152
上级 f858af81
......@@ -695,40 +695,34 @@ def postprocess_output(all_examples,
null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0]
start_indexes_and_logits = _get_best_indexes_and_logits(
result=result,
n_best_size=n_best_size,
start=True,
xlnet_format=xlnet_format)
end_indexes_and_logits = _get_best_indexes_and_logits(
result=result,
n_best_size=n_best_size,
start=False,
xlnet_format=xlnet_format)
doc_offset = 0 if xlnet_format else feature.tokens.index("[SEP]") + 1
for start_index, start_logit in start_indexes_and_logits:
for end_index, end_logit in end_indexes_and_logits:
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
if start_index - doc_offset >= len(feature.tok_start_to_orig_index):
continue
if end_index - doc_offset >= len(feature.tok_end_to_orig_index):
continue
if not feature.token_is_max_context.get(start_index, False):
continue
if end_index < start_index:
continue
length = end_index - start_index + 1
if length > max_answer_length:
continue
prelim_predictions.append(
_PrelimPrediction(
feature_index=feature_index,
start_index=start_index - doc_offset,
end_index=end_index - doc_offset,
start_logit=start_logit,
end_logit=end_logit))
for (start_index, start_logit,
end_index, end_logit) in _get_best_indexes_and_logits(
result=result,
n_best_size=n_best_size,
xlnet_format=xlnet_format):
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
if start_index - doc_offset >= len(feature.tok_start_to_orig_index):
continue
if end_index - doc_offset >= len(feature.tok_end_to_orig_index):
continue
if not feature.token_is_max_context.get(start_index, False):
continue
if end_index < start_index:
continue
length = end_index - start_index + 1
if length > max_answer_length:
continue
prelim_predictions.append(
_PrelimPrediction(
feature_index=feature_index,
start_index=start_index - doc_offset,
end_index=end_index - doc_offset,
start_logit=start_logit,
end_logit=end_logit))
if version_2_with_negative and not xlnet_format:
prelim_predictions.append(
......@@ -752,7 +746,7 @@ def postprocess_output(all_examples,
if len(nbest) >= n_best_size:
break
feature = features[pred.feature_index]
if pred.start_index >= 0: # this is a non-null prediction
if pred.start_index >= 0 or xlnet_format: # this is a non-null prediction
tok_start_to_orig_index = feature.tok_start_to_orig_index
tok_end_to_orig_index = feature.tok_end_to_orig_index
start_orig_pos = tok_start_to_orig_index[pred.start_index]
......@@ -774,7 +768,7 @@ def postprocess_output(all_examples,
start_logit=pred.start_logit,
end_logit=pred.end_logit))
# if we didn't inlude the empty option in the n-best, inlcude it
# if we didn't inlude the empty option in the n-best, include it
if version_2_with_negative and not xlnet_format:
if "" not in seen_predictions:
nbest.append(
......@@ -814,14 +808,19 @@ def postprocess_output(all_examples,
all_predictions[example.qas_id] = nbest_json[0]["text"]
else:
assert best_non_null_entry is not None
# predict "" iff the null score - the score of best non-null > threshold
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
if xlnet_format:
score_diff = score_null
scores_diff_json[example.qas_id] = score_diff
all_predictions[example.qas_id] = best_non_null_entry.text
else:
# predict "" iff the null score - the score of best non-null > threshold
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text
all_nbest_json[example.qas_id] = nbest_json
......@@ -835,28 +834,27 @@ def write_to_json_files(json_records, json_file):
def _get_best_indexes_and_logits(result,
n_best_size,
start=False,
xlnet_format=False):
"""Generates the n-best indexes and logits from a list."""
if xlnet_format:
for i in range(n_best_size):
for j in range(n_best_size):
j_index = i * n_best_size + j
if start:
yield result.start_indexes[i], result.start_logits[i]
else:
yield result.end_indexes[j_index], result.end_logits[j_index]
yield (result.start_indexes[i], result.start_logits[i],
result.end_indexes[j_index], result.end_logits[j_index])
else:
if start:
logits = result.start_logits
else:
logits = result.end_logits
index_and_score = sorted(enumerate(logits),
key=lambda x: x[1], reverse=True)
for i in range(len(index_and_score)):
start_index_and_score = sorted(enumerate(result.start_logits),
key=lambda x: x[1], reverse=True)
end_index_and_score = sorted(enumerate(result.end_logits),
key=lambda x: x[1], reverse=True)
for i in range(len(start_index_and_score)):
if i >= n_best_size:
break
yield index_and_score[i]
for j in range(len(end_index_and_score)):
if j >= n_best_size:
break
yield (start_index_and_score[i][0], start_index_and_score[i][1],
end_index_and_score[j][0], end_index_and_score[j][1])
def _compute_softmax(scores):
......@@ -885,13 +883,12 @@ def _compute_softmax(scores):
class FeatureWriter(object):
"""Writes InputFeature to TF example file."""
def __init__(self, filename, is_training, xlnet_format=False):
def __init__(self, filename, is_training):
self.filename = filename
self.is_training = is_training
self.num_features = 0
tf.io.gfile.makedirs(os.path.dirname(filename))
self._writer = tf.io.TFRecordWriter(filename)
self._xlnet_format = xlnet_format
def process_feature(self, feature):
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
......@@ -907,8 +904,9 @@ class FeatureWriter(object):
features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
if self._xlnet_format:
if feature.paragraph_mask:
features["paragraph_mask"] = create_int_feature(feature.paragraph_mask)
if feature.class_index:
features["class_index"] = create_int_feature([feature.class_index])
if self.is_training:
......@@ -943,7 +941,7 @@ def generate_tf_record_from_json_file(input_file_path,
tokenizer = tokenization.FullSentencePieceTokenizer(
sp_model_file=sp_model_file)
train_writer = FeatureWriter(
filename=output_path, is_training=True, xlnet_format=xlnet_format)
filename=output_path, is_training=True)
number_of_examples = convert_examples_to_features(
examples=train_examples,
tokenizer=tokenizer,
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet cls-token classifier."""
"""XLNet models."""
# pylint: disable=g-classes-have-attributes
from typing import Any, Mapping, Union
......@@ -127,7 +127,7 @@ class XLNetSpanLabeler(tf.keras.Model):
start_n_top: Beam size for span start.
end_n_top: Beam size for span end.
dropout_rate: The dropout rate for the span labeling layer.
span_labeling_activation
span_labeling_activation: The activation for the span labeling head.
initializer: The initializer (if any) to use in the span labeling network.
Defaults to a Glorot uniform initializer.
"""
......@@ -135,9 +135,9 @@ class XLNetSpanLabeler(tf.keras.Model):
def __init__(
self,
network: Union[tf.keras.layers.Layer, tf.keras.Model],
start_n_top: int,
end_n_top: int,
dropout_rate: float,
start_n_top: int = 5,
end_n_top: int = 5,
dropout_rate: float = 0.1,
span_labeling_activation: tf.keras.initializers.Initializer = 'tanh',
initializer: tf.keras.initializers.Initializer = 'glorot_uniform',
**kwargs):
......@@ -165,24 +165,27 @@ class XLNetSpanLabeler(tf.keras.Model):
initializer=self._initializer)
def call(self, inputs: Mapping[str, Any]):
input_ids = inputs['input_ids']
segment_ids = inputs['segment_ids']
input_ids = inputs['input_word_ids']
segment_ids = inputs['input_type_ids']
input_mask = inputs['input_mask']
class_index = inputs['class_index']
paragraph_mask = inputs['paragraph_mask']
start_positions = inputs.get('start_positions', None)
class_index = tf.reshape(inputs['class_index'], [-1])
position_mask = inputs['position_mask']
start_positions = inputs['start_positions']
attention_output, new_states = self._network(
attention_output, _ = self._network(
input_ids=input_ids,
segment_ids=segment_ids,
input_mask=input_mask)
outputs = self.span_labeling(
sequence_data=attention_output,
class_index=class_index,
position_mask=position_mask,
paragraph_mask=paragraph_mask,
start_positions=start_positions)
return outputs, new_states
return outputs
@property
def checkpoint_items(self):
return dict(encoder=self._network)
def get_config(self):
return self._config
......
......@@ -137,9 +137,9 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes
class XLNetSpanLabelerTest(keras_parameterized.TestCase):
@parameterized.parameters(1, 2)
def test_xlnet_trainer(self, top_n):
def test_xlnet_trainer(self):
"""Validate that the Keras object can be created."""
top_n = 2
seq_length = 4
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base = _get_xlnet_base()
......@@ -153,46 +153,50 @@ class XLNetSpanLabelerTest(keras_parameterized.TestCase):
span_labeling_activation='tanh',
dropout_rate=0.1)
inputs = dict(
input_ids=tf.keras.layers.Input(
input_word_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_word_ids'),
segment_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='segment_ids'),
input_type_ids=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.int32, name='input_type_ids'),
input_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='input_mask'),
position_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='position_mask'),
paragraph_mask=tf.keras.layers.Input(
shape=(seq_length,), dtype=tf.float32, name='paragraph_mask'),
class_index=tf.keras.layers.Input(
shape=(), dtype=tf.int32, name='class_index'),
start_positions=tf.keras.layers.Input(
shape=(), dtype=tf.int32, name='start_positions'))
outputs, _ = xlnet_trainer_model(inputs)
outputs = xlnet_trainer_model(inputs)
self.assertIsInstance(outputs, dict)
# Test tensor value calls for the created model.
batch_size = 2
sequence_shape = (batch_size, seq_length)
inputs = dict(
input_ids=np.random.randint(10, size=sequence_shape, dtype='int32'),
segment_ids=np.random.randint(2, size=sequence_shape, dtype='int32'),
input_word_ids=np.random.randint(
10, size=sequence_shape, dtype='int32'),
input_type_ids=np.random.randint(2, size=sequence_shape, dtype='int32'),
input_mask=np.random.randint(2, size=sequence_shape).astype('float32'),
position_mask=np.random.randint(
paragraph_mask=np.random.randint(
1, size=(sequence_shape)).astype('float32'),
class_index=np.random.randint(1, size=(batch_size)).astype('uint8'),
start_positions=tf.random.uniform(
shape=(batch_size,), maxval=5, dtype=tf.int32))
outputs, _ = xlnet_trainer_model(inputs)
expected_inference_keys = {
'start_top_log_probs', 'end_top_log_probs', 'class_logits',
'start_top_index', 'end_top_index',
common_keys = {
'start_logits', 'end_logits', 'start_predictions', 'end_predictions',
'class_logits',
}
self.assertSetEqual(expected_inference_keys, set(outputs.keys()))
inference_keys = {
'start_top_predictions', 'end_top_predictions', 'start_top_index',
'end_top_index',
}
outputs = xlnet_trainer_model(inputs)
self.assertSetEqual(common_keys | inference_keys, set(outputs.keys()))
outputs, _ = xlnet_trainer_model(inputs, training=True)
outputs = xlnet_trainer_model(inputs, training=True)
self.assertIsInstance(outputs, dict)
expected_train_keys = {
'start_log_probs', 'end_log_probs', 'class_logits'
}
self.assertSetEqual(expected_train_keys, set(outputs.keys()))
self.assertSetEqual(common_keys, set(outputs.keys()))
self.assertIsInstance(outputs, dict)
def test_serialize_deserialize(self):
......
......@@ -18,11 +18,9 @@ import collections
import tensorflow as tf
def _apply_position_mask(logits, position_mask):
def _apply_paragraph_mask(logits, paragraph_mask):
"""Applies a position mask to calculated logits."""
if tf.rank(logits) != tf.rank(position_mask):
position_mask = position_mask[:, None, :]
masked_logits = logits * (1 - position_mask) - 1e30 * position_mask
masked_logits = logits * (paragraph_mask) - 1e30 * (1 - paragraph_mask)
return tf.nn.log_softmax(masked_logits, -1), masked_logits
......@@ -137,8 +135,8 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
def __init__(self,
input_width,
start_n_top,
end_n_top,
start_n_top=5,
end_n_top=5,
activation='tanh',
dropout_rate=0.,
initializer='glorot_uniform',
......@@ -152,6 +150,8 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
'end_n_top': end_n_top,
'dropout_rate': dropout_rate,
}
if start_n_top <= 1:
raise ValueError('`start_n_top` must be greater than 1.')
self._start_n_top = start_n_top
self._end_n_top = end_n_top
self.start_logits_dense = tf.keras.layers.Dense(
......@@ -210,16 +210,12 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
end_logits = self.end_logits_layer_norm(end_logits)
end_logits = self.end_logits_output_dense(end_logits)
end_logits = tf.squeeze(end_logits)
if tf.rank(end_logits) > 2:
# shape = [B, S, K] -> [B, K, S]
end_logits = tf.transpose(end_logits, [0, 2, 1])
return end_logits
def call(self,
sequence_data,
class_index,
position_mask=None,
paragraph_mask=None,
start_positions=None,
training=False):
"""Implements call().
......@@ -234,31 +230,35 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
sequence_data: The input sequence data of shape
(batch_size, seq_length, input_width).
class_index: The class indices of the inputs of shape (batch_size,).
position_mask: Invalid position mask such as query and special symbols
paragraph_mask: Invalid position mask such as query and special symbols
(e.g. PAD, SEP, CLS) of shape (batch_size,).
start_positions: The start positions of each example of shape
(batch_size,).
training: Whether or not this is the training phase.
Returns:
A dictionary with the keys 'cls_logits' and
- (if training) 'start_log_probs', 'end_log_probs'.
- (if inference/beam search) 'start_top_log_probs', 'start_top_index',
'end_top_log_probs', 'end_top_index'.
A dictionary with the keys 'start_predictions', 'end_predictions',
'start_logits', 'end_logits'.
If inference, then 'start_top_predictions', 'start_top_index',
'end_top_predictions', 'end_top_index' are also included.
"""
paragraph_mask = tf.cast(paragraph_mask, dtype=sequence_data.dtype)
class_index = tf.reshape(class_index, [-1])
seq_length = tf.shape(sequence_data)[1]
start_logits = self.start_logits_dense(sequence_data)
start_logits = tf.squeeze(start_logits, -1)
start_log_probs, masked_start_logits = _apply_position_mask(
start_logits, position_mask)
start_predictions, masked_start_logits = _apply_paragraph_mask(
start_logits, paragraph_mask)
compute_with_beam_search = not training or start_positions is None
if compute_with_beam_search:
# Compute end logits using beam search.
start_top_log_probs, start_top_index = tf.nn.top_k(
start_log_probs, k=self._start_n_top)
start_top_predictions, start_top_index = tf.nn.top_k(
start_predictions, k=self._start_n_top)
start_index = tf.one_hot(
start_top_index, depth=seq_length, axis=-1, dtype=tf.float32)
# start_index: [batch_size, end_n_top, seq_length]
......@@ -272,8 +272,13 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
[1, 1, self._start_n_top, 1])
end_input = tf.concat([end_input, start_features], axis=-1)
# end_input: [batch_size, seq_length, end_n_top, 2*input_width]
paragraph_mask = paragraph_mask[:, None, :]
end_logits = self.end_logits(end_input)
# Note: this will fail if start_n_top is not >= 1.
end_logits = tf.transpose(end_logits, [0, 2, 1])
else:
start_positions = tf.reshape(start_positions, -1)
start_positions = tf.reshape(start_positions, [-1])
start_index = tf.one_hot(
start_positions, depth=seq_length, axis=-1, dtype=tf.float32)
# start_index: [batch_size, seq_length]
......@@ -285,24 +290,28 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
end_input = tf.concat([sequence_data, start_features],
axis=-1)
# end_input: [batch_size, seq_length, 2*input_width]
end_logits = self.end_logits(end_input)
end_log_probs, _ = _apply_position_mask(end_logits, position_mask)
output_dict = {}
if training:
output_dict['start_log_probs'] = start_log_probs
output_dict['end_log_probs'] = end_log_probs
else:
end_top_log_probs, end_top_index = tf.nn.top_k(
end_log_probs, k=self._end_n_top)
end_top_log_probs = tf.reshape(end_top_log_probs,
[-1, self._start_n_top * self._end_n_top])
end_top_index = tf.reshape(end_top_index,
[-1, self._start_n_top * self._end_n_top])
output_dict['start_top_log_probs'] = start_top_log_probs
end_logits = self.end_logits(end_input)
end_predictions, masked_end_logits = _apply_paragraph_mask(
end_logits, paragraph_mask)
output_dict = dict(
start_predictions=start_predictions,
end_predictions=end_predictions,
start_logits=masked_start_logits,
end_logits=masked_end_logits)
if not training:
end_top_predictions, end_top_index = tf.nn.top_k(
end_predictions, k=self._end_n_top)
end_top_predictions = tf.reshape(
end_top_predictions,
[-1, self._start_n_top * self._end_n_top])
end_top_index = tf.reshape(
end_top_index,
[-1, self._start_n_top * self._end_n_top])
output_dict['start_top_predictions'] = start_top_predictions
output_dict['start_top_index'] = start_top_index
output_dict['end_top_log_probs'] = end_top_log_probs
output_dict['end_top_predictions'] = end_top_predictions
output_dict['end_top_index'] = end_top_index
# get the representation of CLS
......
......@@ -13,13 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Tests for span_labeling network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
......@@ -181,39 +174,38 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
hidden_size = 4
sequence_data = np.random.uniform(
size=(batch_size, seq_length, hidden_size)).astype('float32')
position_mask = np.random.uniform(
paragraph_mask = np.random.uniform(
size=(batch_size, seq_length)).astype('float32')
class_index = np.random.uniform(size=(batch_size)).astype('uint8')
start_positions = np.zeros(shape=(batch_size)).astype('uint8')
layer = span_labeling.XLNetSpanLabeling(
input_width=hidden_size,
start_n_top=1,
end_n_top=1,
start_n_top=2,
end_n_top=2,
activation='tanh',
dropout_rate=0.,
initializer='glorot_uniform')
output = layer(sequence_data=sequence_data,
class_index=class_index,
position_mask=position_mask,
paragraph_mask=paragraph_mask,
start_positions=start_positions,
training=True)
expected_keys = {
'start_log_probs', 'end_log_probs', 'class_logits',
'start_logits', 'end_logits', 'class_logits', 'start_predictions',
'end_predictions',
}
self.assertSetEqual(expected_keys, set(output.keys()))
@parameterized.named_parameters(
('top_1', 1),
('top_n', 5))
def test_basic_invocation_beam_search(self, top_n):
def test_basic_invocation_beam_search(self):
batch_size = 2
seq_length = 8
hidden_size = 4
top_n = 5
sequence_data = np.random.uniform(
size=(batch_size, seq_length, hidden_size)).astype('float32')
position_mask = np.random.uniform(
paragraph_mask = np.random.uniform(
size=(batch_size, seq_length)).astype('float32')
class_index = np.random.uniform(size=(batch_size)).astype('uint8')
......@@ -226,11 +218,12 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
initializer='glorot_uniform')
output = layer(sequence_data=sequence_data,
class_index=class_index,
position_mask=position_mask,
paragraph_mask=paragraph_mask,
training=False)
expected_keys = {
'start_top_log_probs', 'end_top_log_probs', 'class_logits',
'start_top_index', 'end_top_index',
'start_top_predictions', 'end_top_predictions', 'class_logits',
'start_top_index', 'end_top_index', 'start_logits',
'end_logits', 'start_predictions', 'end_predictions'
}
self.assertSetEqual(expected_keys, set(output.keys()))
......@@ -243,7 +236,7 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
sequence_data = tf.keras.Input(shape=(seq_length, hidden_size),
dtype=tf.float32)
class_index = tf.keras.Input(shape=(), dtype=tf.uint8)
position_mask = tf.keras.Input(shape=(seq_length), dtype=tf.float32)
paragraph_mask = tf.keras.Input(shape=(seq_length), dtype=tf.float32)
start_positions = tf.keras.Input(shape=(), dtype=tf.int32)
layer = span_labeling.XLNetSpanLabeling(
......@@ -256,27 +249,27 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
output = layer(sequence_data=sequence_data,
class_index=class_index,
position_mask=position_mask,
paragraph_mask=paragraph_mask,
start_positions=start_positions)
model = tf.keras.Model(
inputs={
'sequence_data': sequence_data,
'class_index': class_index,
'position_mask': position_mask,
'paragraph_mask': paragraph_mask,
'start_positions': start_positions,
},
outputs=output)
sequence_data = tf.random.uniform(
shape=(batch_size, seq_length, hidden_size), dtype=tf.float32)
position_mask = tf.random.uniform(
paragraph_mask = tf.random.uniform(
shape=(batch_size, seq_length), dtype=tf.float32)
class_index = tf.ones(shape=(batch_size,), dtype=tf.uint8)
start_positions = tf.random.uniform(
shape=(batch_size,), maxval=5, dtype=tf.int32)
inputs = dict(sequence_data=sequence_data,
position_mask=position_mask,
paragraph_mask=paragraph_mask,
class_index=class_index,
start_positions=start_positions)
......
......@@ -629,6 +629,7 @@ class XLNetBase(tf.keras.layers.Layer):
"enabled. Please enable `two_stream` to enable two "
"stream attention.")
dtype = input_mask.dtype if input_mask is not None else tf.float32
query_attention_mask, content_attention_mask = _compute_attention_mask(
input_mask=input_mask,
permutation_mask=permutation_mask,
......@@ -636,7 +637,7 @@ class XLNetBase(tf.keras.layers.Layer):
seq_length=seq_length,
memory_length=memory_length,
batch_size=batch_size,
dtype=tf.float32)
dtype=dtype)
relative_position_encoding = _compute_positional_encoding(
attention_type=self._attention_type,
position_encoding_layer=self.position_encoding,
......
......@@ -14,9 +14,9 @@
# limitations under the License.
# ==============================================================================
"""Question answering task."""
import collections
import json
import os
from typing import List, Optional
from absl import logging
import dataclasses
......@@ -58,6 +58,17 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
validation_data: cfg.DataConfig = cfg.DataConfig()
@dataclasses.dataclass
class RawAggregatedResult:
"""Raw representation for SQuAD predictions."""
unique_id: int
start_logits: List[float]
end_logits: List[float]
start_indexes: Optional[List[int]] = None
end_indexes: Optional[List[int]] = None
class_logits: Optional[float] = None
@task_factory.register_task_cls(QuestionAnsweringConfig)
class QuestionAnsweringTask(base_task.Task):
"""Task object for question answering."""
......@@ -91,7 +102,6 @@ class QuestionAnsweringTask(base_task.Task):
else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
encoder_cfg = self.task_config.model.encoder.get()
# Currently, we only supports bert-style question answering finetuning.
return models.BertSpanLabeler(
network=encoder_network,
initializer=tf.keras.initializers.TruncatedNormal(
......@@ -147,6 +157,7 @@ class QuestionAnsweringTask(base_task.Task):
kwargs['do_lower_case'] = params.do_lower_case
kwargs['tokenizer'] = tokenization.FullSentencePieceTokenizer(
sp_model_file=params.vocab_file)
kwargs['xlnet_format'] = self.task_config.model.encoder.type == 'xlnet'
elif params.tokenization == 'WordPiece':
kwargs['tokenizer'] = tokenization.FullTokenizer(
vocab_file=params.vocab_file, do_lower_case=params.do_lower_case)
......@@ -176,7 +187,8 @@ class QuestionAnsweringTask(base_task.Task):
input_type_ids=dummy_ids)
y = dict(
start_positions=tf.constant(0, dtype=tf.int32),
end_positions=tf.constant(1, dtype=tf.int32))
end_positions=tf.constant(1, dtype=tf.int32),
is_impossible=tf.constant(0, dtype=tf.int32))
return (x, y)
dataset = tf.data.Dataset.range(1)
......@@ -235,25 +247,22 @@ class QuestionAnsweringTask(base_task.Task):
}
return logs
raw_aggregated_result = collections.namedtuple(
'RawResult', ['unique_id', 'start_logits', 'end_logits'])
def aggregate_logs(self, state=None, step_outputs=None):
assert step_outputs is not None, 'Got no logs from self.validation_step.'
if state is None:
state = []
for unique_ids, start_logits, end_logits in zip(
step_outputs['unique_ids'], step_outputs['start_logits'],
step_outputs['end_logits']):
u_ids, s_logits, e_logits = (unique_ids.numpy(), start_logits.numpy(),
end_logits.numpy())
for values in zip(u_ids, s_logits, e_logits):
state.append(
self.raw_aggregated_result(
unique_id=values[0],
start_logits=values[1].tolist(),
end_logits=values[2].tolist()))
for outputs in zip(step_outputs['unique_ids'],
step_outputs['start_logits'],
step_outputs['end_logits']):
numpy_values = [
output.numpy() for output in outputs if output is not None]
for values in zip(*numpy_values):
state.append(RawAggregatedResult(
unique_id=values[0],
start_logits=values[1],
end_logits=values[2]))
return state
def reduce_aggregated_logs(self, aggregated_logs):
......@@ -299,6 +308,127 @@ class QuestionAnsweringTask(base_task.Task):
return eval_metrics
@dataclasses.dataclass
class XLNetQuestionAnsweringConfig(QuestionAnsweringConfig):
"""The config for the XLNet variation of QuestionAnswering."""
pass
@task_factory.register_task_cls(XLNetQuestionAnsweringConfig)
class XLNetQuestionAnsweringTask(QuestionAnsweringTask):
"""XLNet variant of the Question Answering Task.
The main differences include:
- The encoder is an `XLNetBase` class.
- The `SpanLabeling` head is an instance of `XLNetSpanLabeling` which
predicts start/end positions and impossibility score. During inference,
it predicts the top N scores and indexes.
"""
def build_model(self):
if self.task_config.hub_module_url and self.task_config.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if self.task_config.hub_module_url:
encoder_network = utils.get_encoder_from_hub(
self.task_config.hub_module_url)
else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
encoder_cfg = self.task_config.model.encoder.get()
return models.XLNetSpanLabeler(
network=encoder_network,
start_n_top=self.task_config.n_best_size,
end_n_top=self.task_config.n_best_size,
initializer=tf.keras.initializers.RandomNormal(
stddev=encoder_cfg.initializer_range))
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
start_positions = labels['start_positions']
end_positions = labels['end_positions']
is_impossible = labels['is_impossible']
is_impossible = tf.cast(tf.reshape(is_impossible, [-1]), tf.float32)
start_logits = model_outputs['start_logits']
end_logits = model_outputs['end_logits']
class_logits = model_outputs['class_logits']
start_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
start_positions, start_logits)
end_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
end_positions, end_logits)
is_impossible_loss = tf.keras.losses.binary_crossentropy(
is_impossible, class_logits, from_logits=True)
loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
loss += tf.reduce_mean(is_impossible_loss) / 2
return loss
def process_metrics(self, metrics, labels, model_outputs):
metrics = dict([(metric.name, metric) for metric in metrics])
start_logits = model_outputs['start_logits']
end_logits = model_outputs['end_logits']
metrics['start_position_accuracy'].update_state(labels['start_positions'],
start_logits)
metrics['end_position_accuracy'].update_state(labels['end_positions'],
end_logits)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
start_logits = model_outputs['start_logits']
end_logits = model_outputs['end_logits']
compiled_metrics.update_state(
y_true=labels, # labels has keys 'start_positions' and 'end_positions'.
y_pred={
'start_positions': start_logits,
'end_positions': end_logits,
})
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
features, _ = inputs
unique_ids = features.pop('unique_ids')
model_outputs = self.inference_step(features, model)
start_top_predictions = model_outputs['start_top_predictions']
end_top_predictions = model_outputs['end_top_predictions']
start_indexes = model_outputs['start_top_index']
end_indexes = model_outputs['end_top_index']
class_logits = model_outputs['class_logits']
logs = {
self.loss: 0.0, # TODO(lehou): compute the real validation loss.
'unique_ids': unique_ids,
'start_top_predictions': start_top_predictions,
'end_top_predictions': end_top_predictions,
'start_indexes': start_indexes,
'end_indexes': end_indexes,
'class_logits': class_logits,
}
return logs
def aggregate_logs(self, state=None, step_outputs=None):
assert step_outputs is not None, 'Got no logs from self.validation_step.'
if state is None:
state = []
for outputs in zip(step_outputs['unique_ids'],
step_outputs['start_top_predictions'],
step_outputs['end_top_predictions'],
step_outputs['start_indexes'],
step_outputs['end_indexes'],
step_outputs['class_logits']):
numpy_values = [
output.numpy() for output in outputs]
for (unique_id, start_top_predictions, end_top_predictions, start_indexes,
end_indexes, class_logits) in zip(*numpy_values):
state.append(RawAggregatedResult(
unique_id=unique_id,
start_logits=start_top_predictions.tolist(),
end_logits=end_top_predictions.tolist(),
start_indexes=start_indexes.tolist(),
end_indexes=end_indexes.tolist(),
class_logits=class_logits))
return state
def predict(task: QuestionAnsweringTask, params: cfg.DataConfig,
model: tf.keras.Model):
"""Predicts on the input data.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册