run_classifier.py 18.4 KB
Newer Older
A
A. Unique TensorFlower 已提交
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
F
Frederick Liu 已提交
14

15
"""BERT classification or regression finetuning runner in TF 2.x."""
16

17
import functools
18 19
import json
import math
A
A. Unique TensorFlower 已提交
20
import os
21

H
Hongkun Yu 已提交
22
# Import libraries
23 24 25
from absl import app
from absl import flags
from absl import logging
L
Le Hou 已提交
26
import gin
27
import tensorflow as tf
28
from official.common import distribute_utils
L
Le Hou 已提交
29 30 31 32 33
from official.legacy.bert import bert_models
from official.legacy.bert import common_flags
from official.legacy.bert import configs as bert_configs
from official.legacy.bert import input_pipeline
from official.legacy.bert import model_saving_utils
34
from official.modeling import performance
35
from official.nlp import optimization
36
from official.utils.misc import keras_utils
37 38

flags.DEFINE_enum(
H
Hongkun Yu 已提交
39 40
    'mode', 'train_and_eval', ['train_and_eval', 'export_only', 'predict'],
    'One of {"train_and_eval", "export_only", "predict"}. `train_and_eval`: '
41 42
    'trains the model and evaluates in the meantime. '
    '`export_only`: will take the latest checkpoint inside '
H
Hongkun Yu 已提交
43 44
    'model_dir and export a `SavedModel`. `predict`: takes a checkpoint and '
    'restores the model to output predictions on the test set.')
45 46 47 48 49 50 51 52
flags.DEFINE_string('train_data_path', None,
                    'Path to training data for BERT classifier.')
flags.DEFINE_string('eval_data_path', None,
                    'Path to evaluation data for BERT classifier.')
flags.DEFINE_string(
    'input_meta_data_path', None,
    'Path to file that contains meta data about input '
    'to be used for training and evaluation.')
53 54 55
flags.DEFINE_integer('train_data_size', None, 'Number of training samples '
                     'to use. If None, uses the full train data. '
                     '(default: None).')
H
Hongkun Yu 已提交
56 57
flags.DEFINE_string('predict_checkpoint_path', None,
                    'Path to the checkpoint for predictions.')
T
Tianqi Liu 已提交
58 59 60 61 62 63
flags.DEFINE_integer(
    'num_eval_per_epoch', 1,
    'Number of evaluations per epoch. The purpose of this flag is to provide '
    'more granular evaluation scores and checkpoints. For example, if original '
    'data has N samples and num_eval_per_epoch is n, then each epoch will be '
    'evaluated every N/n samples.')
64
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
65
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
66 67

common_flags.define_common_bert_flags()
68 69 70

FLAGS = flags.FLAGS

71 72
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}

73

74
def get_loss_fn(num_classes):
75 76 77 78
  """Gets the classification loss function."""

  def classification_loss_fn(labels, logits):
    """Classification loss."""
79
    labels = tf.reshape(labels, [-1])
80 81 82 83 84
    log_probs = tf.nn.log_softmax(logits, axis=-1)
    one_hot_labels = tf.one_hot(
        tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32)
    per_example_loss = -tf.reduce_sum(
        tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)
85
    return tf.reduce_mean(per_example_loss)
86 87 88 89

  return classification_loss_fn


T
Tianqi Liu 已提交
90 91 92 93
def get_dataset_fn(input_file_pattern,
                   max_seq_length,
                   global_batch_size,
                   is_training,
94
                   label_type=tf.int64,
95 96
                   include_sample_weights=False,
                   num_samples=None):
H
Hongkun Yu 已提交
97 98 99 100 101 102 103
  """Gets a closure to create a dataset."""

  def _dataset_fn(ctx=None):
    """Returns tf.data.Dataset for distributed BERT pretraining."""
    batch_size = ctx.get_per_replica_batch_size(
        global_batch_size) if ctx else global_batch_size
    dataset = input_pipeline.create_classifier_dataset(
A
A. Unique TensorFlower 已提交
104
        tf.io.gfile.glob(input_file_pattern),
H
Hongkun Yu 已提交
105 106 107
        max_seq_length,
        batch_size,
        is_training=is_training,
108
        input_pipeline_context=ctx,
109
        label_type=label_type,
110 111
        include_sample_weights=include_sample_weights,
        num_samples=num_samples)
H
Hongkun Yu 已提交
112 113 114 115 116
    return dataset

  return _dataset_fn


