From 49ba237d35d2a049be7bede596f4b29fd85cfe28 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 18 Oct 2019 17:40:38 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 275578662 --- .../nlp/bert/tf1_checkpoint_converter_lib.py | 195 ++++++++++++++++++ .../bert/tf1_to_keras_checkpoint_converter.py | 182 ++-------------- 2 files changed, 214 insertions(+), 163 deletions(-) create mode 100644 official/nlp/bert/tf1_checkpoint_converter_lib.py diff --git a/official/nlp/bert/tf1_checkpoint_converter_lib.py b/official/nlp/bert/tf1_checkpoint_converter_lib.py new file mode 100644 index 000000000..4f555d915 --- /dev/null +++ b/official/nlp/bert/tf1_checkpoint_converter_lib.py @@ -0,0 +1,195 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Convert checkpoints created by Estimator (tf1) to be Keras compatible.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow.compat.v1 as tf # TF 1.x + +# Mapping between old <=> new names. The source pattern in original variable +# name will be replaced by destination pattern. +BERT_NAME_REPLACEMENTS = ( + ("bert", "bert_model"), + ("embeddings/word_embeddings", "word_embeddings/embeddings"), + ("embeddings/token_type_embeddings", + "embedding_postprocessor/type_embeddings"), + ("embeddings/position_embeddings", + "embedding_postprocessor/position_embeddings"), + ("embeddings/LayerNorm", "embedding_postprocessor/layer_norm"), + ("attention/self", "self_attention"), + ("attention/output/dense", "self_attention_output"), + ("attention/output/LayerNorm", "self_attention_layer_norm"), + ("intermediate/dense", "intermediate"), + ("output/dense", "output"), + ("output/LayerNorm", "output_layer_norm"), + ("pooler/dense", "pooler_transform"), +) + +BERT_V2_NAME_REPLACEMENTS = ( + ("bert/", ""), + ("encoder", "transformer"), + ("embeddings/word_embeddings", "word_embeddings/embeddings"), + ("embeddings/token_type_embeddings", "type_embeddings/embeddings"), + ("embeddings/position_embeddings", "position_embedding/embeddings"), + ("embeddings/LayerNorm", "embeddings/layer_norm"), + ("attention/self", "self_attention"), + ("attention/output/dense", "self_attention_output"), + ("attention/output/LayerNorm", "self_attention_layer_norm"), + ("intermediate/dense", "intermediate"), + ("output/dense", "output"), + ("output/LayerNorm", "output_layer_norm"), + ("pooler/dense", "pooler_transform"), + ("cls/predictions/output_bias", "cls/predictions/output_bias/bias"), + ("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"), + ("cls/seq_relationship/output_weights", + "predictions/transform/logits/kernel"), +) + +BERT_PERMUTATIONS = () + +BERT_V2_PERMUTATIONS = (("cls/seq_relationship/output_weights", (1, 0)),) + + +def _bert_name_replacement(var_name, name_replacements): + """Gets the variable name replacement.""" + for src_pattern, tgt_pattern in name_replacements: + if src_pattern in var_name: + old_var_name = var_name + var_name = var_name.replace(src_pattern, tgt_pattern) + tf.logging.info("Converted: %s --> %s", old_var_name, var_name) + return var_name + + +def _has_exclude_patterns(name, exclude_patterns): + """Checks if a string contains substrings that match patterns to exclude.""" + for p in exclude_patterns: + if p in name: + return True + return False + + +def _get_permutation(name, permutations): + """Checks whether a variable requires transposition by pattern matching.""" + for src_pattern, permutation in permutations: + if src_pattern in name: + tf.logging.info("Permuted: %s --> %s", name, permutation) + return permutation + + return None + + +def _get_new_shape(name, shape, num_heads): + """Checks whether a variable requires reshape by pattern matching.""" + if "attention/output/dense/kernel" in name: + return tuple([num_heads, shape[0] // num_heads, shape[1]]) + if "attention/output/dense/bias" in name: + return shape + + patterns = [ + "attention/self/query", "attention/self/value", "attention/self/key" + ] + for pattern in patterns: + if pattern in name: + if "kernel" in name: + return tuple([shape[0], num_heads, shape[1] // num_heads]) + if "bias" in name: + return tuple([num_heads, shape[0] // num_heads]) + return None + + +def create_v2_checkpoint(model, src_checkpoint, output_path): + """Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint.""" + # Uses streaming-restore in eager model to read V1 name-based checkpoints. + model.load_weights(src_checkpoint).assert_existing_objects_matched() + checkpoint = tf.train.Checkpoint(model=model) + checkpoint.save(output_path) + + +def convert(checkpoint_from_path, + checkpoint_to_path, + num_heads, + name_replacements, + permutations, + exclude_patterns=None): + """Migrates the names of variables within a checkpoint. + + Args: + checkpoint_from_path: Path to source checkpoint to be read in. + checkpoint_to_path: Path to checkpoint to be written out. + num_heads: The number of heads of the model. + name_replacements: A list of tuples of the form (match_str, replace_str) + describing variable names to adjust. + permutations: A list of tuples of the form (match_str, permutation) + describing permutations to apply to given variables. Note that match_str + should match the original variable name, not the replaced one. + exclude_patterns: A list of string patterns to exclude variables from + checkpoint conversion. + + Returns: + A dictionary that maps the new variable names to the Variable objects. + A dictionary that maps the old variable names to the new variable names. + """ + with tf.Graph().as_default(): + tf.logging.info("Reading checkpoint_from_path %s", checkpoint_from_path) + reader = tf.train.NewCheckpointReader(checkpoint_from_path) + name_shape_map = reader.get_variable_to_shape_map() + new_variable_map = {} + conversion_map = {} + for var_name in name_shape_map: + if exclude_patterns and _has_exclude_patterns(var_name, exclude_patterns): + continue + # Get the original tensor data. + tensor = reader.get_tensor(var_name) + + # Look up the new variable name, if any. + new_var_name = _bert_name_replacement(var_name, name_replacements) + + # See if we need to reshape the underlying tensor. + new_shape = None + if num_heads > 0: + new_shape = _get_new_shape(var_name, tensor.shape, num_heads) + if new_shape: + tf.logging.info("Veriable %s has a shape change from %s to %s", + + var_name, tensor.shape, new_shape) + tensor = np.reshape(tensor, new_shape) + + # See if we need to permute the underlying tensor. + permutation = _get_permutation(var_name, permutations) + if permutation: + tensor = np.transpose(tensor, permutation) + + # Create a new variable with the possibly-reshaped or transposed tensor. + var = tf.Variable(tensor, name=var_name) + + # Save the variable into the new variable map. + new_variable_map[new_var_name] = var + + # Keep a list of converter variables for sanity checking. + if new_var_name != var_name: + conversion_map[var_name] = new_var_name + + saver = tf.train.Saver(new_variable_map) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + tf.logging.info("Writing checkpoint_to_path %s", checkpoint_to_path) + saver.save(sess, checkpoint_to_path) + + tf.logging.info("Summary:") + tf.logging.info(" Converted %d variable name(s).", len(new_variable_map)) + tf.logging.info(" Converted: %s", str(conversion_map)) diff --git a/official/nlp/bert/tf1_to_keras_checkpoint_converter.py b/official/nlp/bert/tf1_to_keras_checkpoint_converter.py index e0a2fa87e..398f14c0f 100644 --- a/official/nlp/bert/tf1_to_keras_checkpoint_converter.py +++ b/official/nlp/bert/tf1_to_keras_checkpoint_converter.py @@ -29,8 +29,10 @@ from __future__ import division from __future__ import print_function from absl import app -import numpy as np + import tensorflow as tf # TF 1.x +from third_party.tensorflow_models.official.nlp.bert import tf1_checkpoint_converter_lib + flags = tf.flags @@ -50,174 +52,28 @@ flags.DEFINE_integer( "The number of attention heads, used to reshape variables. If it is -1, " "we do not reshape variables." ) -flags.DEFINE_boolean("use_v2_names", False, - "Whether to use BERT_V2_NAME_REPLACEMENTS.") - -# Mapping between old <=> new names. The source pattern in original variable -# name will be replaced by destination pattern. -BERT_NAME_REPLACEMENTS = [ - ("bert", "bert_model"), - ("embeddings/word_embeddings", "word_embeddings/embeddings"), - ("embeddings/token_type_embeddings", - "embedding_postprocessor/type_embeddings"), - ("embeddings/position_embeddings", - "embedding_postprocessor/position_embeddings"), - ("embeddings/LayerNorm", "embedding_postprocessor/layer_norm"), - ("attention/self", "self_attention"), - ("attention/output/dense", "self_attention_output"), - ("attention/output/LayerNorm", "self_attention_layer_norm"), - ("intermediate/dense", "intermediate"), - ("output/dense", "output"), - ("output/LayerNorm", "output_layer_norm"), - ("pooler/dense", "pooler_transform"), -] - -BERT_V2_NAME_REPLACEMENTS = [ - ("bert/", ""), - ("encoder", "transformer"), - ("embeddings/word_embeddings", "word_embeddings/embeddings"), - ("embeddings/token_type_embeddings", "type_embeddings/embeddings"), - ("embeddings/position_embeddings", "position_embedding/embeddings"), - ("embeddings/LayerNorm", "embeddings/layer_norm"), - ("attention/self", "self_attention"), - ("attention/output/dense", "self_attention_output"), - ("attention/output/LayerNorm", "self_attention_layer_norm"), - ("intermediate/dense", "intermediate"), - ("output/dense", "output"), - ("output/LayerNorm", "output_layer_norm"), - ("pooler/dense", "pooler_transform"), - ("cls/predictions/output_bias", "cls/predictions/output_bias/bias"), - ("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"), - ("cls/seq_relationship/output_weights", - "predictions/transform/logits/kernel"), -] - - -def _bert_name_replacement(var_name): - """Gets the variable name replacement.""" - if FLAGS.use_v2_names: - name_replacements = BERT_V2_NAME_REPLACEMENTS - else: - name_replacements = BERT_NAME_REPLACEMENTS - for src_pattern, tgt_pattern in name_replacements: - if src_pattern in var_name: - old_var_name = var_name - var_name = var_name.replace(src_pattern, tgt_pattern) - tf.logging.info("Converted: %s --> %s", old_var_name, var_name) - return var_name - - -def _has_exclude_patterns(name, exclude_patterns): - """Checks if a string contains substrings that match patterns to exclude.""" - for p in exclude_patterns: - if p in name: - return True - return False - - -def _get_permutation(name): - """Checks whether a variable requires transposition by pattern matching.""" - if not FLAGS.use_v2_names: - return None - - if "cls/seq_relationship/output_weights" in name: - return (1, 0) - - return None - - -def _get_new_shape(name, shape, num_heads): - """Checks whether a variable requires reshape by pattern matching.""" - if "attention/output/dense/kernel" in name: - return tuple([num_heads, shape[0] // num_heads, shape[1]]) - if "attention/output/dense/bias" in name: - return shape - - patterns = [ - "attention/self/query", "attention/self/value", "attention/self/key" - ] - for pattern in patterns: - if pattern in name: - if "kernel" in name: - return tuple([shape[0], num_heads, shape[1] // num_heads]) - if "bias" in name: - return tuple([num_heads, shape[0] // num_heads]) - return None - - -def convert_names(checkpoint_from_path, - checkpoint_to_path, - exclude_patterns=None): - """Migrates the names of variables within a checkpoint. - - Args: - checkpoint_from_path: Path to source checkpoint to be read in. - checkpoint_to_path: Path to checkpoint to be written out. - exclude_patterns: A list of string patterns to exclude variables from - checkpoint conversion. - - Returns: - A dictionary that maps the new variable names to the Variable objects. - A dictionary that maps the old variable names to the new variable names. - """ - with tf.Graph().as_default(): - tf.logging.info("Reading checkpoint_from_path %s", checkpoint_from_path) - reader = tf.train.NewCheckpointReader(checkpoint_from_path) - name_shape_map = reader.get_variable_to_shape_map() - new_variable_map = {} - conversion_map = {} - for var_name in name_shape_map: - if exclude_patterns and _has_exclude_patterns(var_name, exclude_patterns): - continue - # Get the original tensor data. - tensor = reader.get_tensor(var_name) - - # Look up the new variable name, if any. - new_var_name = _bert_name_replacement(var_name) - - # See if we need to reshape the underlying tensor. - new_shape = None - if FLAGS.num_heads > 0: - new_shape = _get_new_shape(var_name, tensor.shape, FLAGS.num_heads) - if new_shape: - tf.logging.info("Veriable %s has a shape change from %s to %s", - - var_name, tensor.shape, new_shape) - tensor = np.reshape(tensor, new_shape) - - # See if we need to permute the underlying tensor. - permutation = _get_permutation(var_name) - if permutation: - tensor = np.transpose(tensor, permutation) - - # Create a new variable with the possibly-reshaped or transposed tensor. - var = tf.Variable(tensor, name=var_name) - - # Save the variable into the new variable map. - new_variable_map[new_var_name] = var - - # Keep a list of converter variables for sanity checking. - if new_var_name != var_name: - conversion_map[var_name] = new_var_name - - saver = tf.train.Saver(new_variable_map) - - with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - tf.logging.info("Writing checkpoint_to_path %s", checkpoint_to_path) - saver.save(sess, checkpoint_to_path) - - tf.logging.info("Summary:") - tf.logging.info(" Converted %d variable name(s).", len(new_variable_map)) - tf.logging.info(" Converted: %s", str(conversion_map)) +flags.DEFINE_boolean( + "create_v2_checkpoint", False, + "Whether to create a checkpoint compatible with KerasBERT V2 modeling code." +) def main(_): exclude_patterns = None if FLAGS.exclude_patterns: exclude_patterns = FLAGS.exclude_patterns.split(",") - convert_names(FLAGS.checkpoint_from_path, FLAGS.checkpoint_to_path, - exclude_patterns) + + if FLAGS.create_v2_checkpoint: + name_replacements = tf1_checkpoint_converter_lib.BERT_V2_NAME_REPLACEMENTS + permutations = tf1_checkpoint_converter_lib.BERT_V2_PERMUTATIONS + else: + name_replacements = tf1_checkpoint_converter_lib.BERT_NAME_REPLACEMENTS + permutations = tf1_checkpoint_converter_lib.BERT_PERMUTATIONS + + tf1_checkpoint_converter_lib.convert(FLAGS.checkpoint_from_path, + FLAGS.checkpoint_to_path, + FLAGS.num_heads, name_replacements, + permutations, exclude_patterns) if __name__ == "__main__": -- GitLab