提交 02d78796 编写于 作者: C Chen Chen 提交者: A. Unique TensorFlower

Branch a run_squad.py for albert.

PiperOrigin-RevId: 296933982
上级 676b1914
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run ALBERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
from absl import app
from absl import flags
import tensorflow as tf
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization
from official.nlp.data import squad_lib_sp
from official.utils.misc import distribution_utils
flags.DEFINE_string(
'sp_model_file', None,
'The path to the sentence piece model. Used by sentence piece tokenizer '
'employed by ALBERT.')
# More flags can be found in run_squad_helper.
run_squad_helper.define_common_squad_flags()
FLAGS = flags.FLAGS
def train_squad(strategy,
input_meta_data,
custom_callbacks=None,
run_eagerly=False):
"""Runs bert squad training."""
bert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
custom_callbacks, run_eagerly)
def predict_squad(strategy, input_meta_data):
"""Makes predictions for a squad dataset."""
bert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
tokenizer = tokenization.FullSentencePieceTokenizer(
sp_model_file=FLAGS.sp_model_file)
run_squad_helper.predict_squad(strategy, input_meta_data, tokenizer,
bert_config, squad_lib_sp)
def export_squad(model_export_path, input_meta_data):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
Raises:
Export path is not specified, got an empty string or None.
"""
bert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
if FLAGS.mode == 'export_only':
export_squad(FLAGS.model_export_path, input_meta_data)
return
# Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0:
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu)
if FLAGS.mode in ('train', 'train_and_predict'):
train_squad(strategy, input_meta_data, run_eagerly=FLAGS.run_eagerly)
if FLAGS.mode in ('predict', 'train_and_predict'):
predict_squad(strategy, input_meta_data)
if __name__ == '__main__':
flags.mark_flag_as_required('bert_config_file')
flags.mark_flag_as_required('model_dir')
app.run(main)
......@@ -66,10 +66,6 @@ def define_common_bert_flags():
flags.DEFINE_string(
'hub_module_url', None, 'TF-Hub path/url to Bert module. '
'If specified, init_checkpoint flag should not be used.')
flags.DEFINE_enum(
'model_type', 'bert', ['bert', 'albert'],
'Specifies the type of the model. '
'If "bert", will use canonical BERT; if "albert", will use ALBERT model.')
flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.')
......
......@@ -19,361 +19,44 @@ from __future__ import division
from __future__ import print_function
import json
import os
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.modeling import model_training_utils
from official.nlp import optimization
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags
from official.nlp.bert import configs as bert_configs
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
from official.nlp.bert import run_squad_helper
from official.nlp.bert import tokenization
# word-piece tokenizer based squad_lib
from official.nlp.data import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib
from official.nlp.data import squad_lib_sp
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
flags.DEFINE_enum(
'mode', 'train_and_predict',
['train_and_predict', 'train', 'predict', 'export_only'],
'One of {"train_and_predict", "train", "predict", "export_only"}. '
'`train_and_predict`: both train and predict to a json file. '
'`train`: only trains the model. '
'`predict`: predict answers from the squad json file. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.')
flags.DEFINE_string('train_data_path', '',
'Training data path with train tfrecords.')
flags.DEFINE_string(
'input_meta_data_path', None,
'Path to file that contains meta data about input '
'to be used for training and evaluation.')
# Model training specific flags.
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
# Predict processing related.
flags.DEFINE_string('predict_file', None,
'Prediction data path with train tfrecords.')
flags.DEFINE_string('vocab_file', None,
'The vocabulary file that the BERT model was trained on.')
flags.DEFINE_bool(
'do_lower_case', True,
'Whether to lower case the input text. Should be True for uncased '
'models and False for cased models.')
flags.DEFINE_float(
'null_score_diff_threshold', 0.0,
'If null_score - best_non_null is greater than the threshold, '
'predict null. This is only used for SQuAD v2.')
flags.DEFINE_bool(
'verbose_logging', False,
'If true, all of the warnings related to data processing will be printed. '
'A number of warnings are expected for a normal SQuAD evaluation.')
flags.DEFINE_integer('predict_batch_size', 8,
'Total batch size for prediction.')
flags.DEFINE_integer(
'n_best_size', 20,
'The total number of n-best predictions to generate in the '
'nbest_predictions.json output file.')
flags.DEFINE_integer(
'max_answer_length', 30,
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one another.')
flags.DEFINE_string(
'sp_model_file', None,
'The path to the sentence piece model. Used by sentence piece tokenizer '
'employed by ALBERT.')
common_flags.define_common_bert_flags()
# More flags can be found in run_squad_helper.
run_squad_helper.define_common_squad_flags()
FLAGS = flags.FLAGS
MODEL_CLASSES = {
'bert': (bert_configs.BertConfig, squad_lib_wp, tokenization.FullTokenizer),
'albert': (albert_configs.AlbertConfig, squad_lib_sp,
tokenization.FullSentencePieceTokenizer),
}
def squad_loss_fn(start_positions,
end_positions,
start_logits,
end_logits,
loss_factor=1.0):
"""Returns sparse categorical crossentropy for start/end logits."""
start_loss = tf.keras.losses.sparse_categorical_crossentropy(
start_positions, start_logits, from_logits=True)
end_loss = tf.keras.losses.sparse_categorical_crossentropy(
end_positions, end_logits, from_logits=True)
total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
total_loss *= loss_factor
return total_loss
def get_loss_fn(loss_factor=1.0):
"""Gets a loss function for squad task."""
def _loss_fn(labels, model_outputs):
start_positions = labels['start_positions']
end_positions = labels['end_positions']
start_logits, end_logits = model_outputs
return squad_loss_fn(
start_positions,
end_positions,
start_logits,
end_logits,
loss_factor=loss_factor)
return _loss_fn
def get_raw_results(predictions):
"""Converts multi-replica predictions to RawResult."""
squad_lib = MODEL_CLASSES[FLAGS.model_type][1]
for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'],
predictions['start_logits'],
predictions['end_logits']):
for values in zip(unique_ids.numpy(), start_logits.numpy(),
end_logits.numpy()):
yield squad_lib.RawResult(
unique_id=values[0],
start_logits=values[1].tolist(),
end_logits=values[2].tolist())
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
is_training):
"""Gets a closure to create a dataset.."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size = ctx.get_per_replica_batch_size(
global_batch_size) if ctx else global_batch_size
dataset = input_pipeline.create_squad_dataset(
input_file_pattern,
max_seq_length,
batch_size,
is_training=is_training,
input_pipeline_context=ctx)
return dataset
return _dataset_fn
def predict_squad_customized(strategy, input_meta_data, bert_config,
predict_tfrecord_path, num_steps):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
with strategy.scope():
# Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(
bert_config,
input_meta_data['max_seq_length'],
hub_module_url=FLAGS.hub_module_url)
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model)
checkpoint.restore(checkpoint_path).expect_partial()
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
def _replicated_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
unique_ids = x.pop('unique_ids')
start_logits, end_logits = squad_model(x, training=False)
return dict(
unique_ids=unique_ids,
start_logits=start_logits,
end_logits=end_logits)
outputs = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(strategy.experimental_local_results, outputs)
all_results = []
for _ in range(num_steps):
predictions = predict_step(predict_iterator)
for result in get_raw_results(predictions):
all_results.append(result)
if len(all_results) % 100 == 0:
logging.info('Made predictions for %d records.', len(all_results))
return all_results
def train_squad(strategy,
input_meta_data,
custom_callbacks=None,
run_eagerly=False):
"""Run bert squad training."""
if strategy:
logging.info('Training using customized training loop with distribution'
' strategy.')
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_config_v2(FLAGS.enable_xla)
use_float16 = common_flags.use_float16()
if use_float16:
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file(
FLAGS.bert_config_file)
epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size']
max_seq_length = input_meta_data['max_seq_length']
steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
train_input_fn = get_dataset_fn(
FLAGS.train_data_path,
max_seq_length,
FLAGS.train_batch_size,
is_training=True)
def _get_squad_model():
"""Get Squad model and optimizer."""
squad_model, core_model = bert_models.squad_model(
bert_config,
max_seq_length,
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable)
squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
if use_float16:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call
# compile(), we must wrap the optimizer manually.
squad_model.optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
squad_model.optimizer)
return squad_model, core_model
# The original BERT model does not scale the loss by
# 1/num_replicas_in_sync. It could be an accident. So, in order to use
# the same hyper parameter, we do the same thing here by keeping each
# replica loss as it is.
loss_fn = get_loss_fn(
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_squad_model,
loss_fn=loss_fn,
model_dir=FLAGS.model_dir,
steps_per_epoch=steps_per_epoch,
steps_per_loop=FLAGS.steps_per_loop,
epochs=epochs,
train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks)
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
custom_callbacks, run_eagerly)
def predict_squad(strategy, input_meta_data):
"""Makes predictions for a squad dataset."""
config_cls, squad_lib, tokenizer_cls = MODEL_CLASSES[FLAGS.model_type]
bert_config = config_cls.from_json_file(FLAGS.bert_config_file)
if tokenizer_cls == tokenization.FullTokenizer:
tokenizer = tokenizer_cls(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
else:
assert tokenizer_cls == tokenization.FullSentencePieceTokenizer
tokenizer = tokenizer_cls(sp_model_file=FLAGS.sp_model_file)
doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length']
# Whether data should be in Ver 2.0 format.
version_2_with_negative = input_meta_data.get('version_2_with_negative',
False)
eval_examples = squad_lib.read_squad_examples(
input_file=FLAGS.predict_file,
is_training=False,
version_2_with_negative=version_2_with_negative)
eval_writer = squad_lib.FeatureWriter(
filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'),
is_training=False)
eval_features = []
def _append_feature(feature, is_padding):
if not is_padding:
eval_features.append(feature)
eval_writer.process_feature(feature)
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on.
kwargs = dict(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=input_meta_data['max_seq_length'],
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=False,
output_fn=_append_feature,
batch_size=FLAGS.predict_batch_size)
# squad_lib_sp requires one more argument 'do_lower_case'.
if squad_lib == squad_lib_sp:
kwargs['do_lower_case'] = FLAGS.do_lower_case
dataset_size = squad_lib.convert_examples_to_features(**kwargs)
eval_writer.close()
logging.info('***** Running predictions *****')
logging.info(' Num orig examples = %d', len(eval_examples))
logging.info(' Num split examples = %d', len(eval_features))
logging.info(' Batch size = %d', FLAGS.predict_batch_size)
num_steps = int(dataset_size / FLAGS.predict_batch_size)
all_results = predict_squad_customized(strategy, input_meta_data, bert_config,
eval_writer.filename, num_steps)
output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json')
output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json')
output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json')
squad_lib.write_predictions(
eval_examples,
eval_features,
all_results,
FLAGS.n_best_size,
FLAGS.max_answer_length,
FLAGS.do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
version_2_with_negative=version_2_with_negative,
null_score_diff_threshold=FLAGS.null_score_diff_threshold,
verbose=FLAGS.verbose_logging)
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
run_squad_helper.predict_squad(strategy, input_meta_data, tokenizer,
bert_config, squad_lib_wp)
def export_squad(model_export_path, input_meta_data):
......@@ -386,16 +69,8 @@ def export_squad(model_export_path, input_meta_data):
Raises:
Export path is not specified, got an empty string or None.
"""
if not model_export_path:
raise ValueError('Export path is not specified: %s' % model_export_path)
bert_config = MODEL_CLASSES[FLAGS.model_type][0].from_json_file(
FLAGS.bert_config_file)
# Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(bert_config,
input_meta_data['max_seq_length'])
model_saving_utils.export_bert_model(
model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
def main(_):
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library for running BERT family models on SQuAD 1.1/2.0 in TF 2.x."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
from absl import flags
from absl import logging
import tensorflow as tf
from official.modeling import model_training_utils
from official.nlp import optimization
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
from official.nlp.data import squad_lib_sp
from official.utils.misc import keras_utils
def define_common_squad_flags():
"""Defines common flags used by SQuAD tasks."""
flags.DEFINE_enum(
'mode', 'train_and_predict',
['train_and_predict', 'train', 'predict', 'export_only'],
'One of {"train_and_predict", "train", "predict", "export_only"}. '
'`train_and_predict`: both train and predict to a json file. '
'`train`: only trains the model. '
'`predict`: predict answers from the squad json file. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.')
flags.DEFINE_string('train_data_path', '',
'Training data path with train tfrecords.')
flags.DEFINE_string(
'input_meta_data_path', None,
'Path to file that contains meta data about input '
'to be used for training and evaluation.')
# Model training specific flags.
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
# Predict processing related.
flags.DEFINE_string('predict_file', None,
'Prediction data path with train tfrecords.')
flags.DEFINE_bool(
'do_lower_case', True,
'Whether to lower case the input text. Should be True for uncased '
'models and False for cased models.')
flags.DEFINE_float(
'null_score_diff_threshold', 0.0,
'If null_score - best_non_null is greater than the threshold, '
'predict null. This is only used for SQuAD v2.')
flags.DEFINE_bool(
'verbose_logging', False,
'If true, all of the warnings related to data processing will be '
'printed. A number of warnings are expected for a normal SQuAD '
'evaluation.')
flags.DEFINE_integer('predict_batch_size', 8,
'Total batch size for prediction.')
flags.DEFINE_integer(
'n_best_size', 20,
'The total number of n-best predictions to generate in the '
'nbest_predictions.json output file.')
flags.DEFINE_integer(
'max_answer_length', 30,
'The maximum length of an answer that can be generated. This is needed '
'because the start and end predictions are not conditioned on one '
'another.')
common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS
def squad_loss_fn(start_positions,
end_positions,
start_logits,
end_logits,
loss_factor=1.0):
"""Returns sparse categorical crossentropy for start/end logits."""
start_loss = tf.keras.losses.sparse_categorical_crossentropy(
start_positions, start_logits, from_logits=True)
end_loss = tf.keras.losses.sparse_categorical_crossentropy(
end_positions, end_logits, from_logits=True)
total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
total_loss *= loss_factor
return total_loss
def get_loss_fn(loss_factor=1.0):
"""Gets a loss function for squad task."""
def _loss_fn(labels, model_outputs):
start_positions = labels['start_positions']
end_positions = labels['end_positions']
start_logits, end_logits = model_outputs
return squad_loss_fn(
start_positions,
end_positions,
start_logits,
end_logits,
loss_factor=loss_factor)
return _loss_fn
RawResult = collections.namedtuple('RawResult',
['unique_id', 'start_logits', 'end_logits'])
def get_raw_results(predictions):
"""Converts multi-replica predictions to RawResult."""
for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'],
predictions['start_logits'],
predictions['end_logits']):
for values in zip(unique_ids.numpy(), start_logits.numpy(),
end_logits.numpy()):
yield RawResult(
unique_id=values[0],
start_logits=values[1].tolist(),
end_logits=values[2].tolist())
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
is_training):
"""Gets a closure to create a dataset.."""
def _dataset_fn(ctx=None):
"""Returns tf.data.Dataset for distributed BERT pretraining."""
batch_size = ctx.get_per_replica_batch_size(
global_batch_size) if ctx else global_batch_size
dataset = input_pipeline.create_squad_dataset(
input_file_pattern,
max_seq_length,
batch_size,
is_training=is_training,
input_pipeline_context=ctx)
return dataset
return _dataset_fn
def predict_squad_customized(strategy, input_meta_data, bert_config,
predict_tfrecord_path, num_steps):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
with strategy.scope():
# Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(
bert_config,
input_meta_data['max_seq_length'],
hub_module_url=FLAGS.hub_module_url)
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model)
checkpoint.restore(checkpoint_path).expect_partial()
@tf.function
def predict_step(iterator):
"""Predicts on distributed devices."""
def _replicated_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
unique_ids = x.pop('unique_ids')
start_logits, end_logits = squad_model(x, training=False)
return dict(
unique_ids=unique_ids,
start_logits=start_logits,
end_logits=end_logits)
outputs = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(strategy.experimental_local_results, outputs)
all_results = []
for _ in range(num_steps):
predictions = predict_step(predict_iterator)
for result in get_raw_results(predictions):
all_results.append(result)
if len(all_results) % 100 == 0:
logging.info('Made predictions for %d records.', len(all_results))
return all_results
def train_squad(strategy,
input_meta_data,
bert_config,
custom_callbacks=None,
run_eagerly=False):
"""Run bert squad training."""
if strategy:
logging.info('Training using customized training loop with distribution'
' strategy.')
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_config_v2(FLAGS.enable_xla)
use_float16 = common_flags.use_float16()
if use_float16:
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size']
max_seq_length = input_meta_data['max_seq_length']
steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
train_input_fn = get_dataset_fn(
FLAGS.train_data_path,
max_seq_length,
FLAGS.train_batch_size,
is_training=True)
def _get_squad_model():
"""Get Squad model and optimizer."""
squad_model, core_model = bert_models.squad_model(
bert_config,
max_seq_length,
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable)
squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
if use_float16:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call
# compile(), we must wrap the optimizer manually.
squad_model.optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
squad_model.optimizer)
return squad_model, core_model
# The original BERT model does not scale the loss by
# 1/num_replicas_in_sync. It could be an accident. So, in order to use
# the same hyper parameter, we do the same thing here by keeping each
# replica loss as it is.
loss_fn = get_loss_fn(
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_squad_model,
loss_fn=loss_fn,
model_dir=FLAGS.model_dir,
steps_per_epoch=steps_per_epoch,
steps_per_loop=FLAGS.steps_per_loop,
epochs=epochs,
train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks)
def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
"""Makes predictions for a squad dataset."""
doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length']
# Whether data should be in Ver 2.0 format.
version_2_with_negative = input_meta_data.get('version_2_with_negative',
False)
eval_examples = squad_lib.read_squad_examples(
input_file=FLAGS.predict_file,
is_training=False,
version_2_with_negative=version_2_with_negative)
eval_writer = squad_lib.FeatureWriter(
filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'),
is_training=False)
eval_features = []
def _append_feature(feature, is_padding):
if not is_padding:
eval_features.append(feature)
eval_writer.process_feature(feature)
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on.
kwargs = dict(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=input_meta_data['max_seq_length'],
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=False,
output_fn=_append_feature,
batch_size=FLAGS.predict_batch_size)
# squad_lib_sp requires one more argument 'do_lower_case'.
if squad_lib == squad_lib_sp:
kwargs['do_lower_case'] = FLAGS.do_lower_case
dataset_size = squad_lib.convert_examples_to_features(**kwargs)
eval_writer.close()
logging.info('***** Running predictions *****')
logging.info(' Num orig examples = %d', len(eval_examples))
logging.info(' Num split examples = %d', len(eval_features))
logging.info(' Batch size = %d', FLAGS.predict_batch_size)
num_steps = int(dataset_size / FLAGS.predict_batch_size)
all_results = predict_squad_customized(strategy, input_meta_data, bert_config,
eval_writer.filename, num_steps)
output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json')
output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json')
output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json')
squad_lib.write_predictions(
eval_examples,
eval_features,
all_results,
FLAGS.n_best_size,
FLAGS.max_answer_length,
FLAGS.do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
version_2_with_negative=version_2_with_negative,
null_score_diff_threshold=FLAGS.null_score_diff_threshold,
verbose=FLAGS.verbose_logging)
def export_squad(model_export_path, input_meta_data, bert_config):
"""Exports a trained model as a `SavedModel` for inference.
Args:
model_export_path: a string specifying the path to the SavedModel directory.
input_meta_data: dictionary containing meta data about input and model.
bert_config: Bert configuration file to define core bert layers.
Raises:
Export path is not specified, got an empty string or None.
"""
if not model_export_path:
raise ValueError('Export path is not specified: %s' % model_export_path)
# Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
squad_model, _ = bert_models.squad_model(bert_config,
input_meta_data['max_seq_length'])
model_saving_utils.export_bert_model(
model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
......@@ -490,10 +490,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return cur_span_index == best_span_index
RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])
def write_predictions(all_examples,
all_features,
all_results,
......
......@@ -562,10 +562,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return cur_span_index == best_span_index
RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])
def write_predictions(all_examples,
all_features,
all_results,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册