A
A. Unique TensorFlower 已提交
117 118 119 120 121 122 123 124 125 126 127
def run_bert_classifier(strategy,
                        bert_config,
                        input_meta_data,
                        model_dir,
                        epochs,
                        steps_per_epoch,
                        steps_per_loop,
                        eval_steps,
                        warmup_steps,
                        initial_lr,
                        init_checkpoint,
R
Rajagopal Ananthanarayanan 已提交
128 129
                        train_input_fn,
                        eval_input_fn,
130
                        training_callbacks=True,
131 132
                        custom_callbacks=None,
                        custom_metrics=None):
133 134
  """Run BERT classifier training using low-level API."""
  max_seq_length = input_meta_data['max_seq_length']
135 136
  num_classes = input_meta_data.get('num_labels', 1)
  is_regression = num_classes == 1
137 138

  def _get_classifier_model():
139
    """Gets a classifier model."""
140
    classifier_model, core_model = (
141 142 143 144
        bert_models.classifier_model(
            bert_config,
            num_classes,
            max_seq_length,
A
A. Unique TensorFlower 已提交
145 146
            hub_module_url=FLAGS.hub_module_url,
            hub_module_trainable=FLAGS.hub_module_trainable))
H
Hongkun Yu 已提交
147 148 149 150
    optimizer = optimization.create_optimizer(initial_lr,
                                              steps_per_epoch * epochs,
                                              warmup_steps, FLAGS.end_lr,
                                              FLAGS.optimizer_type)
151 152
    classifier_model.optimizer = performance.configure_optimizer(
        optimizer,
153
        use_float16=common_flags.use_float16())
154 155
    return classifier_model, core_model

156 157 158 159 160 161
  # tf.keras.losses objects accept optional sample_weight arguments (eg. coming
  # from the dataset) to compute weighted loss, as used for the regression
  # tasks. The classification tasks, using the custom get_loss_fn don't accept
  # sample weights though.
  loss_fn = (tf.keras.losses.MeanSquaredError() if is_regression
             else get_loss_fn(num_classes))
162 163 164

  # Defines evaluation metrics function, which will create metrics in the
  # correct device and strategy scope.
165 166 167
  if custom_metrics:
    metric_fn = custom_metrics
  elif is_regression:
T
Tianqi Liu 已提交
168 169 170 171
    metric_fn = functools.partial(
        tf.keras.metrics.MeanSquaredError,
        'mean_squared_error',
        dtype=tf.float32)
172
  else:
T
Tianqi Liu 已提交
173 174 175 176
    metric_fn = functools.partial(
        tf.keras.metrics.SparseCategoricalAccuracy,
        'accuracy',
        dtype=tf.float32)
177 178 179

  # Start training using Keras compile/fit API.
  logging.info('Training using TF 2.x Keras compile/fit API with '
R
Rajagopal Ananthanarayanan 已提交
180
               'distribution strategy.')
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
  return run_keras_compile_fit(
      model_dir,
      strategy,
      _get_classifier_model,
      train_input_fn,
      eval_input_fn,
      loss_fn,
      metric_fn,
      init_checkpoint,
      epochs,
      steps_per_epoch,
      steps_per_loop,
      eval_steps,
      training_callbacks=training_callbacks,
      custom_callbacks=custom_callbacks)
196 197


A
A. Unique TensorFlower 已提交
198 199 200 201 202 203 204 205 206 207
def run_keras_compile_fit(model_dir,
                          strategy,
                          model_fn,
                          train_input_fn,
                          eval_input_fn,
                          loss_fn,
                          metric_fn,
                          init_checkpoint,
                          epochs,
                          steps_per_epoch,
H
Hongkun Yu 已提交
208
                          steps_per_loop,
A
A. Unique TensorFlower 已提交
209
                          eval_steps,
210
                          training_callbacks=True,
A
A. Unique TensorFlower 已提交
211 212 213 214 215
                          custom_callbacks=None):
  """Runs BERT classifier model using Keras compile/fit API."""

  with strategy.scope():
    training_dataset = train_input_fn()
L
Le Hou 已提交
216
    evaluation_dataset = eval_input_fn() if eval_input_fn else None
A
A. Unique TensorFlower 已提交
217 218 219 220
    bert_model, sub_model = model_fn()
    optimizer = bert_model.optimizer

    if init_checkpoint:
