提交 24c619ff 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 277793274
上级 fc2056bc
......@@ -168,7 +168,7 @@ def create_squad_dataset(file_path, seq_length, batch_size, is_training):
return dataset
def _get_input_iterator(input_fn, strategy):
def get_input_iterator(input_fn, strategy):
"""Returns distributed dataset iterator."""
# When training with TPU pods, datasets needs to be cloned across
......
......@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import os
import pickle
import random
from absl import app
......@@ -54,16 +53,11 @@ flags.DEFINE_bool(
FLAGS = flags.FLAGS
def _get_spm_basename():
spm_basename = os.path.basename(FLAGS.spiece_model_file)
return spm_basename
def preprocess():
"""Preprocesses SQUAD data."""
sp_model = spm.SentencePieceProcessor()
sp_model.Load(FLAGS.spiece_model_file)
spm_basename = _get_spm_basename()
spm_basename = os.path.basename(FLAGS.spiece_model_file)
if FLAGS.create_train_data:
train_rec_file = os.path.join(
FLAGS.output_dir,
......@@ -97,39 +91,10 @@ def preprocess():
if FLAGS.create_eval_data:
eval_examples = squad_utils.read_squad_examples(
FLAGS.predict_file, is_training=False)
eval_rec_file = os.path.join(
FLAGS.output_dir,
"{}.slen-{}.qlen-{}.eval.tf_record".format(spm_basename,
FLAGS.max_seq_length,
FLAGS.max_query_length))
eval_feature_file = os.path.join(
FLAGS.output_dir,
"{}.slen-{}.qlen-{}.eval.features.pkl".format(spm_basename,
FLAGS.max_seq_length,
FLAGS.max_query_length))
eval_writer = squad_utils.FeatureWriter(
filename=eval_rec_file, is_training=False)
eval_features = []
def append_feature(feature):
eval_features.append(feature)
eval_writer.process_feature(feature)
squad_utils.convert_examples_to_features(
examples=eval_examples,
sp_model=sp_model,
max_seq_length=FLAGS.max_seq_length,
doc_stride=FLAGS.doc_stride,
max_query_length=FLAGS.max_query_length,
is_training=False,
output_fn=append_feature,
uncased=FLAGS.uncased)
eval_writer.close()
with tf.io.gfile.GFile(eval_feature_file, "wb") as fout:
pickle.dump(eval_features, fout)
squad_utils.create_eval_data(spm_basename, sp_model, eval_examples,
FLAGS.max_seq_length, FLAGS.max_query_length,
FLAGS.doc_stride, FLAGS.uncased,
FLAGS.output_dir)
def main(_):
......
......@@ -30,6 +30,7 @@ from absl import logging
import tensorflow as tf
# pylint: disable=unused-import
import sentencepiece as spm
from official.nlp import xlnet_config
from official.nlp import xlnet_modeling as modeling
from official.nlp.xlnet import common_flags
......@@ -51,6 +52,12 @@ flags.DEFINE_string(
flags.DEFINE_integer(
"n_best_size", default=5, help="n best size for predictions.")
flags.DEFINE_integer("max_answer_length", default=64, help="Max answer length.")
# Data preprocessing config
flags.DEFINE_string(
"spiece_model_file", default=None, help="Sentence Piece model path.")
flags.DEFINE_integer("max_seq_length", default=512, help="Max sequence length.")
flags.DEFINE_integer("max_query_length", default=64, help="Max query length.")
flags.DEFINE_integer("doc_stride", default=128, help="Doc stride.")
FLAGS = flags.FLAGS
......@@ -92,23 +99,23 @@ class InputFeatures(object):
# pylint: disable=unused-argument
def run_evaluation(strategy,
test_input_fn,
eval_steps,
input_meta_data,
model,
step,
eval_summary_writer=None):
def run_evaluation(strategy, test_input_fn, eval_examples, eval_features,
original_data, eval_steps, input_meta_data, model,
current_step, eval_summary_writer):
"""Run evaluation for SQUAD task.
Args:
strategy: distribution strategy.
test_input_fn: input function for evaluation data.
eval_examples: tf.Examples of the evaluation set.
eval_features: Feature objects of the evaluation set.
original_data: The original json data for the evaluation set.
eval_steps: total number of evaluation steps.
input_meta_data: input meta data.
model: keras model object.
step: current training step.
current_step: current training step.
eval_summary_writer: summary writer used to record evaluation metrics.
Returns:
A float metric, F1 score.
"""
......@@ -127,15 +134,8 @@ def run_evaluation(strategy,
_test_step_fn, args=(next(test_iterator),))
return res, unique_ids
# pylint: disable=protected-access
test_iterator = data_utils._get_input_iterator(test_input_fn, strategy)
# pylint: enable=protected-access
test_iterator = data_utils.get_input_iterator(test_input_fn, strategy)
cur_results = []
eval_examples = squad_utils.read_squad_examples(
input_meta_data["predict_file"], is_training=False)
with tf.io.gfile.GFile(input_meta_data["predict_file"]) as f:
orig_data = json.load(f)["data"]
for _ in range(eval_steps):
results, unique_ids = _run_evaluation(test_iterator)
unique_ids = strategy.experimental_local_results(unique_ids)
......@@ -187,21 +187,20 @@ def run_evaluation(strategy,
"null_odds.json")
results = squad_utils.write_predictions(
eval_examples, input_meta_data["eval_features"], cur_results,
input_meta_data["n_best_size"], input_meta_data["max_answer_length"],
output_prediction_file, output_nbest_file, output_null_log_odds_file,
orig_data, input_meta_data["start_n_top"], input_meta_data["end_n_top"])
eval_examples, eval_features, cur_results, input_meta_data["n_best_size"],
input_meta_data["max_answer_length"], output_prediction_file,
output_nbest_file, output_null_log_odds_file, original_data,
input_meta_data["start_n_top"], input_meta_data["end_n_top"])
# Log current results.
log_str = "Result | "
for key, val in results.items():
log_str += "{} {} | ".format(key, val)
logging.info(log_str)
if eval_summary_writer:
with eval_summary_writer.as_default():
tf.summary.scalar("best_f1", results["best_f1"], step=step)
tf.summary.scalar("best_exact", results["best_exact"], step=step)
eval_summary_writer.flush()
with eval_summary_writer.as_default():
tf.summary.scalar("best_f1", results["best_f1"], step=current_step)
tf.summary.scalar("best_exact", results["best_exact"], step=current_step)
eval_summary_writer.flush()
return results["best_f1"]
......@@ -254,24 +253,33 @@ def main(unused_argv):
input_meta_data["end_n_top"] = FLAGS.end_n_top
input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
input_meta_data["predict_dir"] = FLAGS.predict_dir
input_meta_data["predict_file"] = FLAGS.predict_file
input_meta_data["n_best_size"] = FLAGS.n_best_size
input_meta_data["max_answer_length"] = FLAGS.max_answer_length
input_meta_data["test_feature_path"] = FLAGS.test_feature_path
input_meta_data["test_batch_size"] = FLAGS.test_batch_size
input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
strategy.num_replicas_in_sync)
input_meta_data["mem_len"] = FLAGS.mem_len
model_fn = functools.partial(get_qaxlnet_model, model_config, run_config,
FLAGS.start_n_top, FLAGS.end_n_top)
logging.info("start reading pickle file...")
with tf.io.gfile.GFile(input_meta_data["test_feature_path"], "rb") as f:
eval_features = pickle.load(f)
logging.info("finishing reading pickle file...")
input_meta_data["eval_features"] = eval_features
eval_examples = squad_utils.read_squad_examples(
FLAGS.predict_file, is_training=False)
if FLAGS.test_feature_path:
logging.info("start reading pickle file...")
with tf.io.gfile.GFile(FLAGS.test_feature_path, "rb") as f:
eval_features = pickle.load(f)
logging.info("finishing reading pickle file...")
else:
sp_model = spm.SentencePieceProcessor()
sp_model.Load(FLAGS.spiece_model_file)
spm_basename = os.path.basename(FLAGS.spiece_model_file)
eval_features = squad_utils.create_eval_data(
spm_basename, sp_model, eval_examples, FLAGS.max_seq_length,
FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.uncased)
with tf.io.gfile.GFile(FLAGS.predict_file) as f:
original_data = json.load(f)["data"]
eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
eval_examples, eval_features, original_data,
eval_steps, input_meta_data)
training_utils.train(
......
......@@ -23,6 +23,8 @@ import collections
import gc
import json
import math
import os
import pickle
import re
import string
......@@ -922,3 +924,50 @@ class FeatureWriter(object):
def close(self):
self._writer.close()
def create_eval_data(spm_basename,
sp_model,
eval_examples,
max_seq_length,
max_query_length,
doc_stride,
uncased,
output_dir=None):
"""Creates evaluation tfrecords."""
eval_features = []
eval_writer = None
if output_dir:
eval_rec_file = os.path.join(
output_dir,
"{}.slen-{}.qlen-{}.eval.tf_record".format(spm_basename, max_seq_length,
max_query_length))
eval_feature_file = os.path.join(
output_dir,
"{}.slen-{}.qlen-{}.eval.features.pkl".format(spm_basename,
max_seq_length,
max_query_length))
eval_writer = FeatureWriter(filename=eval_rec_file, is_training=False)
def append_feature(feature):
eval_features.append(feature)
if eval_writer:
eval_writer.process_feature(feature)
convert_examples_to_features(
examples=eval_examples,
sp_model=sp_model,
max_seq_length=max_seq_length,
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=False,
output_fn=append_feature,
uncased=uncased)
if eval_writer:
eval_writer.close()
with tf.io.gfile.GFile(eval_feature_file, "wb") as fout:
pickle.dump(eval_features, fout)
return eval_features
......@@ -110,9 +110,7 @@ def train(
"`learning_rate_fn` are required parameters.")
if not model_dir:
raise TypeError("Model directory must be specified.")
# pylint: disable=protected-access
train_iterator = data_utils._get_input_iterator(train_input_fn, strategy)
# pylint: enable=protected-access
train_iterator = data_utils.get_input_iterator(train_input_fn, strategy)
if not tf.io.gfile.exists(model_dir):
tf.io.gfile.mkdir(model_dir)
# Create summary writers
......
......@@ -8,8 +8,9 @@ pandas>=0.22.0
psutil>=5.4.3
py-cpuinfo>=3.3.0
scipy>=0.19.1
tensorflow-hub>=0.6.0
typing
tensorflow-hub
sentencepiece
Cython
matplotlib
opencv-python-headless
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册