diff --git a/official/nlp/xlnet/data_utils.py b/official/nlp/xlnet/data_utils.py index d86c66b06be506bfca7ee90862dba519c4e4bdca..dcb18ce9d3b818dbee5425b1aad6bfee9be5081e 100644 --- a/official/nlp/xlnet/data_utils.py +++ b/official/nlp/xlnet/data_utils.py @@ -168,7 +168,7 @@ def create_squad_dataset(file_path, seq_length, batch_size, is_training): return dataset -def _get_input_iterator(input_fn, strategy): +def get_input_iterator(input_fn, strategy): """Returns distributed dataset iterator.""" # When training with TPU pods, datasets needs to be cloned across diff --git a/official/nlp/xlnet/preprocess_squad_data.py b/official/nlp/xlnet/preprocess_squad_data.py index 2a4302913219773e1f06761364949425578438bc..59c8944697348f12b185399463978c170b4ee46b 100644 --- a/official/nlp/xlnet/preprocess_squad_data.py +++ b/official/nlp/xlnet/preprocess_squad_data.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import os -import pickle import random from absl import app @@ -54,16 +53,11 @@ flags.DEFINE_bool( FLAGS = flags.FLAGS -def _get_spm_basename(): - spm_basename = os.path.basename(FLAGS.spiece_model_file) - return spm_basename - - def preprocess(): """Preprocesses SQUAD data.""" sp_model = spm.SentencePieceProcessor() sp_model.Load(FLAGS.spiece_model_file) - spm_basename = _get_spm_basename() + spm_basename = os.path.basename(FLAGS.spiece_model_file) if FLAGS.create_train_data: train_rec_file = os.path.join( FLAGS.output_dir, @@ -97,39 +91,10 @@ def preprocess(): if FLAGS.create_eval_data: eval_examples = squad_utils.read_squad_examples( FLAGS.predict_file, is_training=False) - - eval_rec_file = os.path.join( - FLAGS.output_dir, - "{}.slen-{}.qlen-{}.eval.tf_record".format(spm_basename, - FLAGS.max_seq_length, - FLAGS.max_query_length)) - eval_feature_file = os.path.join( - FLAGS.output_dir, - "{}.slen-{}.qlen-{}.eval.features.pkl".format(spm_basename, - FLAGS.max_seq_length, - FLAGS.max_query_length)) - - eval_writer = squad_utils.FeatureWriter( - filename=eval_rec_file, is_training=False) - eval_features = [] - - def append_feature(feature): - eval_features.append(feature) - eval_writer.process_feature(feature) - - squad_utils.convert_examples_to_features( - examples=eval_examples, - sp_model=sp_model, - max_seq_length=FLAGS.max_seq_length, - doc_stride=FLAGS.doc_stride, - max_query_length=FLAGS.max_query_length, - is_training=False, - output_fn=append_feature, - uncased=FLAGS.uncased) - eval_writer.close() - - with tf.io.gfile.GFile(eval_feature_file, "wb") as fout: - pickle.dump(eval_features, fout) + squad_utils.create_eval_data(spm_basename, sp_model, eval_examples, + FLAGS.max_seq_length, FLAGS.max_query_length, + FLAGS.doc_stride, FLAGS.uncased, + FLAGS.output_dir) def main(_): diff --git a/official/nlp/xlnet/run_squad.py b/official/nlp/xlnet/run_squad.py index 4ea60633e22f325e97a65fcc9fa009acab9312f8..c00ee567d08d3e7b67fed4dd74805f81120a0e33 100644 --- a/official/nlp/xlnet/run_squad.py +++ b/official/nlp/xlnet/run_squad.py @@ -30,6 +30,7 @@ from absl import logging import tensorflow as tf # pylint: disable=unused-import +import sentencepiece as spm from official.nlp import xlnet_config from official.nlp import xlnet_modeling as modeling from official.nlp.xlnet import common_flags @@ -51,6 +52,12 @@ flags.DEFINE_string( flags.DEFINE_integer( "n_best_size", default=5, help="n best size for predictions.") flags.DEFINE_integer("max_answer_length", default=64, help="Max answer length.") +# Data preprocessing config +flags.DEFINE_string( + "spiece_model_file", default=None, help="Sentence Piece model path.") +flags.DEFINE_integer("max_seq_length", default=512, help="Max sequence length.") +flags.DEFINE_integer("max_query_length", default=64, help="Max query length.") +flags.DEFINE_integer("doc_stride", default=128, help="Doc stride.") FLAGS = flags.FLAGS @@ -92,23 +99,23 @@ class InputFeatures(object): # pylint: disable=unused-argument -def run_evaluation(strategy, - test_input_fn, - eval_steps, - input_meta_data, - model, - step, - eval_summary_writer=None): +def run_evaluation(strategy, test_input_fn, eval_examples, eval_features, + original_data, eval_steps, input_meta_data, model, + current_step, eval_summary_writer): """Run evaluation for SQUAD task. Args: strategy: distribution strategy. test_input_fn: input function for evaluation data. + eval_examples: tf.Examples of the evaluation set. + eval_features: Feature objects of the evaluation set. + original_data: The original json data for the evaluation set. eval_steps: total number of evaluation steps. input_meta_data: input meta data. model: keras model object. - step: current training step. + current_step: current training step. eval_summary_writer: summary writer used to record evaluation metrics. + Returns: A float metric, F1 score. """ @@ -127,15 +134,8 @@ def run_evaluation(strategy, _test_step_fn, args=(next(test_iterator),)) return res, unique_ids - # pylint: disable=protected-access - test_iterator = data_utils._get_input_iterator(test_input_fn, strategy) - # pylint: enable=protected-access + test_iterator = data_utils.get_input_iterator(test_input_fn, strategy) cur_results = [] - eval_examples = squad_utils.read_squad_examples( - input_meta_data["predict_file"], is_training=False) - with tf.io.gfile.GFile(input_meta_data["predict_file"]) as f: - orig_data = json.load(f)["data"] - for _ in range(eval_steps): results, unique_ids = _run_evaluation(test_iterator) unique_ids = strategy.experimental_local_results(unique_ids) @@ -187,21 +187,20 @@ def run_evaluation(strategy, "null_odds.json") results = squad_utils.write_predictions( - eval_examples, input_meta_data["eval_features"], cur_results, - input_meta_data["n_best_size"], input_meta_data["max_answer_length"], - output_prediction_file, output_nbest_file, output_null_log_odds_file, - orig_data, input_meta_data["start_n_top"], input_meta_data["end_n_top"]) + eval_examples, eval_features, cur_results, input_meta_data["n_best_size"], + input_meta_data["max_answer_length"], output_prediction_file, + output_nbest_file, output_null_log_odds_file, original_data, + input_meta_data["start_n_top"], input_meta_data["end_n_top"]) # Log current results. log_str = "Result | " for key, val in results.items(): log_str += "{} {} | ".format(key, val) logging.info(log_str) - if eval_summary_writer: - with eval_summary_writer.as_default(): - tf.summary.scalar("best_f1", results["best_f1"], step=step) - tf.summary.scalar("best_exact", results["best_exact"], step=step) - eval_summary_writer.flush() + with eval_summary_writer.as_default(): + tf.summary.scalar("best_f1", results["best_f1"], step=current_step) + tf.summary.scalar("best_exact", results["best_exact"], step=current_step) + eval_summary_writer.flush() return results["best_f1"] @@ -254,24 +253,33 @@ def main(unused_argv): input_meta_data["end_n_top"] = FLAGS.end_n_top input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate input_meta_data["predict_dir"] = FLAGS.predict_dir - input_meta_data["predict_file"] = FLAGS.predict_file input_meta_data["n_best_size"] = FLAGS.n_best_size input_meta_data["max_answer_length"] = FLAGS.max_answer_length - input_meta_data["test_feature_path"] = FLAGS.test_feature_path input_meta_data["test_batch_size"] = FLAGS.test_batch_size input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size / strategy.num_replicas_in_sync) input_meta_data["mem_len"] = FLAGS.mem_len model_fn = functools.partial(get_qaxlnet_model, model_config, run_config, FLAGS.start_n_top, FLAGS.end_n_top) - - logging.info("start reading pickle file...") - with tf.io.gfile.GFile(input_meta_data["test_feature_path"], "rb") as f: - eval_features = pickle.load(f) - - logging.info("finishing reading pickle file...") - input_meta_data["eval_features"] = eval_features + eval_examples = squad_utils.read_squad_examples( + FLAGS.predict_file, is_training=False) + if FLAGS.test_feature_path: + logging.info("start reading pickle file...") + with tf.io.gfile.GFile(FLAGS.test_feature_path, "rb") as f: + eval_features = pickle.load(f) + logging.info("finishing reading pickle file...") + else: + sp_model = spm.SentencePieceProcessor() + sp_model.Load(FLAGS.spiece_model_file) + spm_basename = os.path.basename(FLAGS.spiece_model_file) + eval_features = squad_utils.create_eval_data( + spm_basename, sp_model, eval_examples, FLAGS.max_seq_length, + FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.uncased) + + with tf.io.gfile.GFile(FLAGS.predict_file) as f: + original_data = json.load(f)["data"] eval_fn = functools.partial(run_evaluation, strategy, test_input_fn, + eval_examples, eval_features, original_data, eval_steps, input_meta_data) training_utils.train( diff --git a/official/nlp/xlnet/squad_utils.py b/official/nlp/xlnet/squad_utils.py index fe4f3e621eb940075edd8d9a87cd0bf04f0f83c5..efab6da6f80658213317e13dee86b09b2cb94c63 100644 --- a/official/nlp/xlnet/squad_utils.py +++ b/official/nlp/xlnet/squad_utils.py @@ -23,6 +23,8 @@ import collections import gc import json import math +import os +import pickle import re import string @@ -922,3 +924,50 @@ class FeatureWriter(object): def close(self): self._writer.close() + + +def create_eval_data(spm_basename, + sp_model, + eval_examples, + max_seq_length, + max_query_length, + doc_stride, + uncased, + output_dir=None): + """Creates evaluation tfrecords.""" + eval_features = [] + eval_writer = None + if output_dir: + eval_rec_file = os.path.join( + output_dir, + "{}.slen-{}.qlen-{}.eval.tf_record".format(spm_basename, max_seq_length, + max_query_length)) + eval_feature_file = os.path.join( + output_dir, + "{}.slen-{}.qlen-{}.eval.features.pkl".format(spm_basename, + max_seq_length, + max_query_length)) + + eval_writer = FeatureWriter(filename=eval_rec_file, is_training=False) + + def append_feature(feature): + eval_features.append(feature) + if eval_writer: + eval_writer.process_feature(feature) + + convert_examples_to_features( + examples=eval_examples, + sp_model=sp_model, + max_seq_length=max_seq_length, + doc_stride=doc_stride, + max_query_length=max_query_length, + is_training=False, + output_fn=append_feature, + uncased=uncased) + + if eval_writer: + eval_writer.close() + with tf.io.gfile.GFile(eval_feature_file, "wb") as fout: + pickle.dump(eval_features, fout) + + return eval_features diff --git a/official/nlp/xlnet/training_utils.py b/official/nlp/xlnet/training_utils.py index 361b3fb4b54819b3dab129563685c884750ca303..cdb6d2b06bc3fda65dd15d42d7661668fd3b49bb 100644 --- a/official/nlp/xlnet/training_utils.py +++ b/official/nlp/xlnet/training_utils.py @@ -110,9 +110,7 @@ def train( "`learning_rate_fn` are required parameters.") if not model_dir: raise TypeError("Model directory must be specified.") - # pylint: disable=protected-access - train_iterator = data_utils._get_input_iterator(train_input_fn, strategy) - # pylint: enable=protected-access + train_iterator = data_utils.get_input_iterator(train_input_fn, strategy) if not tf.io.gfile.exists(model_dir): tf.io.gfile.mkdir(model_dir) # Create summary writers diff --git a/official/requirements.txt b/official/requirements.txt index 26e5b89d1425c67e766366306b3bb2039568ce7a..76106e774d2f0b83bc075b7a8b89c5416a459eda 100644 --- a/official/requirements.txt +++ b/official/requirements.txt @@ -8,8 +8,9 @@ pandas>=0.22.0 psutil>=5.4.3 py-cpuinfo>=3.3.0 scipy>=0.19.1 +tensorflow-hub>=0.6.0 typing -tensorflow-hub +sentencepiece Cython matplotlib opencv-python-headless