C
Chen Chen 已提交
221 222
      checkpoint = tf.train.Checkpoint(model=sub_model, encoder=sub_model)
      checkpoint.read(init_checkpoint).assert_existing_objects_matched()
A
A. Unique TensorFlower 已提交
223

224 225
    if not isinstance(metric_fn, (list, tuple)):
      metric_fn = [metric_fn]
H
Hongkun Yu 已提交
226 227 228
    bert_model.compile(
        optimizer=optimizer,
        loss=loss_fn,
229
        metrics=[fn() for fn in metric_fn],
230
        steps_per_execution=steps_per_loop)
A
A. Unique TensorFlower 已提交
231

232 233
    summary_dir = os.path.join(model_dir, 'summaries')
    summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
H
Hongkun Yu 已提交
234 235 236 237 238 239 240 241
    checkpoint = tf.train.Checkpoint(model=bert_model, optimizer=optimizer)
    checkpoint_manager = tf.train.CheckpointManager(
        checkpoint,
        directory=model_dir,
        max_to_keep=None,
        step_counter=optimizer.iterations,
        checkpoint_interval=0)
    checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)
A
A. Unique TensorFlower 已提交
242

243 244 245 246 247
    if training_callbacks:
      if custom_callbacks is not None:
        custom_callbacks += [summary_callback, checkpoint_callback]
      else:
        custom_callbacks = [summary_callback, checkpoint_callback]
A
A. Unique TensorFlower 已提交
248

249
    history = bert_model.fit(
A
A. Unique TensorFlower 已提交
250 251 252 253 254 255
        x=training_dataset,
        validation_data=evaluation_dataset,
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
        validation_steps=eval_steps,
        callbacks=custom_callbacks)
256 257 258 259 260 261
    stats = {'total_training_steps': steps_per_epoch * epochs}
    if 'loss' in history.history:
      stats['train_loss'] = history.history['loss'][-1]
    if 'val_accuracy' in history.history:
      stats['eval_metrics'] = history.history['val_accuracy'][-1]
    return bert_model, stats
A
A. Unique TensorFlower 已提交
262 263


H
Hongkun Yu 已提交
264 265 266
def get_predictions_and_labels(strategy,
                               trained_model,
                               eval_input_fn,
267
                               is_regression=False,
H
Hongkun Yu 已提交
268
                               return_probs=False):
A
A. Unique TensorFlower 已提交
269 270 271 272 273 274 275 276 277
  """Obtains predictions of trained model on evaluation data.

  Note that list of labels is returned along with the predictions because the
  order changes on distributing dataset over TPU pods.

  Args:
    strategy: Distribution strategy.
    trained_model: Trained model with preloaded weights.
    eval_input_fn: Input function for evaluation data.
278
    is_regression: Whether it is a regression task.
H
Hongkun Yu 已提交
279
    return_probs: Whether to return probabilities of classes.
A
A. Unique TensorFlower 已提交
280 281 282 283 284 285 286 287 288 289 290 291 292

  Returns:
    predictions: List of predictions.
    labels: List of gold labels corresponding to predictions.
  """

  @tf.function
  def test_step(iterator):
    """Computes predictions on distributed devices."""

    def _test_step_fn(inputs):
      """Replicated predictions."""
      inputs, labels = inputs
H
Hongkun Yu 已提交
293
      logits = trained_model(inputs, training=False)
294
      if not is_regression:
295 296 297 298
        probabilities = tf.nn.softmax(logits)
        return probabilities, labels
      else:
        return logits, labels
A
A. Unique TensorFlower 已提交
299

H
Hongkun Yu 已提交
300
    outputs, labels = strategy.run(_test_step_fn, args=(next(iterator),))
A
A. Unique TensorFlower 已提交
301 302 303 304 305 306 307 308 309
    # outputs: current batch logits as a tuple of shard logits
    outputs = tf.nest.map_structure(strategy.experimental_local_results,
                                    outputs)
    labels = tf.nest.map_structure(strategy.experimental_local_results, labels)
    return outputs, labels

  def _run_evaluation(test_iterator):
    """Runs evaluation steps."""
    preds, golds = list(), list()
