From ad1a37c9b6935bd2c48f8e74cf65a40b8037a2c6 Mon Sep 17 00:00:00 2001 From: Hongkun Yu Date: Fri, 11 Oct 2019 10:36:21 -0700 Subject: [PATCH] Internal changes. PiperOrigin-RevId: 274201399 --- .../bert/tf1_to_keras_checkpoint_converter.py | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/official/nlp/bert/tf1_to_keras_checkpoint_converter.py b/official/nlp/bert/tf1_to_keras_checkpoint_converter.py index 962ad96a1..3f0b13467 100644 --- a/official/nlp/bert/tf1_to_keras_checkpoint_converter.py +++ b/official/nlp/bert/tf1_to_keras_checkpoint_converter.py @@ -50,6 +50,8 @@ flags.DEFINE_integer( "The number of attention heads, used to reshape variables. If it is -1, " "we do not reshape variables." ) +flags.DEFINE_boolean("use_v2_names", False, + "Whether to use BERT_V2_NAME_REPLACEMENTS.") # Mapping between old <=> new names. The source pattern in original variable # name will be replaced by destination pattern. @@ -70,9 +72,30 @@ BERT_NAME_REPLACEMENTS = [ ("pooler/dense", "pooler_transform"), ] +BERT_V2_NAME_REPLACEMENTS = [ + ("bert/", ""), + ("encoder", "transformer"), + ("embeddings/word_embeddings", "word_embeddings/embeddings"), + ("embeddings/token_type_embeddings", "type_embeddings/embeddings"), + ("embeddings/position_embeddings", "position_embedding/embeddings"), + ("embeddings/LayerNorm", "embeddings/layer_norm"), + ("attention/self", "self_attention"), + ("attention/output/dense", "self_attention_output"), + ("attention/output/LayerNorm", "self_attention_layer_norm"), + ("intermediate/dense", "intermediate"), + ("output/dense", "output"), + ("output/LayerNorm", "output_layer_norm"), + ("pooler/dense", "pooler_transform"), +] + def _bert_name_replacement(var_name): - for src_pattern, tgt_pattern in BERT_NAME_REPLACEMENTS: + """Gets the variable name replacement.""" + if FLAGS.use_v2_names: + name_replacements = BERT_V2_NAME_REPLACEMENTS + else: + name_replacements = BERT_NAME_REPLACEMENTS + for src_pattern, tgt_pattern in name_replacements: if src_pattern in var_name: old_var_name = var_name var_name = var_name.replace(src_pattern, tgt_pattern) -- GitLab