Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
diluosixu
bert
提交
60454702
B
bert
项目概览
diluosixu
/
bert
与 Fork 源项目一致
从无法访问的项目Fork
通知
4
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
B
bert
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
60454702
编写于
11月 15, 2018
作者:
J
Jacob Devlin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Adding SQuAD 2.0 support
上级
9d81f96d
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
268 addition
and
95 deletion
+268
-95
README.md
README.md
+79
-1
modeling.py
modeling.py
+27
-34
modeling_test.py
modeling_test.py
+1
-0
run_classifier.py
run_classifier.py
+1
-2
run_squad.py
run_squad.py
+159
-57
tokenization_test.py
tokenization_test.py
+1
-1
未找到文件。
README.md
浏览文件 @
60454702
# BERT
**\*\*\*\*\* New November 15th, 2018: SOTA SQuAD 2.0 System \*\*\*\*\***
We released code changes to reproduce our 83% F1 SQuAD 2.0 system, which is
currently 1st place on the leaderboard by 3%. See the SQuAD 2.0 section of the
README for details.
**
\*\*\*\*\*
New November 5th, 2018: Third-party PyTorch and Chainer versions of
BERT available
\*\*\*\*\*
**
...
...
@@ -342,7 +348,7 @@ python run_classifier.py \
--output_dir
=
/tmp/mrpc_output/
```
### SQuAD
### SQuAD
1.1
The Stanford Question Answering Dataset (SQuAD) is a popular question answering
benchmark dataset. BERT (at the time of the release) obtains state-of-the-art
...
...
@@ -435,6 +441,78 @@ If you fine-tune for one epoch on
be even better, but you will need to convert TriviaQA into the SQuAD json
format.
### SQuAD 2.0
This model is also implemented and documented in
`run_squad.py`
.
To run on SQuAD 2.0, you will first need to download the dataset. The necessary
files can be found here:
*
[
train-v2.0.json
](
https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json
)
*
[
dev-v2.0.json
](
https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json
)
*
[
evaluate-v2.0.py
](
https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
)
Download these to some directory
`$SQUAD_DIR`
.
On Cloud TPU you can run with BERT-Large as follows:
```
shell
python run_squad.py
\
--vocab_file
=
$BERT_LARGE_DIR
/vocab.txt
\
--bert_config_file
=
$BERT_LARGE_DIR
/bert_config.json
\
--init_checkpoint
=
$BERT_LARGE_DIR
/bert_model.ckpt
\
--do_train
=
True
\
--train_file
=
$SQUAD_DIR
/train-v1.1.json
\
--do_predict
=
True
\
--predict_file
=
$SQUAD_DIR
/dev-v1.1.json
\
--train_batch_size
=
24
\
--learning_rate
=
3e-5
\
--num_train_epochs
=
2.0
\
--max_seq_length
=
384
\
--doc_stride
=
128
\
--output_dir
=
gs://some_bucket/squad_large/
\
--use_tpu
=
True
\
--tpu_name
=
$TPU_NAME
\
--version_2_with_negative
=
True
```
We assume you have copied everything from the output directory to a local
directory called ./squad/. The initial dev set predictions will be at
./squad/predictions.json and the differences between the score of no answer ("")
and the best non-null answer for each question will be in the file
./squad/null_odds.json
Run this script to tune a threshold for predicting null versus non-null answers:
python $SQUAD_DIR/evaluate-v2.0.py $SQUAD_DIR/dev-v2.0.json
./squad/predictions.json --na-prob-file ./squad/null_odds.json
Assume the script outputs "best_f1_thresh" THRESH. (Typical values are between
-1.0 and -5.0). You can now re-run the model to generate predictions with the
derived threshold or alternatively you can extract the appropriate answers from
./squad/nbest_predictions.json.
```
shell
python run_squad.py
\
--vocab_file
=
$BERT_LARGE_DIR
/vocab.txt
\
--bert_config_file
=
$BERT_LARGE_DIR
/bert_config.json
\
--init_checkpoint
=
$BERT_LARGE_DIR
/bert_model.ckpt
\
--do_train
=
False
\
--train_file
=
$SQUAD_DIR
/train-v1.1.json
\
--do_predict
=
True
\
--predict_file
=
$SQUAD_DIR
/dev-v1.1.json
\
--train_batch_size
=
24
\
--learning_rate
=
3e-5
\
--num_train_epochs
=
2.0
\
--max_seq_length
=
384
\
--doc_stride
=
128
\
--output_dir
=
gs://some_bucket/squad_large/
\
--use_tpu
=
True
\
--tpu_name
=
$TPU_NAME
\
--version_2_with_negative
=
True
\
--null_score_diff_threshold
=
$THRESH
```
### Out-of-memory issues
All experiments in the paper were fine-tuned on a Cloud TPU, which has 64GB of
...
...
modeling.py
浏览文件 @
60454702
...
...
@@ -469,11 +469,6 @@ def embedding_postprocessor(input_tensor,
seq_length
=
input_shape
[
1
]
width
=
input_shape
[
2
]
if
seq_length
>
max_position_embeddings
:
raise
ValueError
(
"The seq length (%d) cannot be greater than "
"`max_position_embeddings` (%d)"
%
(
seq_length
,
max_position_embeddings
))
output
=
input_tensor
if
use_token_type
:
...
...
@@ -494,37 +489,35 @@ def embedding_postprocessor(input_tensor,
output
+=
token_type_embeddings
if
use_position_embeddings
:
full_position_embeddings
=
tf
.
get_variable
(
name
=
position_embedding_name
,
shape
=
[
max_position_embeddings
,
width
],
initializer
=
create_initializer
(
initializer_range
))
# Since the position embedding table is a learned variable, we create it
# using a (long) sequence length `max_position_embeddings`. The actual
# sequence length might be shorter than this, for faster training of
# tasks that do not have long sequences.
#
# So `full_position_embeddings` is effectively an embedding table
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
# perform a slice.
if
seq_length
<
max_position_embeddings
:
assert_op
=
tf
.
assert_less_equal
(
seq_length
,
max_position_embeddings
)
with
tf
.
control_dependencies
([
assert_op
]):
full_position_embeddings
=
tf
.
get_variable
(
name
=
position_embedding_name
,
shape
=
[
max_position_embeddings
,
width
],
initializer
=
create_initializer
(
initializer_range
))
# Since the position embedding table is a learned variable, we create it
# using a (long) sequence length `max_position_embeddings`. The actual
# sequence length might be shorter than this, for faster training of
# tasks that do not have long sequences.
#
# So `full_position_embeddings` is effectively an embedding table
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
# perform a slice.
position_embeddings
=
tf
.
slice
(
full_position_embeddings
,
[
0
,
0
],
[
seq_length
,
-
1
])
else
:
position_embeddings
=
full_position_embeddings
num_dims
=
len
(
output
.
shape
.
as_list
())
# Only the last two dimensions are relevant (`seq_length` and `width`), so
# we broadcast among the first dimensions, which is typically just
# the batch size.
position_broadcast_shape
=
[]
for
_
in
range
(
num_dims
-
2
):
position_broadcast_shape
.
append
(
1
)
position_broadcast_shape
.
extend
([
seq_length
,
width
])
position_embeddings
=
tf
.
reshape
(
position_embeddings
,
position_broadcast_shape
)
output
+=
position_embeddings
num_dims
=
len
(
output
.
shape
.
as_list
())
# Only the last two dimensions are relevant (`seq_length` and `width`), so
# we broadcast among the first dimensions, which is typically just
# the batch size.
position_broadcast_shape
=
[]
for
_
in
range
(
num_dims
-
2
):
position_broadcast_shape
.
append
(
1
)
position_broadcast_shape
.
extend
([
seq_length
,
width
])
position_embeddings
=
tf
.
reshape
(
position_embeddings
,
position_broadcast_shape
)
output
+=
position_embeddings
output
=
layer_norm_and_dropout
(
output
,
dropout_prob
)
return
output
...
...
modeling_test.py
浏览文件 @
60454702
...
...
@@ -164,6 +164,7 @@ class BertModelTest(tf.test.TestCase):
graph
=
sess
.
graph
ignore_strings
=
[
"^.*/assert_less_equal/.*$"
,
"^.*/dilation_rate$"
,
"^.*/Tensordot/concat$"
,
"^.*/Tensordot/concat/axis$"
,
...
...
run_classifier.py
浏览文件 @
60454702
...
...
@@ -607,9 +607,8 @@ def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
num_labels
,
use_one_hot_embeddings
)
tvars
=
tf
.
trainable_variables
()
initialized_variable_names
=
{}
scaffold_fn
=
None
initialized_variable_names
=
[]
if
init_checkpoint
:
(
assignment_map
,
initialized_variable_names
)
=
modeling
.
get_assignment_map_from_checkpoint
(
tvars
,
init_checkpoint
)
...
...
run_squad.py
浏览文件 @
60454702
...
...
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run BERT on SQuAD."""
"""Run BERT on SQuAD
1.1 and SQuAD 2.0
."""
from
__future__
import
absolute_import
from
__future__
import
division
...
...
@@ -145,9 +145,20 @@ flags.DEFINE_bool(
"If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation."
)
flags
.
DEFINE_bool
(
"version_2_with_negative"
,
False
,
"If true, the SQuAD examples contain some that do not have an answer."
)
flags
.
DEFINE_float
(
"null_score_diff_threshold"
,
0.0
,
"If null_score - best_non_null is greater than the threshold predict null."
)
class
SquadExample
(
object
):
"""A single training/test example for simple sequence classification."""
"""A single training/test example for simple sequence classification.
For examples without an answer, the start and end position are -1.
"""
def
__init__
(
self
,
qas_id
,
...
...
@@ -155,13 +166,15 @@ class SquadExample(object):
doc_tokens
,
orig_answer_text
=
None
,
start_position
=
None
,
end_position
=
None
):
end_position
=
None
,
is_impossible
=
False
):
self
.
qas_id
=
qas_id
self
.
question_text
=
question_text
self
.
doc_tokens
=
doc_tokens
self
.
orig_answer_text
=
orig_answer_text
self
.
start_position
=
start_position
self
.
end_position
=
end_position
self
.
is_impossible
=
is_impossible
def
__str__
(
self
):
return
self
.
__repr__
()
...
...
@@ -176,6 +189,8 @@ class SquadExample(object):
s
+=
", start_position: %d"
%
(
self
.
start_position
)
if
self
.
start_position
:
s
+=
", end_position: %d"
%
(
self
.
end_position
)
if
self
.
start_position
:
s
+=
", is_impossible: %r"
%
(
self
.
is_impossible
)
return
s
...
...
@@ -193,7 +208,8 @@ class InputFeatures(object):
input_mask
,
segment_ids
,
start_position
=
None
,
end_position
=
None
):
end_position
=
None
,
is_impossible
=
None
):
self
.
unique_id
=
unique_id
self
.
example_index
=
example_index
self
.
doc_span_index
=
doc_span_index
...
...
@@ -205,6 +221,7 @@ class InputFeatures(object):
self
.
segment_ids
=
segment_ids
self
.
start_position
=
start_position
self
.
end_position
=
end_position
self
.
is_impossible
=
is_impossible
def
read_squad_examples
(
input_file
,
is_training
):
...
...
@@ -241,29 +258,40 @@ def read_squad_examples(input_file, is_training):
start_position
=
None
end_position
=
None
orig_answer_text
=
None
is_impossible
=
False
if
is_training
:
if
len
(
qa
[
"answers"
])
!=
1
:
if
FLAGS
.
version_2_with_negative
:
is_impossible
=
qa
[
"is_impossible"
]
if
(
len
(
qa
[
"answers"
])
!=
1
)
and
(
not
is_impossible
):
raise
ValueError
(
"For training, each question should have exactly 1 answer."
)
answer
=
qa
[
"answers"
][
0
]
orig_answer_text
=
answer
[
"text"
]
answer_offset
=
answer
[
"answer_start"
]
answer_length
=
len
(
orig_answer_text
)
start_position
=
char_to_word_offset
[
answer_offset
]
end_position
=
char_to_word_offset
[
answer_offset
+
answer_length
-
1
]
# Only add answers where the text can be exactly recovered from the
# document. If this CAN'T happen it's likely due to weird Unicode
# stuff so we will just skip the example.
#
# Note that this means for training mode, every example is NOT
# guaranteed to be preserved.
actual_text
=
" "
.
join
(
doc_tokens
[
start_position
:(
end_position
+
1
)])
cleaned_answer_text
=
" "
.
join
(
tokenization
.
whitespace_tokenize
(
orig_answer_text
))
if
actual_text
.
find
(
cleaned_answer_text
)
==
-
1
:
tf
.
logging
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
actual_text
,
cleaned_answer_text
)
continue
if
not
is_impossible
:
answer
=
qa
[
"answers"
][
0
]
orig_answer_text
=
answer
[
"text"
]
answer_offset
=
answer
[
"answer_start"
]
answer_length
=
len
(
orig_answer_text
)
start_position
=
char_to_word_offset
[
answer_offset
]
end_position
=
char_to_word_offset
[
answer_offset
+
answer_length
-
1
]
# Only add answers where the text can be exactly recovered from the
# document. If this CAN'T happen it's likely due to weird Unicode
# stuff so we will just skip the example.
#
# Note that this means for training mode, every example is NOT
# guaranteed to be preserved.
actual_text
=
" "
.
join
(
doc_tokens
[
start_position
:(
end_position
+
1
)])
cleaned_answer_text
=
" "
.
join
(
tokenization
.
whitespace_tokenize
(
orig_answer_text
))
if
actual_text
.
find
(
cleaned_answer_text
)
==
-
1
:
tf
.
logging
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
actual_text
,
cleaned_answer_text
)
continue
else
:
start_position
=
-
1
end_position
=
-
1
orig_answer_text
=
""
example
=
SquadExample
(
qas_id
=
qas_id
,
...
...
@@ -271,8 +299,10 @@ def read_squad_examples(input_file, is_training):
doc_tokens
=
doc_tokens
,
orig_answer_text
=
orig_answer_text
,
start_position
=
start_position
,
end_position
=
end_position
)
end_position
=
end_position
,
is_impossible
=
is_impossible
)
examples
.
append
(
example
)
return
examples
...
...
@@ -301,7 +331,10 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
tok_start_position
=
None
tok_end_position
=
None
if
is_training
:
if
is_training
and
example
.
is_impossible
:
tok_start_position
=
-
1
tok_end_position
=
-
1
if
is_training
and
not
example
.
is_impossible
:
tok_start_position
=
orig_to_tok_index
[
example
.
start_position
]
if
example
.
end_position
<
len
(
example
.
doc_tokens
)
-
1
:
tok_end_position
=
orig_to_tok_index
[
example
.
end_position
+
1
]
-
1
...
...
@@ -373,19 +406,27 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
start_position
=
None
end_position
=
None
if
is_training
:
if
is_training
and
not
example
.
is_impossible
:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start
=
doc_span
.
start
doc_end
=
doc_span
.
start
+
doc_span
.
length
-
1
out_of_span
=
False
if
(
example
.
start_position
<
doc_start
or
example
.
end_position
<
doc_start
or
example
.
start_position
>
doc_end
or
example
.
end_position
>
doc_end
):
continue
out_of_span
=
True
if
out_of_span
:
start_position
=
0
end_position
=
0
else
:
doc_offset
=
len
(
query_tokens
)
+
2
start_position
=
tok_start_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
doc_offset
=
len
(
query_tokens
)
+
2
start_position
=
tok_start_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
if
is_training
and
example
.
is_impossible
:
start_position
=
0
end_position
=
0
if
example_index
<
20
:
tf
.
logging
.
info
(
"*** Example ***"
)
...
...
@@ -404,7 +445,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
tf
.
logging
.
info
(
"segment_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
if
is_training
:
if
is_training
and
example
.
is_impossible
:
tf
.
logging
.
info
(
"impossible example"
)
if
is_training
and
not
example
.
is_impossible
:
answer_text
=
" "
.
join
(
tokens
[
start_position
:(
end_position
+
1
)])
tf
.
logging
.
info
(
"start_position: %d"
%
(
start_position
))
tf
.
logging
.
info
(
"end_position: %d"
%
(
end_position
))
...
...
@@ -422,7 +465,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
,
start_position
=
start_position
,
end_position
=
end_position
)
end_position
=
end_position
,
is_impossible
=
example
.
is_impossible
)
# Run callback
output_fn
(
feature
)
...
...
@@ -697,8 +741,8 @@ RawResult = collections.namedtuple("RawResult",
def
write_predictions
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
max_answer_length
,
do_lower_case
,
output_prediction_file
,
output_nbest_file
):
"""Write final predictions to the json file."""
output_nbest_file
,
output_null_log_odds_file
):
"""Write final predictions to the json file
and log-odds of null if needed
."""
tf
.
logging
.
info
(
"Writing predictions to: %s"
%
(
output_prediction_file
))
tf
.
logging
.
info
(
"Writing nbest to: %s"
%
(
output_nbest_file
))
...
...
@@ -716,15 +760,29 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
all_predictions
=
collections
.
OrderedDict
()
all_nbest_json
=
collections
.
OrderedDict
()
scores_diff_json
=
collections
.
OrderedDict
()
for
(
example_index
,
example
)
in
enumerate
(
all_examples
):
features
=
example_index_to_features
[
example_index
]
prelim_predictions
=
[]
# keep track of the minimum score of null start+end of position 0
score_null
=
1000000
# large and positive
min_null_feature_index
=
0
# the paragraph slice with min mull score
null_start_logit
=
0
# the start logit at the slice with min null score
null_end_logit
=
0
# the end logit at the slice with min null score
for
(
feature_index
,
feature
)
in
enumerate
(
features
):
result
=
unique_id_to_result
[
feature
.
unique_id
]
start_indexes
=
_get_best_indexes
(
result
.
start_logits
,
n_best_size
)
end_indexes
=
_get_best_indexes
(
result
.
end_logits
,
n_best_size
)
# if we could have irrelevant answers, get the min score of irrelevant
if
FLAGS
.
version_2_with_negative
:
feature_null_score
=
result
.
start_logits
[
0
]
+
result
.
end_logits
[
0
]
if
feature_null_score
<
score_null
:
score_null
=
feature_null_score
min_null_feature_index
=
feature_index
null_start_logit
=
result
.
start_logits
[
0
]
null_end_logit
=
result
.
end_logits
[
0
]
for
start_index
in
start_indexes
:
for
end_index
in
end_indexes
:
# We could hypothetically create invalid predictions, e.g., predict
...
...
@@ -753,6 +811,14 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
start_logit
=
result
.
start_logits
[
start_index
],
end_logit
=
result
.
end_logits
[
end_index
]))
if
FLAGS
.
version_2_with_negative
:
prelim_predictions
.
append
(
_PrelimPrediction
(
feature_index
=
min_null_feature_index
,
start_index
=
0
,
end_index
=
0
,
start_logit
=
null_start_logit
,
end_logit
=
null_end_logit
))
prelim_predictions
=
sorted
(
prelim_predictions
,
key
=
lambda
x
:
(
x
.
start_logit
+
x
.
end_logit
),
...
...
@@ -767,33 +833,44 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
if
len
(
nbest
)
>=
n_best_size
:
break
feature
=
features
[
pred
.
feature_index
]
if
pred
.
start_index
>
0
:
# this is a non-null prediction
tok_tokens
=
feature
.
tokens
[
pred
.
start_index
:(
pred
.
end_index
+
1
)]
orig_doc_start
=
feature
.
token_to_orig_map
[
pred
.
start_index
]
orig_doc_end
=
feature
.
token_to_orig_map
[
pred
.
end_index
]
orig_tokens
=
example
.
doc_tokens
[
orig_doc_start
:(
orig_doc_end
+
1
)]
tok_text
=
" "
.
join
(
tok_tokens
)
# De-tokenize WordPieces that have been split off.
tok_text
=
tok_text
.
replace
(
" ##"
,
""
)
tok_text
=
tok_text
.
replace
(
"##"
,
""
)
# Clean whitespace
tok_text
=
tok_text
.
strip
()
tok_text
=
" "
.
join
(
tok_text
.
split
())
orig_text
=
" "
.
join
(
orig_tokens
)
final_text
=
get_final_text
(
tok_text
,
orig_text
,
do_lower_case
)
if
final_text
in
seen_predictions
:
continue
tok_tokens
=
feature
.
tokens
[
pred
.
start_index
:(
pred
.
end_index
+
1
)]
orig_doc_start
=
feature
.
token_to_orig_map
[
pred
.
start_index
]
orig_doc_end
=
feature
.
token_to_orig_map
[
pred
.
end_index
]
orig_tokens
=
example
.
doc_tokens
[
orig_doc_start
:(
orig_doc_end
+
1
)]
tok_text
=
" "
.
join
(
tok_tokens
)
# De-tokenize WordPieces that have been split off.
tok_text
=
tok_text
.
replace
(
" ##"
,
""
)
tok_text
=
tok_text
.
replace
(
"##"
,
""
)
# Clean whitespace
tok_text
=
tok_text
.
strip
()
tok_text
=
" "
.
join
(
tok_text
.
split
())
orig_text
=
" "
.
join
(
orig_tokens
)
final_text
=
get_final_text
(
tok_text
,
orig_text
,
do_lower_case
)
if
final_text
in
seen_predictions
:
continue
seen_predictions
[
final_text
]
=
True
else
:
final_text
=
""
seen_predictions
[
final_text
]
=
True
seen_predictions
[
final_text
]
=
True
nbest
.
append
(
_NbestPrediction
(
text
=
final_text
,
start_logit
=
pred
.
start_logit
,
end_logit
=
pred
.
end_logit
))
# if we didn't inlude the empty option in the n-best, inlcude it
if
FLAGS
.
version_2_with_negative
:
if
""
not
in
seen_predictions
:
nbest
.
append
(
_NbestPrediction
(
text
=
""
,
start_logit
=
null_start_logit
,
end_logit
=
null_end_logit
))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if
not
nbest
:
...
...
@@ -803,8 +880,12 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
assert
len
(
nbest
)
>=
1
total_scores
=
[]
best_non_null_entry
=
None
for
entry
in
nbest
:
total_scores
.
append
(
entry
.
start_logit
+
entry
.
end_logit
)
if
not
best_non_null_entry
:
if
entry
.
text
:
best_non_null_entry
=
entry
probs
=
_compute_softmax
(
total_scores
)
...
...
@@ -819,7 +900,18 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
assert
len
(
nbest_json
)
>=
1
all_predictions
[
example
.
qas_id
]
=
nbest_json
[
0
][
"text"
]
if
not
FLAGS
.
version_2_with_negative
:
all_predictions
[
example
.
qas_id
]
=
nbest_json
[
0
][
"text"
]
else
:
# predict "" iff the null score - the score of best non-null > threshold
score_diff
=
score_null
-
best_non_null_entry
.
start_logit
-
(
best_non_null_entry
.
end_logit
)
scores_diff_json
[
example
.
qas_id
]
=
score_diff
if
score_diff
>
FLAGS
.
null_score_diff_threshold
:
all_predictions
[
example
.
qas_id
]
=
""
else
:
all_predictions
[
example
.
qas_id
]
=
best_non_null_entry
.
text
all_nbest_json
[
example
.
qas_id
]
=
nbest_json
with
tf
.
gfile
.
GFile
(
output_prediction_file
,
"w"
)
as
writer
:
...
...
@@ -828,6 +920,10 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
with
tf
.
gfile
.
GFile
(
output_nbest_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_nbest_json
,
indent
=
4
)
+
"
\n
"
)
if
FLAGS
.
version_2_with_negative
:
with
tf
.
gfile
.
GFile
(
output_null_log_odds_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
scores_diff_json
,
indent
=
4
)
+
"
\n
"
)
def
get_final_text
(
pred_text
,
orig_text
,
do_lower_case
):
"""Project the tokenized prediction back to the original text."""
...
...
@@ -987,6 +1083,10 @@ class FeatureWriter(object):
if
self
.
is_training
:
features
[
"start_positions"
]
=
create_int_feature
([
feature
.
start_position
])
features
[
"end_positions"
]
=
create_int_feature
([
feature
.
end_position
])
impossible
=
0
if
feature
.
is_impossible
:
impossible
=
1
features
[
"is_impossible"
]
=
create_int_feature
([
impossible
])
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
self
.
_writer
.
write
(
tf_example
.
SerializeToString
())
...
...
@@ -1166,10 +1266,12 @@ def main(_):
output_prediction_file
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
"predictions.json"
)
output_nbest_file
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
"nbest_predictions.json"
)
output_null_log_odds_file
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
"null_odds.json"
)
write_predictions
(
eval_examples
,
eval_features
,
all_results
,
FLAGS
.
n_best_size
,
FLAGS
.
max_answer_length
,
FLAGS
.
do_lower_case
,
output_prediction_file
,
output_nbest_file
)
output_nbest_file
,
output_null_log_odds_file
)
if
__name__
==
"__main__"
:
...
...
tokenization_test.py
浏览文件 @
60454702
...
...
@@ -30,7 +30,7 @@ class TokenizationTest(tf.test.TestCase):
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
,
","
]
with
tempfile
.
NamedTemporaryFile
(
mode
=
'w+'
,
delete
=
False
)
as
vocab_writer
:
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
as
vocab_writer
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_file
=
vocab_writer
.
name
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录