H
Hongkun Yu 已提交
310 311 312 313 314 315 316 317 318 319 320 321
    try:
      with tf.experimental.async_scope():
        while True:
          probabilities, labels = test_step(test_iterator)
          for cur_probs, cur_labels in zip(probabilities, labels):
            if return_probs:
              preds.extend(cur_probs.numpy().tolist())
            else:
              preds.extend(tf.math.argmax(cur_probs, axis=1).numpy())
            golds.extend(cur_labels.numpy().tolist())
    except (StopIteration, tf.errors.OutOfRangeError):
      tf.experimental.async_clear_error()
A
A. Unique TensorFlower 已提交
322 323
    return preds, golds

C
Chenkai Kuang 已提交
324
  test_iter = iter(strategy.distribute_datasets_from_function(eval_input_fn))
A
A. Unique TensorFlower 已提交
325 326 327 328 329
  predictions, labels = _run_evaluation(test_iter)

  return predictions, labels


H
Hongkun Yu 已提交
330 331
def export_classifier(model_export_path, input_meta_data, bert_config,
                      model_dir):
332 333 334 335 336
  """Exports a trained model as a `SavedModel` for inference.

  Args:
    model_export_path: a string specifying the path to the SavedModel directory.
    input_meta_data: dictionary containing meta data about input and model.
R
Rajagopal Ananthanarayanan 已提交
337 338 339
    bert_config: Bert configuration file to define core bert layers.
    model_dir: The directory where the model weights and training/evaluation
      summaries are stored.
340 341 342 343 344 345

  Raises:
    Export path is not specified, got an empty string or None.
  """
  if not model_export_path:
    raise ValueError('Export path is not specified: %s' % model_export_path)
R
Rajagopal Ananthanarayanan 已提交
346 347
  if not model_dir:
    raise ValueError('Export path is not specified: %s' % model_dir)
348

Z
Zongwei Zhou 已提交
349
  # Export uses float32 for now, even if training uses mixed precision.
R
Reed Wanderman-Milne 已提交
350
  tf.keras.mixed_precision.set_global_policy('float32')
351
  classifier_model = bert_models.classifier_model(
352 353 354 355
      bert_config,
      input_meta_data.get('num_labels', 1),
      hub_module_url=FLAGS.hub_module_url,
      hub_module_trainable=False)[0]
356

357
  model_saving_utils.export_bert_model(
H
Hongkun Yu 已提交
358
      model_export_path, model=classifier_model, checkpoint_dir=model_dir)
359 360


H
Hongkun Yu 已提交
361 362
def run_bert(strategy,
             input_meta_data,
363
             model_config,
H
Hongkun Yu 已提交
364
             train_input_fn=None,
L
Le Hou 已提交
365
             eval_input_fn=None,
366
             init_checkpoint=None,
367 368
             custom_callbacks=None,
             custom_metrics=None):
369
  """Run BERT training."""
370
  # Enables XLA in Session Config. Should not be set for TPU.
371
  keras_utils.set_session_config(FLAGS.enable_xla)
372
  performance.set_mixed_precision_policy(common_flags.dtype())
373

T
Tianqi Liu 已提交
374 375 376
  epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch
  train_data_size = (
      input_meta_data['train_data_size'] // FLAGS.num_eval_per_epoch)
377 378 379
  if FLAGS.train_data_size:
    train_data_size = min(train_data_size, FLAGS.train_data_size)
    logging.info('Updated train_data_size: %s', train_data_size)
380 381 382 383 384 385 386
  steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
  warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
  eval_steps = int(
      math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))

  if not strategy:
    raise ValueError('Distribution strategy has not been specified.')
A
A. Unique TensorFlower 已提交
387

388 389 390
  if not custom_callbacks:
    custom_callbacks = []

W
Will Cromar 已提交
391
  if FLAGS.log_steps:
H
Hongkun Yu 已提交
392 393 394 395 396
    custom_callbacks.append(
        keras_utils.TimeHistory(
            batch_size=FLAGS.train_batch_size,
            log_steps=FLAGS.log_steps,
            logdir=FLAGS.model_dir))
W
Will Cromar 已提交
397

398
  trained_model, _ = run_bert_classifier(
399
      strategy,
400
      model_config,
401 402 403 404
      input_meta_data,
      FLAGS.model_dir,
      epochs,
      steps_per_epoch,
405
      FLAGS.steps_per_loop,
406 407 408
      eval_steps,
      warmup_steps,
      FLAGS.learning_rate,
L
Le Hou 已提交
409
      init_checkpoint or FLAGS.init_checkpoint,
R
Rajagopal Ananthanarayanan 已提交
410 411
      train_input_fn,
      eval_input_fn,
412 413
      custom_callbacks=custom_callbacks,
      custom_metrics=custom_metrics)
