Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
diluosixu
bert
提交
bee6030e
B
bert
项目概览
diluosixu
/
bert
与 Fork 源项目一致
从无法访问的项目Fork
通知
4
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
B
bert
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
bee6030e
编写于
2月 07, 2019
作者:
J
Jacob Devlin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Adding TF Hub support
上级
f39e881b
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
294 addition
and
16 deletion
+294
-16
README.md
README.md
+5
-0
create_pretraining_data.py
create_pretraining_data.py
+1
-1
modeling.py
modeling.py
+12
-14
run_classifier_with_tfhub.py
run_classifier_with_tfhub.py
+275
-0
tokenization_test.py
tokenization_test.py
+1
-1
未找到文件。
README.md
浏览文件 @
bee6030e
# BERT
**\*\*\*\*\* New February 7th, 2019: TfHub Module \*\*\*\*\***
BERT has been uploaded to
[
TensorFlow Hub
](
https://tfhub.dev
)
. See
`run_classifier_with_tfhub.py`
for an example of how to use the TF Hub module.
**
\*\*\*\*\*
New November 23rd, 2018: Un-normalized multilingual model + Thai +
Mongolian
\*\*\*\*\*
**
...
...
create_pretraining_data.py
浏览文件 @
bee6030e
...
...
@@ -20,8 +20,8 @@ from __future__ import print_function
import
collections
import
random
import
tensorflow
as
tf
import
tokenization
import
tensorflow
as
tf
flags
=
tf
.
flags
...
...
modeling.py
浏览文件 @
bee6030e
...
...
@@ -23,6 +23,7 @@ import copy
import
json
import
math
import
re
import
numpy
as
np
import
six
import
tensorflow
as
tf
...
...
@@ -133,7 +134,7 @@ class BertModel(object):
input_ids
,
input_mask
=
None
,
token_type_ids
=
None
,
use_one_hot_embeddings
=
Tru
e
,
use_one_hot_embeddings
=
Fals
e
,
scope
=
None
):
"""Constructor for BertModel.
...
...
@@ -145,9 +146,7 @@ class BertModel(object):
input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
embeddings or tf.embedding_lookup() for the word embeddings. On the TPU,
it is much faster if this is True, on the CPU or GPU, it is faster if
this is False.
embeddings or tf.embedding_lookup() for the word embeddings.
scope: (optional) variable scope. Defaults to "bert".
Raises:
...
...
@@ -262,20 +261,20 @@ class BertModel(object):
return
self
.
embedding_table
def
gelu
(
input_tensor
):
def
gelu
(
x
):
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
Original paper: https://arxiv.org/abs/1606.08415
Args:
input_tensor
: float Tensor to perform activation.
x
: float Tensor to perform activation.
Returns:
`
input_tensor
` with the GELU activation applied.
`
x
` with the GELU activation applied.
"""
cdf
=
0.5
*
(
1.0
+
tf
.
erf
(
input_tensor
/
tf
.
sqrt
(
2.0
)))
return
input_tensor
*
cdf
cdf
=
0.5
*
(
1.0
+
tf
.
tanh
(
(
np
.
sqrt
(
2
/
np
.
pi
)
*
(
x
+
0.044715
*
tf
.
pow
(
x
,
3
)))))
return
x
*
cdf
def
get_activation
(
activation_string
):
...
...
@@ -394,8 +393,7 @@ def embedding_lookup(input_ids,
initializer_range: float. Embedding initialization range.
word_embedding_name: string. Name of the embedding table.
use_one_hot_embeddings: bool. If True, use one-hot method for word
embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better
for TPUs.
embeddings. If False, use `tf.gather()`.
Returns:
float Tensor of shape [batch_size, seq_length, embedding_size].
...
...
@@ -413,12 +411,12 @@ def embedding_lookup(input_ids,
shape
=
[
vocab_size
,
embedding_size
],
initializer
=
create_initializer
(
initializer_range
))
flat_input_ids
=
tf
.
reshape
(
input_ids
,
[
-
1
])
if
use_one_hot_embeddings
:
flat_input_ids
=
tf
.
reshape
(
input_ids
,
[
-
1
])
one_hot_input_ids
=
tf
.
one_hot
(
flat_input_ids
,
depth
=
vocab_size
)
output
=
tf
.
matmul
(
one_hot_input_ids
,
embedding_table
)
else
:
output
=
tf
.
nn
.
embedding_lookup
(
embedding_table
,
input_ids
)
output
=
tf
.
gather
(
embedding_table
,
flat_
input_ids
)
input_shape
=
get_shape_list
(
input_ids
)
...
...
run_classifier_with_tfhub.py
0 → 100644
浏览文件 @
bee6030e
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT finetuning runner with TF-Hub."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
optimization
import
run_classifier
import
tokenization
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
flags
=
tf
.
flags
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
"bert_hub_module_handle"
,
None
,
"Handle for the BERT TF-Hub module."
)
def
create_model
(
is_training
,
input_ids
,
input_mask
,
segment_ids
,
labels
,
num_labels
):
"""Creates a classification model."""
tags
=
set
()
if
is_training
:
tags
.
add
(
"train"
)
bert_module
=
hub
.
Module
(
FLAGS
.
bert_hub_module_handle
,
tags
=
tags
,
trainable
=
True
)
bert_inputs
=
dict
(
input_ids
=
input_ids
,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
)
bert_outputs
=
bert_module
(
inputs
=
bert_inputs
,
signature
=
"tokens"
,
as_dict
=
True
)
# In the demo, we are doing a simple classification task on the entire
# segment.
#
# If you want to use the token-level output, use
# bert_outputs["sequence_output"] instead.
output_layer
=
bert_outputs
[
"pooled_output"
]
hidden_size
=
output_layer
.
shape
[
-
1
].
value
output_weights
=
tf
.
get_variable
(
"output_weights"
,
[
num_labels
,
hidden_size
],
initializer
=
tf
.
truncated_normal_initializer
(
stddev
=
0.02
))
output_bias
=
tf
.
get_variable
(
"output_bias"
,
[
num_labels
],
initializer
=
tf
.
zeros_initializer
())
with
tf
.
variable_scope
(
"loss"
):
if
is_training
:
# I.e., 0.1 dropout
output_layer
=
tf
.
nn
.
dropout
(
output_layer
,
keep_prob
=
0.9
)
logits
=
tf
.
matmul
(
output_layer
,
output_weights
,
transpose_b
=
True
)
logits
=
tf
.
nn
.
bias_add
(
logits
,
output_bias
)
log_probs
=
tf
.
nn
.
log_softmax
(
logits
,
axis
=-
1
)
one_hot_labels
=
tf
.
one_hot
(
labels
,
depth
=
num_labels
,
dtype
=
tf
.
float32
)
per_example_loss
=
-
tf
.
reduce_sum
(
one_hot_labels
*
log_probs
,
axis
=-
1
)
loss
=
tf
.
reduce_mean
(
per_example_loss
)
return
(
loss
,
per_example_loss
,
logits
)
def
model_fn_builder
(
num_labels
,
learning_rate
,
num_train_steps
,
num_warmup_steps
,
use_tpu
):
"""Returns `model_fn` closure for TPUEstimator."""
def
model_fn
(
features
,
labels
,
mode
,
params
):
# pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
tf
.
logging
.
info
(
"*** Features ***"
)
for
name
in
sorted
(
features
.
keys
()):
tf
.
logging
.
info
(
" name = %s, shape = %s"
%
(
name
,
features
[
name
].
shape
))
input_ids
=
features
[
"input_ids"
]
input_mask
=
features
[
"input_mask"
]
segment_ids
=
features
[
"segment_ids"
]
label_ids
=
features
[
"label_ids"
]
is_training
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
(
total_loss
,
per_example_loss
,
logits
)
=
create_model
(
is_training
,
input_ids
,
input_mask
,
segment_ids
,
label_ids
,
num_labels
)
output_spec
=
None
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
train_op
=
optimization
.
create_optimizer
(
total_loss
,
learning_rate
,
num_train_steps
,
num_warmup_steps
,
use_tpu
)
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
total_loss
,
train_op
=
train_op
)
elif
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
def
metric_fn
(
per_example_loss
,
label_ids
,
logits
):
predictions
=
tf
.
argmax
(
logits
,
axis
=-
1
,
output_type
=
tf
.
int32
)
accuracy
=
tf
.
metrics
.
accuracy
(
label_ids
,
predictions
)
loss
=
tf
.
metrics
.
mean
(
per_example_loss
)
return
{
"eval_accuracy"
:
accuracy
,
"eval_loss"
:
loss
,
}
eval_metrics
=
(
metric_fn
,
[
per_example_loss
,
label_ids
,
logits
])
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
total_loss
,
eval_metrics
=
eval_metrics
)
else
:
raise
ValueError
(
"Only TRAIN and EVAL modes are supported: %s"
%
(
mode
))
return
output_spec
return
model_fn
def
create_tokenizer_from_hub_module
():
"""Get the vocab file and casing info from the Hub module."""
with
tf
.
Graph
().
as_default
():
bert_module
=
hub
.
Module
(
FLAGS
.
bert_hub_module_handle
)
tokenization_info
=
bert_module
(
signature
=
"tokenization_info"
,
as_dict
=
True
)
with
tf
.
Session
()
as
sess
:
vocab_file
,
do_lower_case
=
sess
.
run
([
tokenization_info
[
"vocab_file"
],
tokenization_info
[
"do_lower_case"
]])
return
tokenization
.
FullTokenizer
(
vocab_file
=
vocab_file
,
do_lower_case
=
do_lower_case
)
def
main
(
_
):
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
processors
=
{
"cola"
:
run_classifier
.
ColaProcessor
,
"mnli"
:
run_classifier
.
MnliProcessor
,
"mrpc"
:
run_classifier
.
MrpcProcessor
,
}
if
not
FLAGS
.
do_train
and
not
FLAGS
.
do_eval
:
raise
ValueError
(
"At least one of `do_train` or `do_eval` must be True."
)
tf
.
gfile
.
MakeDirs
(
FLAGS
.
output_dir
)
task_name
=
FLAGS
.
task_name
.
lower
()
if
task_name
not
in
processors
:
raise
ValueError
(
"Task not found: %s"
%
(
task_name
))
processor
=
processors
[
task_name
]()
label_list
=
processor
.
get_labels
()
tokenizer
=
create_tokenizer_from_hub_module
()
tpu_cluster_resolver
=
None
if
FLAGS
.
use_tpu
and
FLAGS
.
tpu_name
:
tpu_cluster_resolver
=
tf
.
contrib
.
cluster_resolver
.
TPUClusterResolver
(
FLAGS
.
tpu_name
,
zone
=
FLAGS
.
tpu_zone
,
project
=
FLAGS
.
gcp_project
)
is_per_host
=
tf
.
contrib
.
tpu
.
InputPipelineConfig
.
PER_HOST_V2
run_config
=
tf
.
contrib
.
tpu
.
RunConfig
(
cluster
=
tpu_cluster_resolver
,
master
=
FLAGS
.
master
,
model_dir
=
FLAGS
.
output_dir
,
save_checkpoints_steps
=
FLAGS
.
save_checkpoints_steps
,
tpu_config
=
tf
.
contrib
.
tpu
.
TPUConfig
(
iterations_per_loop
=
FLAGS
.
iterations_per_loop
,
num_shards
=
FLAGS
.
num_tpu_cores
,
per_host_input_for_training
=
is_per_host
))
train_examples
=
None
num_train_steps
=
None
num_warmup_steps
=
None
if
FLAGS
.
do_train
:
train_examples
=
processor
.
get_train_examples
(
FLAGS
.
data_dir
)
num_train_steps
=
int
(
len
(
train_examples
)
/
FLAGS
.
train_batch_size
*
FLAGS
.
num_train_epochs
)
num_warmup_steps
=
int
(
num_train_steps
*
FLAGS
.
warmup_proportion
)
model_fn
=
model_fn_builder
(
num_labels
=
len
(
label_list
),
learning_rate
=
FLAGS
.
learning_rate
,
num_train_steps
=
num_train_steps
,
num_warmup_steps
=
num_warmup_steps
,
use_tpu
=
FLAGS
.
use_tpu
)
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimator
(
use_tpu
=
FLAGS
.
use_tpu
,
model_fn
=
model_fn
,
config
=
run_config
,
train_batch_size
=
FLAGS
.
train_batch_size
,
eval_batch_size
=
FLAGS
.
eval_batch_size
)
if
FLAGS
.
do_train
:
train_features
=
run_classifier
.
convert_examples_to_features
(
train_examples
,
label_list
,
FLAGS
.
max_seq_length
,
tokenizer
)
tf
.
logging
.
info
(
"***** Running training *****"
)
tf
.
logging
.
info
(
" Num examples = %d"
,
len
(
train_examples
))
tf
.
logging
.
info
(
" Batch size = %d"
,
FLAGS
.
train_batch_size
)
tf
.
logging
.
info
(
" Num steps = %d"
,
num_train_steps
)
train_input_fn
=
run_classifier
.
input_fn_builder
(
features
=
train_features
,
seq_length
=
FLAGS
.
max_seq_length
,
is_training
=
True
,
drop_remainder
=
True
)
estimator
.
train
(
input_fn
=
train_input_fn
,
max_steps
=
num_train_steps
)
if
FLAGS
.
do_eval
:
eval_examples
=
processor
.
get_dev_examples
(
FLAGS
.
data_dir
)
eval_features
=
run_classifier
.
convert_examples_to_features
(
eval_examples
,
label_list
,
FLAGS
.
max_seq_length
,
tokenizer
)
tf
.
logging
.
info
(
"***** Running evaluation *****"
)
tf
.
logging
.
info
(
" Num examples = %d"
,
len
(
eval_examples
))
tf
.
logging
.
info
(
" Batch size = %d"
,
FLAGS
.
eval_batch_size
)
# This tells the estimator to run through the entire set.
eval_steps
=
None
# However, if running eval on the TPU, you will need to specify the
# number of steps.
if
FLAGS
.
use_tpu
:
# Eval will be slightly WRONG on the TPU because it will truncate
# the last batch.
eval_steps
=
int
(
len
(
eval_examples
)
/
FLAGS
.
eval_batch_size
)
eval_drop_remainder
=
True
if
FLAGS
.
use_tpu
else
False
eval_input_fn
=
run_classifier
.
input_fn_builder
(
features
=
eval_features
,
seq_length
=
FLAGS
.
max_seq_length
,
is_training
=
False
,
drop_remainder
=
eval_drop_remainder
)
result
=
estimator
.
evaluate
(
input_fn
=
eval_input_fn
,
steps
=
eval_steps
)
output_eval_file
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
"eval_results.txt"
)
with
tf
.
gfile
.
GFile
(
output_eval_file
,
"w"
)
as
writer
:
tf
.
logging
.
info
(
"***** Eval results *****"
)
for
key
in
sorted
(
result
.
keys
()):
tf
.
logging
.
info
(
" %s = %s"
,
key
,
str
(
result
[
key
]))
writer
.
write
(
"%s = %s
\n
"
%
(
key
,
str
(
result
[
key
])))
if
__name__
==
"__main__"
:
flags
.
mark_flag_as_required
(
"data_dir"
)
flags
.
mark_flag_as_required
(
"task_name"
)
flags
.
mark_flag_as_required
(
"bert_hub_module_handle"
)
flags
.
mark_flag_as_required
(
"output_dir"
)
tf
.
app
.
run
()
tokenization_test.py
浏览文件 @
bee6030e
...
...
@@ -18,9 +18,9 @@ from __future__ import print_function
import
os
import
tempfile
import
tokenization
import
six
import
tensorflow
as
tf
import
tokenization
class
TokenizationTest
(
tf
.
test
.
TestCase
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录