Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
6ad2e797
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6ad2e797
编写于
9月 30, 2017
作者:
W
wangmeng28
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add data featurize scripts
上级
248cc037
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
707 addition
and
10 deletion
+707
-10
globally_normalized_reader/README.md
globally_normalized_reader/README.md
+7
-4
globally_normalized_reader/data/download.sh
globally_normalized_reader/data/download.sh
+5
-2
globally_normalized_reader/evaluate.py
globally_normalized_reader/evaluate.py
+96
-0
globally_normalized_reader/featurize.py
globally_normalized_reader/featurize.py
+308
-0
globally_normalized_reader/index.html
globally_normalized_reader/index.html
+7
-4
globally_normalized_reader/vocab.py
globally_normalized_reader/vocab.py
+284
-0
未找到文件。
globally_normalized_reader/README.md
浏览文件 @
6ad2e797
...
...
@@ -25,10 +25,13 @@ You can also visit https://github.com/baidu-research/GloballyNormalizedReader to
docker pull paddledev/paddle
```
2.
Download all necessary data by running:
```
bash
cd
data
&&
./download.sh
```
3.
**(TODO) add the preprocess and featurizer scripts.**
```
bash
cd
data
&&
./download.sh
&&
cd
..
```
3.
Preprocess and featurizer data:
```
bash
python featurize.py
--datadir
data
--outdir
data/featurized
--glove-path
data/glove.840B.300d.txt
```
# Training a Model
...
...
globally_normalized_reader/data/download.sh
浏览文件 @
6ad2e797
#!/bin/bash
wget
--no-check-certificate
https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json
wget
--no-check-certificate
https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json
wget
--no-check-certificate
https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json
-O
train.json
wget
--no-check-certificate
https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json
-O
dev.json
wget http://nlp.stanford.edu/data/glove.840B.300d.zip
unzip glove.840B.300d.zip
globally_normalized_reader/evaluate.py
0 → 100644
浏览文件 @
6ad2e797
""" Official evaluation script for v1.1 of the SQuAD dataset. """
from
__future__
import
print_function
from
collections
import
Counter
import
string
import
re
import
argparse
import
json
import
sys
def
normalize_answer
(
s
):
"""Lower text and remove punctuation, articles and extra whitespace."""
def
remove_articles
(
text
):
return
re
.
sub
(
r
'\b(a|an|the)\b'
,
' '
,
text
)
def
white_space_fix
(
text
):
return
' '
.
join
(
text
.
split
())
def
remove_punc
(
text
):
exclude
=
set
(
string
.
punctuation
)
return
''
.
join
(
ch
for
ch
in
text
if
ch
not
in
exclude
)
def
lower
(
text
):
return
text
.
lower
()
return
white_space_fix
(
remove_articles
(
remove_punc
(
lower
(
s
))))
def
f1_score
(
prediction
,
ground_truth
):
prediction_tokens
=
normalize_answer
(
prediction
).
split
()
ground_truth_tokens
=
normalize_answer
(
ground_truth
).
split
()
common
=
Counter
(
prediction_tokens
)
&
Counter
(
ground_truth_tokens
)
num_same
=
sum
(
common
.
values
())
if
num_same
==
0
:
return
0
precision
=
1.0
*
num_same
/
len
(
prediction_tokens
)
recall
=
1.0
*
num_same
/
len
(
ground_truth_tokens
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
f1
def
exact_match_score
(
prediction
,
ground_truth
):
return
(
normalize_answer
(
prediction
)
==
normalize_answer
(
ground_truth
))
def
metric_max_over_ground_truths
(
metric_fn
,
prediction
,
ground_truths
):
scores_for_ground_truths
=
[]
for
ground_truth
in
ground_truths
:
score
=
metric_fn
(
prediction
,
ground_truth
)
scores_for_ground_truths
.
append
(
score
)
return
max
(
scores_for_ground_truths
)
def
evaluate
(
dataset
,
predictions
):
f1
=
exact_match
=
total
=
0
for
article
in
dataset
:
for
paragraph
in
article
[
'paragraphs'
]:
for
qa
in
paragraph
[
'qas'
]:
total
+=
1
if
qa
[
'id'
]
not
in
predictions
:
message
=
'Unanswered question '
+
qa
[
'id'
]
+
\
' will receive score 0.'
print
(
message
,
file
=
sys
.
stderr
)
continue
ground_truths
=
list
(
map
(
lambda
x
:
x
[
'text'
],
qa
[
'answers'
]))
prediction
=
predictions
[
qa
[
'id'
]]
exact_match
+=
metric_max_over_ground_truths
(
exact_match_score
,
prediction
,
ground_truths
)
f1
+=
metric_max_over_ground_truths
(
f1_score
,
prediction
,
ground_truths
)
exact_match
=
100.0
*
exact_match
/
total
f1
=
100.0
*
f1
/
total
return
{
'exact_match'
:
exact_match
,
'f1'
:
f1
}
if
__name__
==
'__main__'
:
expected_version
=
'1.1'
parser
=
argparse
.
ArgumentParser
(
description
=
'Evaluation for SQuAD '
+
expected_version
)
parser
.
add_argument
(
'dataset_file'
,
help
=
'Dataset file'
)
parser
.
add_argument
(
'prediction_file'
,
help
=
'Prediction File'
)
args
=
parser
.
parse_args
()
with
open
(
args
.
dataset_file
)
as
dataset_file
:
dataset_json
=
json
.
load
(
dataset_file
)
if
(
dataset_json
[
'version'
]
!=
expected_version
):
print
(
'Evaluation expects v-'
+
expected_version
+
', but got dataset with v-'
+
dataset_json
[
'version'
],
file
=
sys
.
stderr
)
dataset
=
dataset_json
[
'data'
]
with
open
(
args
.
prediction_file
)
as
prediction_file
:
predictions
=
json
.
load
(
prediction_file
)
print
(
json
.
dumps
(
evaluate
(
dataset
,
predictions
)))
globally_normalized_reader/featurize.py
0 → 100644
浏览文件 @
6ad2e797
# -*- coding: utf-8 -*-
"""
Convert the raw json data into training and validation examples.
"""
from
collections
import
Counter
import
json
import
os
import
io
import
string
import
click
import
numpy
as
np
import
ciseau
from
vocab
import
Vocab
from
evaluate
import
normalize_answer
# Constants
UNK
=
"<UNK>"
SOS
=
"<SOS>"
EOS
=
"<EOS>"
PAD
=
"<PAD>"
splits
=
[
"train"
,
"dev"
]
ARTICLES
=
{
"a"
,
"an"
,
"the"
,
"of"
}
# Keep the random embedding matrix the same between runs.
np
.
random
.
seed
(
1234
)
def
data_stream
(
path
):
""" Given a path json data in Pranav format, convert it to a stream
question/context/answers tuple."""
with
io
.
open
(
path
,
"r"
)
as
handle
:
raw_data
=
json
.
load
(
handle
)[
"data"
]
for
ex
in
raw_data
:
for
paragraph
in
ex
[
"paragraphs"
]:
context
=
paragraph
[
"context"
]
for
qa
in
paragraph
[
"qas"
]:
question
=
qa
[
"question"
]
answers
=
qa
[
"answers"
]
if
"id"
not
in
qa
:
qa_id
=
-
1
else
:
qa_id
=
qa
[
"id"
]
yield
question
,
context
,
answers
,
qa_id
def
build_vocabulary
(
datadir
,
outdir
,
glove_path
):
"""Construct the vocabulary object used throughout."""
# We're not going to backprop through the word vectors
# both train and dev words end up in the vocab.
counter
=
Counter
()
for
split
in
splits
:
datapath
=
os
.
path
.
join
(
datadir
,
split
+
".json"
)
for
question
,
context
,
_
,
_
in
data_stream
(
datapath
):
for
word
in
ciseau
.
tokenize
(
question
,
normalize_ascii
=
False
):
counter
[
normalize
(
word
)]
+=
1
for
word
in
ciseau
.
tokenize
(
context
,
normalize_ascii
=
False
):
counter
[
normalize
(
word
)]
+=
1
common_words
=
[
UNK
,
SOS
,
EOS
,
PAD
]
+
[
w
for
w
,
_
in
counter
.
most_common
()]
vocab_path
=
os
.
path
.
join
(
outdir
,
"vocab.txt"
)
with
io
.
open
(
vocab_path
,
"w"
,
encoding
=
"utf8"
)
as
handle
:
handle
.
write
(
"
\n
"
.
join
(
common_words
))
return
Vocab
(
outdir
)
def
normalize_answer_tokens
(
tokens
):
start
=
0
end
=
len
(
tokens
)
while
end
-
start
>
1
:
first_token
=
tokens
[
start
].
rstrip
().
lower
()
if
first_token
in
string
.
punctuation
or
first_token
in
ARTICLES
:
start
+=
1
else
:
break
while
end
-
start
>
1
:
last_token
=
tokens
[
end
-
1
].
rstrip
().
lower
()
if
last_token
in
string
.
punctuation
:
end
-=
1
else
:
break
return
start
,
end
def
tokenize_example
(
question
,
context
,
answers
,
strip_labels
=
True
):
# Q: How should we choose the right answer
answer
=
answers
[
0
][
"text"
]
answer_start
=
answers
[
0
][
"answer_start"
]
if
strip_labels
:
answer_tokens
=
ciseau
.
tokenize
(
answer
,
normalize_ascii
=
False
)
start_offset
,
end_offset
=
normalize_answer_tokens
(
answer_tokens
)
answer
=
""
.
join
(
answer_tokens
[
start_offset
:
end_offset
])
# add back the piece that was stripped off:
answer_start
=
answer_start
+
len
(
""
.
join
(
answer_tokens
[:
start_offset
]))
# replace answer string with placeholder
placeholder
=
"XXXX"
new_context
=
context
[:
answer_start
]
+
placeholder
+
context
[
answer_start
+
len
(
answer
):]
token_context
=
ciseau
.
sent_tokenize
(
new_context
,
keep_whitespace
=
True
)
token_question
=
ciseau
.
tokenize
(
question
)
sentence_label
=
None
for
sent_idx
,
sent
in
enumerate
(
token_context
):
answer_start
=
None
for
idx
,
word
in
enumerate
(
sent
):
if
placeholder
in
word
:
answer_start
=
idx
break
if
answer_start
is
None
:
continue
sentence_label
=
sent_idx
# deal with cases where the answer is in the middle
# of the word
answer
=
word
.
replace
(
placeholder
,
answer
)
token_answer
=
ciseau
.
tokenize
(
answer
)
answer_end
=
answer_start
+
len
(
token_answer
)
-
1
answer_sent
=
sent
[:
answer_start
]
+
token_answer
+
sent
[
answer_start
+
1
:]
break
token_context
[
sentence_label
]
=
answer_sent
return
token_question
,
token_context
,
sentence_label
,
answer_start
,
answer_end
def
normalize
(
word
):
return
word
.
strip
()
def
same_as_question_feature
(
question_idxs
,
context_idxs
,
vocab
):
question_words
=
[
vocab
.
idx_to_word
(
idx
)
for
idx
in
question_idxs
]
# remove stop word and puncutation
question_words
=
set
([
w
.
strip
().
lower
()
for
w
in
question_words
if
w
not
in
ARTICLES
and
w
not
in
string
.
punctuation
])
features
=
[]
for
word_idx
in
context_idxs
:
word
=
vocab
.
idx_to_word
(
word_idx
)
features
.
append
(
int
(
word
.
strip
().
lower
()
in
question_words
))
return
features
def
repeated_word_features
(
context_idxs
,
vocab
):
context_words
=
[
vocab
.
idx_to_word
(
idx
)
for
idx
in
context_idxs
]
word_counter
=
{}
for
word
in
context_words
:
canon
=
word
.
strip
().
lower
()
if
canon
in
word_counter
:
word_counter
[
canon
]
+=
1
else
:
word_counter
[
canon
]
=
1
max_occur
=
max
(
word_counter
.
values
())
min_occur
=
min
(
word_counter
.
values
())
occur_range
=
max
(
1.0
,
max_occur
-
min_occur
)
repeated_words
=
[]
repeated_word_intensity
=
[]
for
word
in
context_words
:
canon
=
word
.
strip
().
lower
()
count
=
word_counter
[
canon
]
repeated
=
float
(
count
>
1
and
canon
not
in
ARTICLES
and
canon
not
in
string
.
punctuation
)
intensity
=
float
((
count
-
min_occur
)
/
occur_range
)
repeated_words
.
append
(
repeated
)
repeated_word_intensity
.
append
(
intensity
)
return
repeated_words
,
repeated_word_intensity
def
convert_example_to_indices
(
example
,
outfile
,
vocab
):
print
(
"Processing {}"
.
format
(
outfile
))
question
,
context
,
answers
,
qa_id
=
example
tokenized
=
tokenize_example
(
question
,
context
,
answers
,
strip_labels
=
True
)
token_question
,
token_context
,
ans_sent
,
ans_start
,
ans_end
=
tokenized
# Convert to indices
question_idxs
=
[
vocab
.
word_to_idx
(
normalize
(
w
))
for
w
in
token_question
]
# + 1 for end of sentence
sent_lengths
=
[
len
(
sent
)
+
1
for
sent
in
token_context
]
context_idxs
=
[]
for
sent
in
token_context
:
for
w
in
sent
:
context_idxs
.
append
(
vocab
.
word_to_idx
(
normalize
(
w
)))
context_idxs
.
append
(
vocab
.
eos
)
same_as_question
=
same_as_question_feature
(
question_idxs
,
context_idxs
,
vocab
)
repeated_words
,
repeated_intensity
=
repeated_word_features
(
context_idxs
,
vocab
)
features
=
{
"question"
:
question_idxs
,
"context"
:
context_idxs
,
"ans_sentence"
:
ans_sent
,
"ans_start"
:
ans_start
,
"ans_end"
:
ans_end
,
"sent_lengths"
:
sent_lengths
,
"same_as_question_word"
:
same_as_question
,
"repeated_words"
:
repeated_words
,
"repeated_intensity"
:
repeated_intensity
,
"qa_id"
:
qa_id
}
# Hack!: This is not a great way to save indices...
with
io
.
open
(
outfile
,
"w"
,
encoding
=
"utf8"
)
as
handle
:
handle
.
write
(
unicode
(
json
.
dumps
(
features
,
ensure_ascii
=
False
)))
def
featurize_example
(
question
,
context
,
vocab
):
# Convert to indices
question_idxs
=
[
vocab
.
word_to_idx
(
normalize
(
w
))
for
w
in
ciseau
.
tokenize
(
question
,
normalize_ascii
=
False
)
]
context_sents
=
ciseau
.
sent_tokenize
(
context
,
keep_whitespace
=
True
,
normalize_ascii
=
False
)
# + 1 for end of sentence
sent_lengths
=
[
len
(
sent
)
+
1
for
sent
in
context_sents
]
context_idxs
=
[]
for
sent
in
context_sents
:
for
w
in
sent
:
context_idxs
.
append
(
vocab
.
word_to_idx
(
normalize
(
w
)))
context_idxs
.
append
(
vocab
.
eos
)
same_as_question
=
same_as_question_feature
(
question_idxs
,
context_idxs
,
vocab
)
repeated_words
,
repeated_intensity
=
repeated_word_features
(
context_idxs
,
vocab
)
return
(
question_idxs
,
context_idxs
,
same_as_question
,
repeated_words
,
repeated_intensity
,
sent_lengths
),
context_sents
def
random_sample
(
data
,
k
,
replace
=
False
):
indices
=
np
.
arange
(
len
(
data
))
chosen_indices
=
np
.
random
.
choice
(
indices
,
k
,
replace
=
replace
)
return
[
data
[
idx
]
for
idx
in
chosen_indices
]
@
click
.
command
()
@
click
.
option
(
"--datadir"
,
type
=
str
,
help
=
"Path to raw data"
)
@
click
.
option
(
"--outdir"
,
type
=
str
,
help
=
"Path to save the result"
)
@
click
.
option
(
"--glove-path"
,
default
=
"/mnt/data/jmiller/glove.840B.300d.txt"
)
def
preprocess
(
datadir
,
outdir
,
glove_path
):
if
not
os
.
path
.
exists
(
outdir
):
os
.
makedirs
(
outdir
)
print
(
"Constructing vocabularies..."
)
vocab
=
build_vocabulary
(
datadir
,
outdir
,
glove_path
)
print
(
"Finished..."
)
print
(
"Building word embedding matrix..."
)
vocab
.
construct_embedding_matrix
(
glove_path
)
print
(
"Finished..."
)
# Create training featurizations
for
split
in
splits
:
results_path
=
os
.
path
.
join
(
outdir
,
split
)
os
.
makedirs
(
results_path
)
# process each example
examples
=
list
(
data_stream
(
os
.
path
.
join
(
datadir
,
split
+
".json"
)))
for
idx
,
example
in
enumerate
(
examples
):
outfile
=
os
.
path
.
join
(
results_path
,
str
(
idx
)
+
".json"
)
convert_example_to_indices
(
example
,
outfile
,
vocab
)
print
(
"Building evaluation featurization..."
)
eval_feats
=
[]
for
question
,
context
,
_
,
qa_id
in
data_stream
(
os
.
path
.
join
(
datadir
,
"dev.json"
)):
features
,
tokenized_context
=
featurize_example
(
question
,
context
,
vocab
)
eval_feats
.
append
((
qa_id
,
tokenized_context
,
features
))
with
io
.
open
(
os
.
path
.
join
(
outdir
,
"eval.json"
),
"w"
,
encoding
=
"utf8"
)
as
handle
:
handle
.
write
(
unicode
(
json
.
dumps
(
eval_feats
,
ensure_ascii
=
False
)))
if
__name__
==
"__main__"
:
preprocess
()
globally_normalized_reader/index.html
浏览文件 @
6ad2e797
...
...
@@ -67,10 +67,13 @@ You can also visit https://github.com/baidu-research/GloballyNormalizedReader to
docker pull paddledev/paddle
```
2. Download all necessary data by running:
```bash
cd data
&&
./download.sh
```
3. **(TODO) add the preprocess and featurizer scripts.**
```bash
cd data
&&
./download.sh
&&
cd ..
```
3. Preprocess and featurizer data:
```bash
python featurize.py --datadir data --outdir data/featurized --glove-path data/glove.840B.300d.txt
```
# Training a Model
...
...
globally_normalized_reader/vocab.py
0 → 100644
浏览文件 @
6ad2e797
# -*- coding: utf-8 -*-
import
os
import
io
import
numpy
as
np
# Constants
UNK
=
"<UNK>"
SOS
=
"<SOS>"
EOS
=
"<EOS>"
PAD
=
"<PAD>"
VOCAB_DIM
=
2196017
EMBEDDING_DIM
=
300
WORD2VEC
=
None
class
Vocab
(
object
):
"""Class to hold the vocabulary for the SquadDataset."""
def
__init__
(
self
,
path
):
self
.
_id_to_word
=
[]
self
.
_word_to_id
=
{}
self
.
_word_ending_tables
=
{}
self
.
_path
=
path
self
.
_pad
=
-
1
self
.
_unk
=
None
self
.
_sos
=
None
self
.
_eos
=
None
# first read in the base vocab
with
io
.
open
(
os
.
path
.
join
(
path
,
"vocab.txt"
),
"r"
)
as
f
:
for
idx
,
line
in
enumerate
(
f
):
word_name
=
line
.
strip
()
if
word_name
==
UNK
:
self
.
_unk
=
idx
elif
word_name
==
SOS
:
self
.
_sos
=
idx
elif
word_name
==
EOS
:
self
.
_eos
=
idx
self
.
_id_to_word
.
append
(
word_name
)
self
.
_word_to_id
[
word_name
]
=
idx
@
property
def
unk
(
self
):
return
self
.
_unk
@
property
def
sos
(
self
):
return
self
.
_sos
@
property
def
eos
(
self
):
return
self
.
_eos
@
property
def
size
(
self
):
return
len
(
self
.
_id_to_word
)
def
word_to_idx
(
self
,
word
):
if
word
in
self
.
_word_to_id
:
return
self
.
_word_to_id
[
word
]
return
self
.
unk
def
idx_to_word
(
self
,
idx
):
if
idx
==
self
.
_pad
:
return
PAD
if
idx
<
self
.
size
:
return
self
.
_id_to_word
[
idx
]
return
"ERROR"
def
decode
(
self
,
idxs
):
return
" "
.
join
([
self
.
idx_to_word
(
idx
)
for
idx
in
idxs
])
def
encode
(
self
,
sentence
):
return
[
self
.
word_to_idx
(
word
)
for
word
in
sentence
]
@
property
def
word_embeddings
(
self
):
embedding_path
=
os
.
path
.
join
(
self
.
_path
,
"embeddings.npy"
)
embeddings
=
np
.
load
(
embedding_path
)
return
embeddings
def
construct_embedding_matrix
(
self
,
glove_path
):
# Randomly initialize word embeddings
embeddings
=
np
.
random
.
randn
(
self
.
size
,
EMBEDDING_DIM
).
astype
(
np
.
float32
)
load_word_vectors
(
param
=
embeddings
,
vocab
=
self
.
_id_to_word
,
path
=
glove_path
,
missing_word_alternative
=
missing_word_heuristic
,
missing_word_value
=
lambda
:
0.0
)
embedding_path
=
os
.
path
.
join
(
self
.
_path
,
"embeddings.npy"
)
np
.
save
(
embedding_path
,
embeddings
)
def
missing_word_heuristic
(
word
,
word2vec
):
"""
propose alternate spellings of a word to match against
pretrained word vectors (so that if the original spelling
has no pretrained vector, but alternate spelling does,
a vector can be retrieved anyways.)
"""
if
len
(
word
)
>
5
:
# try to find similar words that share
# the same 5 character ending:
most_sim
=
word2vec
.
words_ending_in
(
word
[
-
5
:])
if
len
(
most_sim
)
>
0
:
most_sim
=
sorted
(
most_sim
,
reverse
=
True
,
key
=
lambda
x
:
(
(
word
[
0
].
isupper
()
==
x
[
0
].
isupper
())
+
(
word
.
lower
()[:
3
]
==
x
.
lower
()[:
3
])
+
(
word
.
lower
()[:
4
]
==
x
.
lower
()[:
4
])
+
(
abs
(
len
(
word
)
-
len
(
x
))
<
5
)
)
)
return
most_sim
[:
1
]
if
all
(
not
c
.
isalpha
()
for
c
in
word
):
# this is a fully numerical answer (and non alpha)
return
[
'13'
,
'9'
,
'100'
,
'2.0'
]
return
[
# add a capital letter
word
.
capitalize
(),
# see if word has spurious period
word
.
split
(
"."
)[
0
],
# see if word has spurious backslash
word
.
split
(
"/"
)[
0
],
# see if word has spurious parenthesis
word
.
split
(
")"
)[
0
],
word
.
split
(
"("
)[
0
]
]
class
Word2Vec
(
object
):
"""
Load word2vec result from file
"""
def
__init__
(
self
,
vocab_size
,
vector_size
):
self
.
syn0
=
np
.
zeros
((
vocab_size
,
vector_size
),
dtype
=
np
.
float32
)
self
.
index2word
=
[]
self
.
vocab_size
=
vocab_size
self
.
vector_size
=
vector_size
def
load_word2vec_format
(
self
,
path
):
with
io
.
open
(
path
,
"r"
)
as
fin
:
for
word_id
in
range
(
self
.
vocab_size
):
line
=
fin
.
readline
()
parts
=
line
.
rstrip
(
"
\n
"
).
rstrip
().
split
(
" "
)
if
len
(
parts
)
!=
self
.
vector_size
+
1
:
raise
ValueError
(
"invalid vector on line {}"
.
format
(
word_id
))
word
,
weights
=
parts
[
0
],
[
np
.
float32
(
x
)
for
x
in
parts
[
1
:]]
self
.
syn0
[
word_id
]
=
weights
self
.
index2word
.
append
(
word
)
return
self
class
FastWord2vec
(
object
):
"""
Load word2vec model, cache the embedding matrix using numpy
and memory-map it so that future loads are fast.
"""
def
__init__
(
self
,
path
):
if
not
os
.
path
.
exists
(
path
+
".npy"
):
word2vec
=
Word2Vec
(
VOCAB_DIM
,
EMBEDDING_DIM
).
load_word2vec_format
(
path
)
# save as numpy
np
.
save
(
path
+
".npy"
,
word2vec
.
syn0
)
# also save the vocab
with
io
.
open
(
path
+
".vocab"
,
"w"
,
encoding
=
"utf8"
)
as
fout
:
for
word
in
word2vec
.
index2word
:
fout
.
write
(
word
+
"
\n
"
)
self
.
syn0
=
np
.
load
(
path
+
".npy"
,
mmap_mode
=
"r"
)
self
.
index2word
=
[
l
.
strip
(
"
\n
"
)
for
l
in
io
.
open
(
path
+
".vocab"
,
"r"
)]
self
.
word2index
=
{
word
:
k
for
k
,
word
in
enumerate
(
self
.
index2word
)}
self
.
_word_ending_tables
=
{}
self
.
_word_beginning_tables
=
{}
def
__getitem__
(
self
,
key
):
return
np
.
array
(
self
.
syn0
[
self
.
word2index
[
key
]])
def
__contains__
(
self
,
key
):
return
key
in
self
.
word2index
def
words_ending_in
(
self
,
word_ending
):
if
len
(
word_ending
)
==
0
:
return
self
.
index2word
self
.
_build_word_ending_table
(
len
(
word_ending
))
return
self
.
_word_ending_tables
[
len
(
word_ending
)].
get
(
word_ending
,
[])
def
_build_word_ending_table
(
self
,
length
):
if
length
not
in
self
.
_word_ending_tables
:
table
=
{}
for
word
in
self
.
index2word
:
if
len
(
word
)
>=
length
:
ending
=
word
[
-
length
:]
if
ending
not
in
table
:
table
[
ending
]
=
[
word
]
else
:
table
[
ending
].
append
(
word
)
self
.
_word_ending_tables
[
length
]
=
table
def
words_starting_in
(
self
,
word_beginning
):
if
len
(
word_beginning
)
==
0
:
return
self
.
index2word
self
.
_build_word_beginning_table
(
len
(
word_beginning
))
return
self
.
_word_beginning_tables
[
len
(
word_beginning
)].
get
(
word_beginning
,
[])
def
_build_word_beginning_table
(
self
,
length
):
if
length
not
in
self
.
_word_beginning_tables
:
table
=
{}
for
word
in
get_progress_bar
(
'building prefix lookup '
)(
self
.
index2word
):
if
len
(
word
)
>=
length
:
ending
=
word
[:
length
]
if
ending
not
in
table
:
table
[
ending
]
=
[
word
]
else
:
table
[
ending
].
append
(
word
)
self
.
_word_beginning_tables
[
length
]
=
table
@
staticmethod
def
get
(
path
):
global
WORD2VEC
if
WORD2VEC
is
None
:
WORD2VEC
=
FastWord2vec
(
path
)
return
WORD2VEC
def
load_word_vectors
(
param
,
vocab
,
path
,
verbose
=
True
,
missing_word_alternative
=
None
,
missing_word_value
=
None
):
"""
Add the pre-trained word embeddings stored under path to the parameter
matrix `param` that has size `vocab x embedding_dim`.
Arguments:
param : np.array
vocab : list<str>
path : str, location of the pretrained word embeddings
verbose : (optional) bool, whether to print how
many words were recovered
"""
word2vec
=
FastWord2vec
.
get
(
path
)
missing
=
0
for
idx
,
word
in
enumerate
(
vocab
):
try
:
param
[
idx
,
:]
=
word2vec
[
word
]
except
KeyError
:
try
:
param
[
idx
,
:]
=
word2vec
[
word
.
lower
()]
except
KeyError
:
found
=
False
if
missing_word_alternative
is
not
None
:
alternatives
=
missing_word_alternative
(
word
,
word2vec
)
if
isinstance
(
alternatives
,
str
):
alternatives
=
[
alternatives
]
assert
(
isinstance
(
alternatives
,
list
)),
(
"missing_word_alternative should return a list of strings."
)
for
alternative
in
alternatives
:
if
alternative
in
word2vec
:
param
[
idx
,
:]
=
word2vec
[
alternative
]
found
=
True
break
if
not
found
:
if
missing_word_value
is
not
None
:
param
[
idx
,
:]
=
missing_word_value
()
missing
+=
1
if
verbose
:
print
(
"Loaded {} words, {} missing"
.
format
(
len
(
vocab
)
-
missing
,
missing
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录