diff --git a/official/modeling/model_training_utils.py b/official/modeling/model_training_utils.py index aa732181addc6d9217b9cc78e08b4f0e5dab9f88..69c90c0f485aac57ccead27bbd5f2600e9d189c8 100644 --- a/official/modeling/model_training_utils.py +++ b/official/modeling/model_training_utils.py @@ -94,7 +94,8 @@ def run_customized_training_loop( metric_fn=None, init_checkpoint=None, custom_callbacks=None, - run_eagerly=False): + run_eagerly=False, + sub_model_export_name=None): """Run BERT pretrain model training using low-level API. Arguments: @@ -131,6 +132,11 @@ def run_customized_training_loop( methods are invoked during training. run_eagerly: Whether to run model training in pure eager execution. This should be disable for TPUStrategy. + sub_model_export_name: If not None, will export `sub_model` returned by + `model_fn` into checkpoint files. The name of intermediate checkpoint + file is {sub_model_export_name}_step_{step}.ckpt and the last + checkpint's name is {sub_model_export_name}.ckpt; + if None, `sub_model` will not be exported as checkpoint. Returns: Trained model. @@ -139,6 +145,8 @@ def run_customized_training_loop( ValueError: (1) When model returned by `model_fn` does not have optimizer attribute or when required parameters are set to none. (2) eval args are not specified correctly. (3) metric_fn must be a callable if specified. + (4) sub_model_checkpoint_name is specified, but `sub_model` returned + by `model_fn` is None. """ if _sentinel is not None: @@ -191,6 +199,10 @@ def run_customized_training_loop( if not hasattr(model, 'optimizer'): raise ValueError('User should set optimizer attribute to model ' 'inside `model_fn`.') + if sub_model_export_name and sub_model is None: + raise ValueError('sub_model_export_name is specified as %s, but ' + 'sub_model is None.' % sub_model_export_name) + optimizer = model.optimizer use_float16 = isinstance( optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer) @@ -326,6 +338,9 @@ def run_customized_training_loop( # Training loop starts here. checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) + sub_model_checkpoint = tf.train.Checkpoint( + model=sub_model) if sub_model_export_name else None + latest_checkpoint_file = tf.train.latest_checkpoint(model_dir) if latest_checkpoint_file: logging.info( @@ -382,7 +397,10 @@ def run_customized_training_loop( if current_step < total_training_steps: _save_checkpoint(checkpoint, model_dir, checkpoint_name.format(step=current_step)) - + if sub_model_export_name: + _save_checkpoint( + sub_model_checkpoint, model_dir, + '%s_step_%d.ckpt' % (sub_model_export_name, current_step)) if eval_input_fn: logging.info('Running evaluation after step: %s.', current_step) _run_evaluation(current_step, @@ -393,6 +411,9 @@ def run_customized_training_loop( _save_checkpoint(checkpoint, model_dir, checkpoint_name.format(step=current_step)) + if sub_model_export_name: + _save_checkpoint(sub_model_checkpoint, model_dir, + '%s.ckpt' % sub_model_export_name) if eval_input_fn: logging.info('Running final evaluation after training is complete.') diff --git a/official/nlp/bert/model_saving_utils.py b/official/nlp/bert/model_saving_utils.py index ef101052ff93c42a5145a69b6d709800483d88a0..e8e8fa841b2a1e85af7a79cd71eb0b77ab5b6b76 100644 --- a/official/nlp/bert/model_saving_utils.py +++ b/official/nlp/bert/model_saving_utils.py @@ -77,37 +77,6 @@ def export_bert_model(model_export_path: typing.Text, model.save(model_export_path, include_optimizer=False, save_format='tf') -def export_pretraining_checkpoint( - checkpoint_dir: typing.Text, - model: tf.keras.Model, - checkpoint_name: typing.Optional[ - typing.Text] = 'pretrained/bert_model.ckpt'): - """Exports BERT model for as a checkpoint without optimizer. - - Arguments: - checkpoint_dir: Path to where training model checkpoints are stored. - model: Keras model object to export. - checkpoint_name: File name or suffix path to export pretrained checkpoint. - - Raises: - ValueError when either checkpoint_dir or model is not specified. - """ - if not checkpoint_dir: - raise ValueError('checkpoint_dir must be specified.') - if not isinstance(model, tf.keras.Model): - raise ValueError('model must be a tf.keras.Model object.') - - checkpoint = tf.train.Checkpoint(model=model) - latest_checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) - assert latest_checkpoint_file - logging.info('Checkpoint file %s found and restoring from ' - 'checkpoint', latest_checkpoint_file) - status = checkpoint.restore(latest_checkpoint_file) - status.assert_existing_objects_matched().expect_partial() - saved_path = checkpoint.save(os.path.join(checkpoint_dir, checkpoint_name)) - logging.info('Exporting the model as a new TF checkpoint: %s', saved_path) - - class BertModelCheckpoint(tf.keras.callbacks.Callback): """Keras callback that saves model at the end of every epoch.""" diff --git a/official/nlp/bert/run_pretraining.py b/official/nlp/bert/run_pretraining.py index b54bf34ef746398c61a37c020a0b767a241f4555..c36cd0f37cb4caa000250f222cb430e04f58f598 100644 --- a/official/nlp/bert/run_pretraining.py +++ b/official/nlp/bert/run_pretraining.py @@ -126,16 +126,9 @@ def run_customized_training(strategy, train_input_fn=train_input_fn, steps_per_epoch=steps_per_epoch, steps_per_loop=steps_per_loop, - epochs=epochs) + epochs=epochs, + sub_model_export_name='pretrained/bert_model') - # Creates the BERT core model outside distribution strategy scope. - _, core_model = bert_models.pretrain_model(bert_config, max_seq_length, - max_predictions_per_seq) - - # Restores the core model from model checkpoints and get a new checkpoint only - # contains the core model. - model_saving_utils.export_pretraining_checkpoint( - checkpoint_dir=model_dir, model=core_model) return trained_model diff --git a/official/nlp/bert_models.py b/official/nlp/bert_models.py index 2e948c07fcc01ddb11704109f73f6f7793446023..a0aad35a625932743894740cd93d54239dfe878b 100644 --- a/official/nlp/bert_models.py +++ b/official/nlp/bert_models.py @@ -18,139 +18,26 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy import tensorflow as tf import tensorflow_hub as hub from official.modeling import tf_utils -from official.nlp import bert_modeling as modeling +from official.nlp.modeling import losses from official.nlp.modeling import networks from official.nlp.modeling.networks import bert_classifier +from official.nlp.modeling.networks import bert_pretrainer from official.nlp.modeling.networks import bert_span_labeler -def gather_indexes(sequence_tensor, positions): - """Gathers the vectors at the specific positions. - - Args: - sequence_tensor: Sequence output of `BertModel` layer of shape - (`batch_size`, `seq_length`, num_hidden) where num_hidden is number of - hidden units of `BertModel` layer. - positions: Positions ids of tokens in sequence to mask for pretraining of - with dimension (batch_size, max_predictions_per_seq) where - `max_predictions_per_seq` is maximum number of tokens to mask out and - predict per each sequence. - - Returns: - Masked out sequence tensor of shape (batch_size * max_predictions_per_seq, - num_hidden). - """ - sequence_shape = tf_utils.get_shape_list( - sequence_tensor, name='sequence_output_tensor') - batch_size = sequence_shape[0] - seq_length = sequence_shape[1] - width = sequence_shape[2] - - flat_offsets = tf.keras.backend.reshape( - tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) - flat_positions = tf.keras.backend.reshape(positions + flat_offsets, [-1]) - flat_sequence_tensor = tf.keras.backend.reshape( - sequence_tensor, [batch_size * seq_length, width]) - output_tensor = tf.gather(flat_sequence_tensor, flat_positions) - - return output_tensor - - -class BertPretrainLayer(tf.keras.layers.Layer): - """Wrapper layer for pre-training a BERT model. - - This layer wraps an existing `bert_layer` which is a Keras Layer. - It outputs `sequence_output` from TransformerBlock sub-layer and - `sentence_output` which are suitable for feeding into a BertPretrainLoss - layer. This layer can be used along with an unsupervised input to - pre-train the embeddings for `bert_layer`. - """ - - def __init__(self, - config, - bert_layer, - initializer=None, - float_type=tf.float32, - **kwargs): - super(BertPretrainLayer, self).__init__(**kwargs) - self.config = copy.deepcopy(config) - self.float_type = float_type - - self.embedding_table = bert_layer.embedding_lookup.embeddings - self.num_next_sentence_label = 2 - if initializer: - self.initializer = initializer - else: - self.initializer = tf.keras.initializers.TruncatedNormal( - stddev=self.config.initializer_range) - - def build(self, unused_input_shapes): - """Implements build() for the layer.""" - self.output_bias = self.add_weight( - shape=[self.config.vocab_size], - name='predictions/output_bias', - initializer=tf.keras.initializers.Zeros()) - self.lm_dense = tf.keras.layers.Dense( - self.config.hidden_size, - activation=tf_utils.get_activation(self.config.hidden_act), - kernel_initializer=self.initializer, - name='predictions/transform/dense') - self.lm_layer_norm = tf.keras.layers.LayerNormalization( - axis=-1, epsilon=1e-12, name='predictions/transform/LayerNorm') - - # Next sentence binary classification dense layer including bias to match - # TF1.x BERT variable shapes. - with tf.name_scope('seq_relationship'): - self.next_seq_weights = self.add_weight( - shape=[self.num_next_sentence_label, self.config.hidden_size], - name='output_weights', - initializer=self.initializer) - self.next_seq_bias = self.add_weight( - shape=[self.num_next_sentence_label], - name='output_bias', - initializer=tf.keras.initializers.Zeros()) - super(BertPretrainLayer, self).build(unused_input_shapes) - - def __call__(self, - pooled_output, - sequence_output=None, - masked_lm_positions=None, - **kwargs): - inputs = tf_utils.pack_inputs( - [pooled_output, sequence_output, masked_lm_positions]) - return super(BertPretrainLayer, self).__call__(inputs, **kwargs) - - def call(self, inputs): - """Implements call() for the layer.""" - unpacked_inputs = tf_utils.unpack_inputs(inputs) - pooled_output = unpacked_inputs[0] - sequence_output = unpacked_inputs[1] - masked_lm_positions = unpacked_inputs[2] - - mask_lm_input_tensor = gather_indexes(sequence_output, masked_lm_positions) - lm_output = self.lm_dense(mask_lm_input_tensor) - lm_output = self.lm_layer_norm(lm_output) - lm_output = tf.matmul(lm_output, self.embedding_table, transpose_b=True) - lm_output = tf.nn.bias_add(lm_output, self.output_bias) - lm_output = tf.nn.log_softmax(lm_output, axis=-1) - - logits = tf.matmul(pooled_output, self.next_seq_weights, transpose_b=True) - logits = tf.nn.bias_add(logits, self.next_seq_bias) - sentence_output = tf.nn.log_softmax(logits, axis=-1) - return (lm_output, sentence_output) - - class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): """Returns layer that computes custom loss and metrics for pretraining.""" - def __init__(self, bert_config, **kwargs): + def __init__(self, vocab_size, **kwargs): super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs) - self.config = copy.deepcopy(bert_config) + self._vocab_size = vocab_size + self.config = { + 'vocab_size': vocab_size, + } def __call__(self, lm_output, @@ -167,8 +54,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): self).__call__(inputs, **kwargs) def _add_metrics(self, lm_output, lm_labels, lm_label_weights, - lm_per_example_loss, sentence_output, sentence_labels, - sentence_per_example_loss): + lm_example_loss, sentence_output, sentence_labels, + next_sentence_loss): """Adds metrics.""" masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy( lm_labels, lm_output) @@ -178,8 +65,6 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): self.add_metric( masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean') - lm_example_loss = tf.reshape(lm_per_example_loss, [-1]) - lm_example_loss = tf.reduce_mean(lm_example_loss * lm_label_weights) self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean') next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy( @@ -189,9 +74,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): name='next_sentence_accuracy', aggregation='mean') - next_sentence_mean_loss = tf.reduce_mean(sentence_per_example_loss) self.add_metric( - next_sentence_mean_loss, name='next_sentence_loss', aggregation='mean') + next_sentence_loss, name='next_sentence_loss', aggregation='mean') def call(self, inputs): """Implements call() for the layer.""" @@ -199,31 +83,21 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): lm_output = unpacked_inputs[0] sentence_output = unpacked_inputs[1] lm_label_ids = unpacked_inputs[2] - lm_label_ids = tf.keras.backend.reshape(lm_label_ids, [-1]) - lm_label_ids_one_hot = tf.keras.backend.one_hot(lm_label_ids, - self.config.vocab_size) lm_label_weights = tf.keras.backend.cast(unpacked_inputs[3], tf.float32) - lm_label_weights = tf.keras.backend.reshape(lm_label_weights, [-1]) - lm_per_example_loss = -tf.keras.backend.sum( - lm_output * lm_label_ids_one_hot, axis=[-1]) - numerator = tf.keras.backend.sum(lm_label_weights * lm_per_example_loss) - denominator = tf.keras.backend.sum(lm_label_weights) + 1e-5 - mask_label_loss = numerator / denominator - sentence_labels = unpacked_inputs[4] - sentence_labels = tf.keras.backend.reshape(sentence_labels, [-1]) - sentence_label_one_hot = tf.keras.backend.one_hot(sentence_labels, 2) - per_example_loss_sentence = -tf.keras.backend.sum( - sentence_label_one_hot * sentence_output, axis=-1) - sentence_loss = tf.keras.backend.mean(per_example_loss_sentence) + + mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss( + labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights) + sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss( + labels=sentence_labels, predictions=sentence_output) loss = mask_label_loss + sentence_loss + batch_shape = tf.slice(tf.keras.backend.shape(sentence_labels), [0], [1]) # TODO(hongkuny): Avoids the hack and switches add_loss. - final_loss = tf.fill( - tf.keras.backend.shape(per_example_loss_sentence), loss) + final_loss = tf.fill(batch_shape, loss) self._add_metrics(lm_output, lm_label_ids, lm_label_weights, - lm_per_example_loss, sentence_output, sentence_labels, - per_example_loss_sentence) + mask_label_loss, sentence_output, sentence_labels, + sentence_loss) return final_loss @@ -268,13 +142,12 @@ def pretrain_model(bert_config, seq_length: Maximum sequence length of the training data. max_predictions_per_seq: Maximum number of tokens in sequence to mask out and use for pretraining. - initializer: Initializer for weights in BertPretrainLayer. + initializer: Initializer for weights in BertPretrainer. Returns: Pretraining model as well as core BERT submodel from which to save weights after pretraining. """ - input_word_ids = tf.keras.layers.Input( shape=(seq_length,), name='input_word_ids', dtype=tf.int32) input_mask = tf.keras.layers.Input( @@ -285,38 +158,34 @@ def pretrain_model(bert_config, shape=(max_predictions_per_seq,), name='masked_lm_positions', dtype=tf.int32) + masked_lm_ids = tf.keras.layers.Input( + shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32) masked_lm_weights = tf.keras.layers.Input( shape=(max_predictions_per_seq,), name='masked_lm_weights', dtype=tf.int32) next_sentence_labels = tf.keras.layers.Input( shape=(1,), name='next_sentence_labels', dtype=tf.int32) - masked_lm_ids = tf.keras.layers.Input( - shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32) - bert_submodel_name = 'bert_model' - bert_submodel = modeling.get_bert_model( - input_word_ids, - input_mask, - input_type_ids, - name=bert_submodel_name, - config=bert_config) - pooled_output = bert_submodel.outputs[0] - sequence_output = bert_submodel.outputs[1] - - pretrain_layer = BertPretrainLayer( - bert_config, - bert_submodel.get_layer(bert_submodel_name), + transformer_encoder = _get_transformer_encoder(bert_config, seq_length) + if initializer is None: + initializer = tf.keras.initializers.TruncatedNormal( + stddev=bert_config.initializer_range) + pretrainer_model = bert_pretrainer.BertPretrainer( + network=transformer_encoder, + num_classes=2, # The next sentence prediction label has two classes. + num_token_predictions=max_predictions_per_seq, initializer=initializer, - name='cls') - lm_output, sentence_output = pretrain_layer(pooled_output, sequence_output, - masked_lm_positions) + output='predictions') + + lm_output, sentence_output = pretrainer_model( + [input_word_ids, input_mask, input_type_ids, masked_lm_positions]) - pretrain_loss_layer = BertPretrainLossAndMetricLayer(bert_config) + pretrain_loss_layer = BertPretrainLossAndMetricLayer( + vocab_size=bert_config.vocab_size) output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids, masked_lm_weights, next_sentence_labels) - - return tf.keras.Model( + keras_model = tf.keras.Model( inputs={ 'input_word_ids': input_word_ids, 'input_mask': input_mask, @@ -326,7 +195,8 @@ def pretrain_model(bert_config, 'masked_lm_weights': masked_lm_weights, 'next_sentence_labels': next_sentence_labels, }, - outputs=output_loss), bert_submodel + outputs=output_loss) + return keras_model, transformer_encoder class BertSquadLogitsLayer(tf.keras.layers.Layer):