diff --git a/official/nlp/bert/tf1_to_keras_checkpoint_converter.py b/official/nlp/bert/tf1_to_keras_checkpoint_converter.py index 962ad96a1d891b59354092ebf6df65a67ea0cdd1..3f0b13467beb02f11ca60a0d1be2fa1525d7f322 100644 --- a/official/nlp/bert/tf1_to_keras_checkpoint_converter.py +++ b/official/nlp/bert/tf1_to_keras_checkpoint_converter.py @@ -50,6 +50,8 @@ flags.DEFINE_integer( "The number of attention heads, used to reshape variables. If it is -1, " "we do not reshape variables." ) +flags.DEFINE_boolean("use_v2_names", False, + "Whether to use BERT_V2_NAME_REPLACEMENTS.") # Mapping between old <=> new names. The source pattern in original variable # name will be replaced by destination pattern. @@ -70,9 +72,30 @@ BERT_NAME_REPLACEMENTS = [ ("pooler/dense", "pooler_transform"), ] +BERT_V2_NAME_REPLACEMENTS = [ + ("bert/", ""), + ("encoder", "transformer"), + ("embeddings/word_embeddings", "word_embeddings/embeddings"), + ("embeddings/token_type_embeddings", "type_embeddings/embeddings"), + ("embeddings/position_embeddings", "position_embedding/embeddings"), + ("embeddings/LayerNorm", "embeddings/layer_norm"), + ("attention/self", "self_attention"), + ("attention/output/dense", "self_attention_output"), + ("attention/output/LayerNorm", "self_attention_layer_norm"), + ("intermediate/dense", "intermediate"), + ("output/dense", "output"), + ("output/LayerNorm", "output_layer_norm"), + ("pooler/dense", "pooler_transform"), +] + def _bert_name_replacement(var_name): - for src_pattern, tgt_pattern in BERT_NAME_REPLACEMENTS: + """Gets the variable name replacement.""" + if FLAGS.use_v2_names: + name_replacements = BERT_V2_NAME_REPLACEMENTS + else: + name_replacements = BERT_NAME_REPLACEMENTS + for src_pattern, tgt_pattern in name_replacements: if src_pattern in var_name: old_var_name = var_name var_name = var_name.replace(src_pattern, tgt_pattern)