Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
49ba237d
M
Models
项目概览
曾经的那一瞬间
/
Models
大约 1 年 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
49ba237d
编写于
10月 18, 2019
作者:
A
A. Unique TensorFlower
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Internal change
PiperOrigin-RevId: 275578662
上级
e6750c5d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
214 addition
and
163 deletion
+214
-163
official/nlp/bert/tf1_checkpoint_converter_lib.py
official/nlp/bert/tf1_checkpoint_converter_lib.py
+195
-0
official/nlp/bert/tf1_to_keras_checkpoint_converter.py
official/nlp/bert/tf1_to_keras_checkpoint_converter.py
+19
-163
未找到文件。
official/nlp/bert/tf1_checkpoint_converter_lib.py
0 → 100644
浏览文件 @
49ba237d
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Convert checkpoints created by Estimator (tf1) to be Keras compatible."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
# TF 1.x
# Mapping between old <=> new names. The source pattern in original variable
# name will be replaced by destination pattern.
BERT_NAME_REPLACEMENTS
=
(
(
"bert"
,
"bert_model"
),
(
"embeddings/word_embeddings"
,
"word_embeddings/embeddings"
),
(
"embeddings/token_type_embeddings"
,
"embedding_postprocessor/type_embeddings"
),
(
"embeddings/position_embeddings"
,
"embedding_postprocessor/position_embeddings"
),
(
"embeddings/LayerNorm"
,
"embedding_postprocessor/layer_norm"
),
(
"attention/self"
,
"self_attention"
),
(
"attention/output/dense"
,
"self_attention_output"
),
(
"attention/output/LayerNorm"
,
"self_attention_layer_norm"
),
(
"intermediate/dense"
,
"intermediate"
),
(
"output/dense"
,
"output"
),
(
"output/LayerNorm"
,
"output_layer_norm"
),
(
"pooler/dense"
,
"pooler_transform"
),
)
BERT_V2_NAME_REPLACEMENTS
=
(
(
"bert/"
,
""
),
(
"encoder"
,
"transformer"
),
(
"embeddings/word_embeddings"
,
"word_embeddings/embeddings"
),
(
"embeddings/token_type_embeddings"
,
"type_embeddings/embeddings"
),
(
"embeddings/position_embeddings"
,
"position_embedding/embeddings"
),
(
"embeddings/LayerNorm"
,
"embeddings/layer_norm"
),
(
"attention/self"
,
"self_attention"
),
(
"attention/output/dense"
,
"self_attention_output"
),
(
"attention/output/LayerNorm"
,
"self_attention_layer_norm"
),
(
"intermediate/dense"
,
"intermediate"
),
(
"output/dense"
,
"output"
),
(
"output/LayerNorm"
,
"output_layer_norm"
),
(
"pooler/dense"
,
"pooler_transform"
),
(
"cls/predictions/output_bias"
,
"cls/predictions/output_bias/bias"
),
(
"cls/seq_relationship/output_bias"
,
"predictions/transform/logits/bias"
),
(
"cls/seq_relationship/output_weights"
,
"predictions/transform/logits/kernel"
),
)
BERT_PERMUTATIONS
=
()
BERT_V2_PERMUTATIONS
=
((
"cls/seq_relationship/output_weights"
,
(
1
,
0
)),)
def
_bert_name_replacement
(
var_name
,
name_replacements
):
"""Gets the variable name replacement."""
for
src_pattern
,
tgt_pattern
in
name_replacements
:
if
src_pattern
in
var_name
:
old_var_name
=
var_name
var_name
=
var_name
.
replace
(
src_pattern
,
tgt_pattern
)
tf
.
logging
.
info
(
"Converted: %s --> %s"
,
old_var_name
,
var_name
)
return
var_name
def
_has_exclude_patterns
(
name
,
exclude_patterns
):
"""Checks if a string contains substrings that match patterns to exclude."""
for
p
in
exclude_patterns
:
if
p
in
name
:
return
True
return
False
def
_get_permutation
(
name
,
permutations
):
"""Checks whether a variable requires transposition by pattern matching."""
for
src_pattern
,
permutation
in
permutations
:
if
src_pattern
in
name
:
tf
.
logging
.
info
(
"Permuted: %s --> %s"
,
name
,
permutation
)
return
permutation
return
None
def
_get_new_shape
(
name
,
shape
,
num_heads
):
"""Checks whether a variable requires reshape by pattern matching."""
if
"attention/output/dense/kernel"
in
name
:
return
tuple
([
num_heads
,
shape
[
0
]
//
num_heads
,
shape
[
1
]])
if
"attention/output/dense/bias"
in
name
:
return
shape
patterns
=
[
"attention/self/query"
,
"attention/self/value"
,
"attention/self/key"
]
for
pattern
in
patterns
:
if
pattern
in
name
:
if
"kernel"
in
name
:
return
tuple
([
shape
[
0
],
num_heads
,
shape
[
1
]
//
num_heads
])
if
"bias"
in
name
:
return
tuple
([
num_heads
,
shape
[
0
]
//
num_heads
])
return
None
def
create_v2_checkpoint
(
model
,
src_checkpoint
,
output_path
):
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
model
.
load_weights
(
src_checkpoint
).
assert_existing_objects_matched
()
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
checkpoint
.
save
(
output_path
)
def
convert
(
checkpoint_from_path
,
checkpoint_to_path
,
num_heads
,
name_replacements
,
permutations
,
exclude_patterns
=
None
):
"""Migrates the names of variables within a checkpoint.
Args:
checkpoint_from_path: Path to source checkpoint to be read in.
checkpoint_to_path: Path to checkpoint to be written out.
num_heads: The number of heads of the model.
name_replacements: A list of tuples of the form (match_str, replace_str)
describing variable names to adjust.
permutations: A list of tuples of the form (match_str, permutation)
describing permutations to apply to given variables. Note that match_str
should match the original variable name, not the replaced one.
exclude_patterns: A list of string patterns to exclude variables from
checkpoint conversion.
Returns:
A dictionary that maps the new variable names to the Variable objects.
A dictionary that maps the old variable names to the new variable names.
"""
with
tf
.
Graph
().
as_default
():
tf
.
logging
.
info
(
"Reading checkpoint_from_path %s"
,
checkpoint_from_path
)
reader
=
tf
.
train
.
NewCheckpointReader
(
checkpoint_from_path
)
name_shape_map
=
reader
.
get_variable_to_shape_map
()
new_variable_map
=
{}
conversion_map
=
{}
for
var_name
in
name_shape_map
:
if
exclude_patterns
and
_has_exclude_patterns
(
var_name
,
exclude_patterns
):
continue
# Get the original tensor data.
tensor
=
reader
.
get_tensor
(
var_name
)
# Look up the new variable name, if any.
new_var_name
=
_bert_name_replacement
(
var_name
,
name_replacements
)
# See if we need to reshape the underlying tensor.
new_shape
=
None
if
num_heads
>
0
:
new_shape
=
_get_new_shape
(
var_name
,
tensor
.
shape
,
num_heads
)
if
new_shape
:
tf
.
logging
.
info
(
"Veriable %s has a shape change from %s to %s"
,
var_name
,
tensor
.
shape
,
new_shape
)
tensor
=
np
.
reshape
(
tensor
,
new_shape
)
# See if we need to permute the underlying tensor.
permutation
=
_get_permutation
(
var_name
,
permutations
)
if
permutation
:
tensor
=
np
.
transpose
(
tensor
,
permutation
)
# Create a new variable with the possibly-reshaped or transposed tensor.
var
=
tf
.
Variable
(
tensor
,
name
=
var_name
)
# Save the variable into the new variable map.
new_variable_map
[
new_var_name
]
=
var
# Keep a list of converter variables for sanity checking.
if
new_var_name
!=
var_name
:
conversion_map
[
var_name
]
=
new_var_name
saver
=
tf
.
train
.
Saver
(
new_variable_map
)
with
tf
.
Session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
tf
.
logging
.
info
(
"Writing checkpoint_to_path %s"
,
checkpoint_to_path
)
saver
.
save
(
sess
,
checkpoint_to_path
)
tf
.
logging
.
info
(
"Summary:"
)
tf
.
logging
.
info
(
" Converted %d variable name(s)."
,
len
(
new_variable_map
))
tf
.
logging
.
info
(
" Converted: %s"
,
str
(
conversion_map
))
official/nlp/bert/tf1_to_keras_checkpoint_converter.py
浏览文件 @
49ba237d
...
...
@@ -29,8 +29,10 @@ from __future__ import division
from
__future__
import
print_function
from
absl
import
app
import
numpy
as
np
import
tensorflow
as
tf
# TF 1.x
from
third_party.tensorflow_models.official.nlp.bert
import
tf1_checkpoint_converter_lib
flags
=
tf
.
flags
...
...
@@ -50,174 +52,28 @@ flags.DEFINE_integer(
"The number of attention heads, used to reshape variables. If it is -1, "
"we do not reshape variables."
)
flags
.
DEFINE_boolean
(
"use_v2_names"
,
False
,
"Whether to use BERT_V2_NAME_REPLACEMENTS."
)
# Mapping between old <=> new names. The source pattern in original variable
# name will be replaced by destination pattern.
BERT_NAME_REPLACEMENTS
=
[
(
"bert"
,
"bert_model"
),
(
"embeddings/word_embeddings"
,
"word_embeddings/embeddings"
),
(
"embeddings/token_type_embeddings"
,
"embedding_postprocessor/type_embeddings"
),
(
"embeddings/position_embeddings"
,
"embedding_postprocessor/position_embeddings"
),
(
"embeddings/LayerNorm"
,
"embedding_postprocessor/layer_norm"
),
(
"attention/self"
,
"self_attention"
),
(
"attention/output/dense"
,
"self_attention_output"
),
(
"attention/output/LayerNorm"
,
"self_attention_layer_norm"
),
(
"intermediate/dense"
,
"intermediate"
),
(
"output/dense"
,
"output"
),
(
"output/LayerNorm"
,
"output_layer_norm"
),
(
"pooler/dense"
,
"pooler_transform"
),
]
BERT_V2_NAME_REPLACEMENTS
=
[
(
"bert/"
,
""
),
(
"encoder"
,
"transformer"
),
(
"embeddings/word_embeddings"
,
"word_embeddings/embeddings"
),
(
"embeddings/token_type_embeddings"
,
"type_embeddings/embeddings"
),
(
"embeddings/position_embeddings"
,
"position_embedding/embeddings"
),
(
"embeddings/LayerNorm"
,
"embeddings/layer_norm"
),
(
"attention/self"
,
"self_attention"
),
(
"attention/output/dense"
,
"self_attention_output"
),
(
"attention/output/LayerNorm"
,
"self_attention_layer_norm"
),
(
"intermediate/dense"
,
"intermediate"
),
(
"output/dense"
,
"output"
),
(
"output/LayerNorm"
,
"output_layer_norm"
),
(
"pooler/dense"
,
"pooler_transform"
),
(
"cls/predictions/output_bias"
,
"cls/predictions/output_bias/bias"
),
(
"cls/seq_relationship/output_bias"
,
"predictions/transform/logits/bias"
),
(
"cls/seq_relationship/output_weights"
,
"predictions/transform/logits/kernel"
),
]
def
_bert_name_replacement
(
var_name
):
"""Gets the variable name replacement."""
if
FLAGS
.
use_v2_names
:
name_replacements
=
BERT_V2_NAME_REPLACEMENTS
else
:
name_replacements
=
BERT_NAME_REPLACEMENTS
for
src_pattern
,
tgt_pattern
in
name_replacements
:
if
src_pattern
in
var_name
:
old_var_name
=
var_name
var_name
=
var_name
.
replace
(
src_pattern
,
tgt_pattern
)
tf
.
logging
.
info
(
"Converted: %s --> %s"
,
old_var_name
,
var_name
)
return
var_name
def
_has_exclude_patterns
(
name
,
exclude_patterns
):
"""Checks if a string contains substrings that match patterns to exclude."""
for
p
in
exclude_patterns
:
if
p
in
name
:
return
True
return
False
def
_get_permutation
(
name
):
"""Checks whether a variable requires transposition by pattern matching."""
if
not
FLAGS
.
use_v2_names
:
return
None
if
"cls/seq_relationship/output_weights"
in
name
:
return
(
1
,
0
)
return
None
def
_get_new_shape
(
name
,
shape
,
num_heads
):
"""Checks whether a variable requires reshape by pattern matching."""
if
"attention/output/dense/kernel"
in
name
:
return
tuple
([
num_heads
,
shape
[
0
]
//
num_heads
,
shape
[
1
]])
if
"attention/output/dense/bias"
in
name
:
return
shape
patterns
=
[
"attention/self/query"
,
"attention/self/value"
,
"attention/self/key"
]
for
pattern
in
patterns
:
if
pattern
in
name
:
if
"kernel"
in
name
:
return
tuple
([
shape
[
0
],
num_heads
,
shape
[
1
]
//
num_heads
])
if
"bias"
in
name
:
return
tuple
([
num_heads
,
shape
[
0
]
//
num_heads
])
return
None
def
convert_names
(
checkpoint_from_path
,
checkpoint_to_path
,
exclude_patterns
=
None
):
"""Migrates the names of variables within a checkpoint.
Args:
checkpoint_from_path: Path to source checkpoint to be read in.
checkpoint_to_path: Path to checkpoint to be written out.
exclude_patterns: A list of string patterns to exclude variables from
checkpoint conversion.
Returns:
A dictionary that maps the new variable names to the Variable objects.
A dictionary that maps the old variable names to the new variable names.
"""
with
tf
.
Graph
().
as_default
():
tf
.
logging
.
info
(
"Reading checkpoint_from_path %s"
,
checkpoint_from_path
)
reader
=
tf
.
train
.
NewCheckpointReader
(
checkpoint_from_path
)
name_shape_map
=
reader
.
get_variable_to_shape_map
()
new_variable_map
=
{}
conversion_map
=
{}
for
var_name
in
name_shape_map
:
if
exclude_patterns
and
_has_exclude_patterns
(
var_name
,
exclude_patterns
):
continue
# Get the original tensor data.
tensor
=
reader
.
get_tensor
(
var_name
)
# Look up the new variable name, if any.
new_var_name
=
_bert_name_replacement
(
var_name
)
# See if we need to reshape the underlying tensor.
new_shape
=
None
if
FLAGS
.
num_heads
>
0
:
new_shape
=
_get_new_shape
(
var_name
,
tensor
.
shape
,
FLAGS
.
num_heads
)
if
new_shape
:
tf
.
logging
.
info
(
"Veriable %s has a shape change from %s to %s"
,
var_name
,
tensor
.
shape
,
new_shape
)
tensor
=
np
.
reshape
(
tensor
,
new_shape
)
# See if we need to permute the underlying tensor.
permutation
=
_get_permutation
(
var_name
)
if
permutation
:
tensor
=
np
.
transpose
(
tensor
,
permutation
)
# Create a new variable with the possibly-reshaped or transposed tensor.
var
=
tf
.
Variable
(
tensor
,
name
=
var_name
)
# Save the variable into the new variable map.
new_variable_map
[
new_var_name
]
=
var
# Keep a list of converter variables for sanity checking.
if
new_var_name
!=
var_name
:
conversion_map
[
var_name
]
=
new_var_name
saver
=
tf
.
train
.
Saver
(
new_variable_map
)
with
tf
.
Session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
tf
.
logging
.
info
(
"Writing checkpoint_to_path %s"
,
checkpoint_to_path
)
saver
.
save
(
sess
,
checkpoint_to_path
)
tf
.
logging
.
info
(
"Summary:"
)
tf
.
logging
.
info
(
" Converted %d variable name(s)."
,
len
(
new_variable_map
))
tf
.
logging
.
info
(
" Converted: %s"
,
str
(
conversion_map
))
flags
.
DEFINE_boolean
(
"create_v2_checkpoint"
,
False
,
"Whether to create a checkpoint compatible with KerasBERT V2 modeling code."
)
def
main
(
_
):
exclude_patterns
=
None
if
FLAGS
.
exclude_patterns
:
exclude_patterns
=
FLAGS
.
exclude_patterns
.
split
(
","
)
convert_names
(
FLAGS
.
checkpoint_from_path
,
FLAGS
.
checkpoint_to_path
,
exclude_patterns
)
if
FLAGS
.
create_v2_checkpoint
:
name_replacements
=
tf1_checkpoint_converter_lib
.
BERT_V2_NAME_REPLACEMENTS
permutations
=
tf1_checkpoint_converter_lib
.
BERT_V2_PERMUTATIONS
else
:
name_replacements
=
tf1_checkpoint_converter_lib
.
BERT_NAME_REPLACEMENTS
permutations
=
tf1_checkpoint_converter_lib
.
BERT_PERMUTATIONS
tf1_checkpoint_converter_lib
.
convert
(
FLAGS
.
checkpoint_from_path
,
FLAGS
.
checkpoint_to_path
,
FLAGS
.
num_heads
,
name_replacements
,
permutations
,
exclude_patterns
)
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录