提交 ad1a37c9 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Internal changes.

PiperOrigin-RevId: 274201399
上级 854fdb7f
......@@ -50,6 +50,8 @@ flags.DEFINE_integer(
"The number of attention heads, used to reshape variables. If it is -1, "
"we do not reshape variables."
)
flags.DEFINE_boolean("use_v2_names", False,
"Whether to use BERT_V2_NAME_REPLACEMENTS.")
# Mapping between old <=> new names. The source pattern in original variable
# name will be replaced by destination pattern.
......@@ -70,9 +72,30 @@ BERT_NAME_REPLACEMENTS = [
("pooler/dense", "pooler_transform"),
]
BERT_V2_NAME_REPLACEMENTS = [
("bert/", ""),
("encoder", "transformer"),
("embeddings/word_embeddings", "word_embeddings/embeddings"),
("embeddings/token_type_embeddings", "type_embeddings/embeddings"),
("embeddings/position_embeddings", "position_embedding/embeddings"),
("embeddings/LayerNorm", "embeddings/layer_norm"),
("attention/self", "self_attention"),
("attention/output/dense", "self_attention_output"),
("attention/output/LayerNorm", "self_attention_layer_norm"),
("intermediate/dense", "intermediate"),
("output/dense", "output"),
("output/LayerNorm", "output_layer_norm"),
("pooler/dense", "pooler_transform"),
]
def _bert_name_replacement(var_name):
for src_pattern, tgt_pattern in BERT_NAME_REPLACEMENTS:
"""Gets the variable name replacement."""
if FLAGS.use_v2_names:
name_replacements = BERT_V2_NAME_REPLACEMENTS
else:
name_replacements = BERT_NAME_REPLACEMENTS
for src_pattern, tgt_pattern in name_replacements:
if src_pattern in var_name:
old_var_name = var_name
var_name = var_name.replace(src_pattern, tgt_pattern)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册