414

415
  if FLAGS.model_export_path:
416
    model_saving_utils.export_bert_model(
H
Hongkun Yu 已提交
417
        FLAGS.model_export_path, model=trained_model)
418 419
  return trained_model

420

421
def custom_main(custom_callbacks=None, custom_metrics=None):
422
  """Run classification or regression.
423

424 425
  Args:
    custom_callbacks: list of tf.keras.Callbacks passed to training loop.
426
    custom_metrics: list of metrics passed to the training loop.
427
  """
L
Le Hou 已提交
428 429
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)

430 431
  with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
    input_meta_data = json.loads(reader.read().decode('utf-8'))
432
  label_type = LABEL_TYPES_MAP[input_meta_data.get('label_type', 'int')]
433
  include_sample_weights = input_meta_data.get('has_sample_weights', False)
434 435 436 437

  if not FLAGS.model_dir:
    FLAGS.model_dir = '/tmp/bert20/'

H
Hongkun Yu 已提交
438 439 440 441 442 443 444
  bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)

  if FLAGS.mode == 'export_only':
    export_classifier(FLAGS.model_export_path, input_meta_data, bert_config,
                      FLAGS.model_dir)
    return

445
  strategy = distribute_utils.get_distribution_strategy(
446 447 448
      distribution_strategy=FLAGS.distribution_strategy,
      num_gpus=FLAGS.num_gpus,
      tpu_address=FLAGS.tpu)
H
Hongkun Yu 已提交
449
  eval_input_fn = get_dataset_fn(
R
Rajagopal Ananthanarayanan 已提交
450
      FLAGS.eval_data_path,
H
Hongkun Yu 已提交
451
      input_meta_data['max_seq_length'],
H
Hongkun Yu 已提交
452
      FLAGS.eval_batch_size,
453
      is_training=False,
454 455
      label_type=label_type,
      include_sample_weights=include_sample_weights)
H
Hongkun Yu 已提交
456

H
Hongkun Yu 已提交
457
  if FLAGS.mode == 'predict':
458
    num_labels = input_meta_data.get('num_labels', 1)
H
Hongkun Yu 已提交
459 460
    with strategy.scope():
      classifier_model = bert_models.classifier_model(
461
          bert_config, num_labels)[0]
H
Hongkun Yu 已提交
462 463 464 465 466 467 468 469 470 471
      checkpoint = tf.train.Checkpoint(model=classifier_model)
      latest_checkpoint_file = (
          FLAGS.predict_checkpoint_path or
          tf.train.latest_checkpoint(FLAGS.model_dir))
      assert latest_checkpoint_file
      logging.info('Checkpoint file %s found and restoring from '
                   'checkpoint', latest_checkpoint_file)
      checkpoint.restore(
          latest_checkpoint_file).assert_existing_objects_matched()
      preds, _ = get_predictions_and_labels(
472 473 474 475 476
          strategy,
          classifier_model,
          eval_input_fn,
          is_regression=(num_labels == 1),
          return_probs=True)
H
Hongkun Yu 已提交
477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492
    output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
    with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
      logging.info('***** Predict results *****')
      for probabilities in preds:
        output_line = '\t'.join(
            str(class_probability)
            for class_probability in probabilities) + '\n'
        writer.write(output_line)
    return

  if FLAGS.mode != 'train_and_eval':
    raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
  train_input_fn = get_dataset_fn(
      FLAGS.train_data_path,
      input_meta_data['max_seq_length'],
      FLAGS.train_batch_size,
493
      is_training=True,
494
      label_type=label_type,
495 496
      include_sample_weights=include_sample_weights,
      num_samples=FLAGS.train_data_size)
H
Hongkun Yu 已提交
497 498 499 500 501 502
  run_bert(
      strategy,
      input_meta_data,
      bert_config,
      train_input_fn,
      eval_input_fn,
503 504
      custom_callbacks=custom_callbacks,
      custom_metrics=custom_metrics)
505 506 507


def main(_):
508
  custom_main(custom_callbacks=None, custom_metrics=None)
509 510 511 512 513


if __name__ == '__main__':
  flags.mark_flag_as_required('bert_config_file')
  flags.mark_flag_as_required('input_meta_data_path')
514
  flags.mark_flag_as_required('model_dir')
515
  app.run(main)