提交 e93afea8 编写于 作者: C Chen Chen 提交者: A. Unique TensorFlower

Support to run prediction on question answering (SQuAD) task.

PiperOrigin-RevId: 324703765
上级 7ebdee5f
......@@ -17,8 +17,10 @@
import collections
import json
import os
from absl import logging
import dataclasses
import orbit
import tensorflow as tf
import tensorflow_hub as hub
......@@ -84,6 +86,10 @@ class QuestionAnsweringTask(base_task.Task):
self._tf_record_input_path, self._eval_examples, self._eval_features = (
self._preprocess_eval_data(params.validation_data))
def set_preprocessed_eval_input_path(self, eval_input_path):
"""Sets the path to the preprocessed eval data."""
self._tf_record_input_path = eval_input_path
def build_model(self):
if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module)
......@@ -242,10 +248,6 @@ class QuestionAnsweringTask(base_task.Task):
step_outputs['end_logits']):
u_ids, s_logits, e_logits = (
unique_ids.numpy(), start_logits.numpy(), end_logits.numpy())
if u_ids.size == 1:
u_ids = [u_ids]
s_logits = [s_logits]
e_logits = [e_logits]
for values in zip(u_ids, s_logits, e_logits):
state.append(self.raw_aggregated_result(
unique_id=values[0],
......@@ -291,3 +293,46 @@ class QuestionAnsweringTask(base_task.Task):
eval_metrics = {'exact_match': eval_metrics['exact_match'],
'final_f1': eval_metrics['final_f1']}
return eval_metrics
def predict(task: QuestionAnsweringTask, params: cfg.DataConfig,
model: tf.keras.Model):
"""Predicts on the input data.
Args:
task: A `QuestionAnsweringTask` object.
params: A `cfg.DataConfig` object.
model: A keras.Model.
Returns:
A tuple of `all_predictions`, `all_nbest` and `scores_diff`, which
are dict and can be written to json files including prediction json file,
nbest json file and null_odds json file.
"""
tf_record_input_path, eval_examples, eval_features = (
task._preprocess_eval_data(params)) # pylint: disable=protected-access
# `tf_record_input_path` will overwrite `params.input_path`,
# when `task.buid_inputs()` is called.
task.set_preprocessed_eval_input_path(tf_record_input_path)
def predict_step(inputs):
"""Replicated prediction calculation."""
return task.validation_step(inputs, model)
dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
task.build_inputs, params)
aggregated_outputs = utils.predict(predict_step, task.aggregate_logs, dataset)
all_predictions, all_nbest, scores_diff = (
task.squad_lib.postprocess_output(
eval_examples,
eval_features,
aggregated_outputs,
task.task_config.n_best_size,
task.task_config.max_answer_length,
task.task_config.validation_data.do_lower_case,
version_2_with_negative=(params.version_2_with_negative),
null_score_diff_threshold=task.task_config.null_score_diff_threshold,
verbose=False))
return all_predictions, all_nbest, scores_diff
......@@ -81,6 +81,8 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
val_dataset = task.build_inputs(config.validation_data)
val_iterator = iter(val_dataset)
logs = task.validation_step(next(val_iterator), model, metrics=metrics)
# Mock that `logs` is from one replica.
logs = {x: (logs[x],) for x in logs}
logs = task.aggregate_logs(step_outputs=logs)
metrics = task.reduce_aggregated_logs(logs)
self.assertIn("final_f1", metrics)
......@@ -160,6 +162,27 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
validation_data=self._get_validation_data_config())
self._run_task(config)
@parameterized.named_parameters(("squad1", False), ("squad2", True))
def test_predict(self, version_2_with_negative):
validation_data = self._get_validation_data_config(
version_2_with_negative=version_2_with_negative)
config = question_answering.QuestionAnsweringConfig(
model=question_answering.ModelConfig(encoder=self._encoder_config),
train_data=self._train_data_config,
validation_data=validation_data)
task = question_answering.QuestionAnsweringTask(config)
model = task.build_model()
all_predictions, all_nbest, scores_diff = question_answering.predict(
task, validation_data, model)
self.assertLen(all_predictions, 1)
self.assertLen(all_nbest, 1)
if version_2_with_negative:
self.assertLen(scores_diff, 1)
else:
self.assertEmpty(scores_diff)
if __name__ == "__main__":
tf.test.main()
......@@ -245,34 +245,25 @@ def predict(task: SentencePredictionTask, params: cfg.DataConfig,
"""
is_regression = task.task_config.model.num_classes == 1
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
def _replicated_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
outputs = task.inference_step(x, model)
if is_regression:
return outputs
else:
return tf.argmax(outputs, axis=-1)
outputs = tf.distribute.get_strategy().run(
_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(
tf.distribute.get_strategy().experimental_local_results, outputs)
def reduce_fn(state, outputs):
def predict_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
outputs = task.inference_step(x, model)
if is_regression:
return outputs
else:
return tf.argmax(outputs, axis=-1)
def aggregate_fn(state, outputs):
"""Concatenates model's outputs."""
if state is None:
state = {'predictions': []}
for per_replica_batch_predictions in outputs:
state.extend(per_replica_batch_predictions)
state['predictions'].extend(per_replica_batch_predictions)
return state
loop_fn = orbit.utils.create_loop_fn(predict_step)
dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
task.build_inputs, params)
# Set `num_steps` to -1 to exhaust the dataset.
predictions = loop_fn(
iter(dataset), num_steps=-1, state=[], reduce_fn=reduce_fn)
return predictions
outputs = utils.predict(predict_step, aggregate_fn, dataset)
return outputs['predictions']
......@@ -232,30 +232,25 @@ def predict(task: TaggingTask, params: cfg.DataConfig,
sentence id of the corresponding example.
"""
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
def _replicated_step(inputs):
"""Replicated prediction calculation."""
x, y = inputs
sentence_ids = x.pop('sentence_id')
outputs = task.inference_step(x, model)
predict_ids = outputs['predict_ids']
label_mask = tf.greater_equal(y, 0)
return dict(
predict_ids=predict_ids,
label_mask=label_mask,
sentence_ids=sentence_ids)
outputs = tf.distribute.get_strategy().run(
_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(
tf.distribute.get_strategy().experimental_local_results, outputs)
def reduce_fn(state, outputs):
def predict_step(inputs):
"""Replicated prediction calculation."""
x, y = inputs
sentence_ids = x.pop('sentence_id')
outputs = task.inference_step(x, model)
predict_ids = outputs['predict_ids']
label_mask = tf.greater_equal(y, 0)
return dict(
predict_ids=predict_ids,
label_mask=label_mask,
sentence_ids=sentence_ids)
def aggregate_fn(state, outputs):
"""Concatenates model's outputs."""
cur_predict_ids, cur_sentence_ids = state
if state is None:
state = {'predict_ids': [], 'sentence_ids': []}
cur_predict_ids = state['predict_ids']
cur_sentence_ids = state['sentence_ids']
for batch_predict_ids, batch_label_mask, batch_sentence_ids in zip(
outputs['predict_ids'], outputs['label_mask'],
outputs['sentence_ids']):
......@@ -269,12 +264,9 @@ def predict(task: TaggingTask, params: cfg.DataConfig,
# Skip the padding label.
if tmp_label_mask[i]:
cur_predict_ids[-1].append(tmp_predict_ids[i])
return cur_predict_ids, cur_sentence_ids
return state
loop_fn = orbit.utils.create_loop_fn(predict_step)
dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
task.build_inputs, params)
# Set `num_steps` to -1 to exhaust the dataset.
predict_ids, sentence_ids = loop_fn(
iter(dataset), num_steps=-1, state=([], []), reduce_fn=reduce_fn)
return predict_ids, sentence_ids
outputs = utils.predict(predict_step, aggregate_fn, dataset)
return outputs['predict_ids'], outputs['sentence_ids']
......@@ -14,6 +14,9 @@
# limitations under the License.
# ==============================================================================
"""Common utils for tasks."""
from typing import Any, Callable
import orbit
import tensorflow as tf
import tensorflow_hub as hub
......@@ -32,3 +35,34 @@ def get_encoder_from_hub(hub_module: str) -> tf.keras.Model:
return tf.keras.Model(
inputs=[input_word_ids, input_mask, input_type_ids],
outputs=[sequence_output, pooled_output])
def predict(predict_step_fn: Callable[[Any], Any],
aggregate_fn: Callable[[Any, Any], Any],
dataset: tf.data.Dataset):
"""Runs prediction.
Args:
predict_step_fn: A callable such as `def predict_step(inputs)`, where
`inputs` are input tensors.
aggregate_fn: A callable such as `def aggregate_fn(state, value)`, where
`value` is the outputs from `predict_step_fn`.
dataset: A `tf.data.Dataset` object.
Returns:
The aggregated predictions.
"""
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
outputs = tf.distribute.get_strategy().run(
predict_step_fn, args=(next(iterator),))
return tf.nest.map_structure(
tf.distribute.get_strategy().experimental_local_results, outputs)
loop_fn = orbit.utils.create_loop_fn(predict_step)
# Set `num_steps` to -1 to exhaust the dataset.
outputs = loop_fn(
iter(dataset), num_steps=-1, state=None, reduce_fn=aggregate_fn) # pytype: disable=wrong-arg-types
return outputs
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册