Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleRec
提交
e858dff8
P
PaddleRec
项目概览
PaddlePaddle
/
PaddleRec
通知
68
Star
12
Fork
5
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
27
列表
看板
标记
里程碑
合并请求
10
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
27
Issue
27
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e858dff8
编写于
8月 17, 2020
作者:
M
malin10
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bug fix for w2v
上级
8ef34ea6
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
350 addition
and
10 deletion
+350
-10
core/trainers/framework/dataset.py
core/trainers/framework/dataset.py
+2
-0
core/utils/dataloader_instance.py
core/utils/dataloader_instance.py
+4
-0
models/recall/word2vec/README.md
models/recall/word2vec/README.md
+10
-2
models/recall/word2vec/config.yaml
models/recall/word2vec/config.yaml
+2
-2
models/recall/word2vec/data_prepare.sh
models/recall/word2vec/data_prepare.sh
+1
-1
models/recall/word2vec/infer.py
models/recall/word2vec/infer.py
+155
-0
models/recall/word2vec/model.py
models/recall/word2vec/model.py
+2
-2
models/recall/word2vec/utils.py
models/recall/word2vec/utils.py
+131
-0
models/recall/word2vec/w2v_evaluate_reader.py
models/recall/word2vec/w2v_evaluate_reader.py
+1
-1
models/recall/word2vec/w2v_reader.py
models/recall/word2vec/w2v_reader.py
+42
-2
未找到文件。
core/trainers/framework/dataset.py
浏览文件 @
e858dff8
...
...
@@ -68,6 +68,8 @@ class DataLoader(DatasetBase):
reader_ins
=
SlotReader
(
context
[
"config_yaml"
])
if
hasattr
(
reader_ins
,
'generate_batch_from_trainfiles'
):
dataloader
.
set_sample_list_generator
(
reader
)
elif
hasattr
(
reader_ins
,
'batch_tensor_creator'
):
dataloader
.
set_batch_generator
(
reader
)
else
:
dataloader
.
set_sample_generator
(
reader
,
batch_size
)
return
dataloader
...
...
core/utils/dataloader_instance.py
浏览文件 @
e858dff8
...
...
@@ -67,6 +67,10 @@ def dataloader_by_name(readerclass,
if
hasattr
(
reader
,
'generate_batch_from_trainfiles'
):
return
gen_batch_reader
()
if
hasattr
(
reader
,
"batch_tensor_creator"
):
return
reader
.
batch_tensor_creator
(
gen_reader
)
return
gen_reader
...
...
models/recall/word2vec/README.md
浏览文件 @
e858dff8
...
...
@@ -19,6 +19,8 @@
├── data_prepare.sh #一键数据处理脚本
├── w2v_reader.py #训练数据reader
├── w2v_evaluate_reader.py # 预测数据reader
├── infer.py # 自定义预测脚本
├── utils.py # 自定义预测中用到的reader等工具
```
注:在阅读该示例前,建议您先了解以下内容:
...
...
@@ -154,9 +156,12 @@ runner:
phases: [phase1]
```
### 单机预测
我们通过词类比(Word Analogy)任务来检验word2vec模型的训练效果。输入四个词A,B,C,D,假设存在一种关系relation, 使得relation(A, B) = relation(C, D),然后通过A,B,C去预测D,emb(D) = emb(B) - emb(A) + emb(C)。
CPU环境
PaddleRec预测配置:
在config.yaml文件中设置好epochs、device等参数。
```
...
...
@@ -168,6 +173,10 @@ CPU环境
print_interval: 1
phases: [phase2]
```
为复现论文效果,我们提供了一个自定义预测脚本,自定义预测中,我们会跳过预测结果是输入A,B,C的情况,计算预测准确率。执行命令如下:
```
python infer.py --test_dir ./data/test --dict_path ./data/dict/word_id_dict.txt --batch_size 20000 --model_dir ./increment_w2v/ --start_index 0 --last_index 5 --emb_size 300
```
### 运行
```
...
...
@@ -212,13 +221,12 @@ Infer phase2 of epoch 3 done, use time: 4.43099021912, global metrics: acc=[1.]
-
batch_size: 修改config.yaml中dataset_train数据集的batch_size为100。
-
epochs: 修改config.yaml中runner的epochs为5。
使用cpu训练 5轮 测试Recall@20:0.540
修改后运行方案:修改config.yaml中的'workspace'为config.yaml的目录位置,执行
```
python -m paddlerec.run -m /home/your/dir/config.yaml #调试模式 直接指定本地config的绝对路径
```
使用cpu训练5轮,自定义测试(跳过输入)准确率为0.540。
## 进阶使用
## FAQ
models/recall/word2vec/config.yaml
浏览文件 @
e858dff8
...
...
@@ -22,7 +22,7 @@ dataset:
word_count_dict_path
:
"
{workspace}/data/dict/word_count_dict.txt"
data_converter
:
"
{workspace}/w2v_reader.py"
-
name
:
dataset_infer
# name
batch_size
:
5
0
batch_size
:
200
0
type
:
DataLoader
# or QueueDataset
data_path
:
"
{workspace}/data/test"
word_id_dict_path
:
"
{workspace}/data/dict/word_id_dict.txt"
...
...
@@ -59,7 +59,7 @@ runner:
save_inference_feed_varnames
:
[]
# feed vars of save inference
save_inference_fetch_varnames
:
[]
# fetch vars of save inference
init_model_path
:
"
"
# load model path
print_interval
:
1
print_interval
:
1
000
phases
:
[
phase1
]
-
name
:
single_cpu_infer
class
:
infer
...
...
models/recall/word2vec/data_prepare.sh
浏览文件 @
e858dff8
...
...
@@ -25,7 +25,7 @@ mv 1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tok
python preprocess.py
--build_dict
--build_dict_corpus_dir
raw_data/training-monolingual.tokenized.shuffled
--dict_path
raw_data/word_count_dict.txt
python preprocess.py
--filter_corpus
--dict_path
raw_data/word_count_dict.txt
--input_corpus_dir
raw_data/training-monolingual.tokenized.shuffled
--output_corpus_dir
raw_data/convert_text8
--min_count
5
--downsample
0.001
mv
raw_data/word_count_dict.txt data/dict/
mv
raw_data/word_
id_dict.txt data/dict/
mv
raw_data/word_
count_dict.txt_word_to_id_ data/dict/word_id_dict.txt
rm
-rf
data/train/
*
rm
-rf
data/test/
*
...
...
models/recall/word2vec/infer.py
0 → 100644
浏览文件 @
e858dff8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
sys
import
time
import
math
import
numpy
as
np
import
six
import
paddle.fluid
as
fluid
import
paddle
import
utils
if
six
.
PY2
:
reload
(
sys
)
sys
.
setdefaultencoding
(
'utf-8'
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
"PaddlePaddle Word2vec infer example"
)
parser
.
add_argument
(
'--dict_path'
,
type
=
str
,
default
=
'./data/data_c/1-billion_dict_word_to_id_'
,
help
=
"The path of dic"
)
parser
.
add_argument
(
'--test_dir'
,
type
=
str
,
default
=
'test_data'
,
help
=
'test file address'
)
parser
.
add_argument
(
'--print_step'
,
type
=
int
,
default
=
'500000'
,
help
=
'print step'
)
parser
.
add_argument
(
'--start_index'
,
type
=
int
,
default
=
'0'
,
help
=
'start index'
)
parser
.
add_argument
(
'--last_index'
,
type
=
int
,
default
=
'100'
,
help
=
'last index'
)
parser
.
add_argument
(
'--model_dir'
,
type
=
str
,
default
=
'model'
,
help
=
'model dir'
)
parser
.
add_argument
(
'--use_cuda'
,
type
=
int
,
default
=
'0'
,
help
=
'whether use cuda'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
'5'
,
help
=
'batch_size'
)
parser
.
add_argument
(
'--emb_size'
,
type
=
int
,
default
=
'64'
,
help
=
'batch_size'
)
args
=
parser
.
parse_args
()
return
args
def
infer_network
(
vocab_size
,
emb_size
):
analogy_a
=
fluid
.
data
(
name
=
"analogy_a"
,
shape
=
[
None
],
dtype
=
'int64'
)
analogy_b
=
fluid
.
data
(
name
=
"analogy_b"
,
shape
=
[
None
],
dtype
=
'int64'
)
analogy_c
=
fluid
.
data
(
name
=
"analogy_c"
,
shape
=
[
None
],
dtype
=
'int64'
)
all_label
=
fluid
.
data
(
name
=
"all_label"
,
shape
=
[
vocab_size
],
dtype
=
'int64'
)
emb_all_label
=
fluid
.
embedding
(
input
=
all_label
,
size
=
[
vocab_size
,
emb_size
],
param_attr
=
"emb"
)
emb_a
=
fluid
.
embedding
(
input
=
analogy_a
,
size
=
[
vocab_size
,
emb_size
],
param_attr
=
"emb"
)
emb_b
=
fluid
.
embedding
(
input
=
analogy_b
,
size
=
[
vocab_size
,
emb_size
],
param_attr
=
"emb"
)
emb_c
=
fluid
.
embedding
(
input
=
analogy_c
,
size
=
[
vocab_size
,
emb_size
],
param_attr
=
"emb"
)
target
=
fluid
.
layers
.
elementwise_add
(
fluid
.
layers
.
elementwise_sub
(
emb_b
,
emb_a
),
emb_c
)
emb_all_label_l2
=
fluid
.
layers
.
l2_normalize
(
x
=
emb_all_label
,
axis
=
1
)
dist
=
fluid
.
layers
.
matmul
(
x
=
target
,
y
=
emb_all_label_l2
,
transpose_y
=
True
)
values
,
pred_idx
=
fluid
.
layers
.
topk
(
input
=
dist
,
k
=
4
)
return
values
,
pred_idx
def
infer_epoch
(
args
,
vocab_size
,
test_reader
,
use_cuda
,
i2w
):
""" inference function """
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
emb_size
=
args
.
emb_size
batch_size
=
args
.
batch_size
with
fluid
.
scope_guard
(
fluid
.
Scope
()):
main_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_program
):
values
,
pred
=
infer_network
(
vocab_size
,
emb_size
)
for
epoch
in
range
(
start_index
,
last_index
+
1
):
copy_program
=
main_program
.
clone
()
model_path
=
model_dir
+
"/"
+
str
(
epoch
)
fluid
.
io
.
load_persistables
(
exe
,
model_path
,
main_program
=
copy_program
)
accum_num
=
0
accum_num_sum
=
0.0
t0
=
time
.
time
()
step_id
=
0
for
data
in
test_reader
():
step_id
+=
1
b_size
=
len
([
dat
[
0
]
for
dat
in
data
])
wa
=
np
.
array
([
dat
[
0
]
for
dat
in
data
]).
astype
(
"int64"
).
reshape
(
b_size
)
wb
=
np
.
array
([
dat
[
1
]
for
dat
in
data
]).
astype
(
"int64"
).
reshape
(
b_size
)
wc
=
np
.
array
([
dat
[
2
]
for
dat
in
data
]).
astype
(
"int64"
).
reshape
(
b_size
)
label
=
[
dat
[
3
]
for
dat
in
data
]
input_word
=
[
dat
[
4
]
for
dat
in
data
]
para
=
exe
.
run
(
copy_program
,
feed
=
{
"analogy_a"
:
wa
,
"analogy_b"
:
wb
,
"analogy_c"
:
wc
,
"all_label"
:
np
.
arange
(
vocab_size
)
.
reshape
(
vocab_size
).
astype
(
"int64"
),
},
fetch_list
=
[
pred
.
name
,
values
],
return_numpy
=
False
)
pre
=
np
.
array
(
para
[
0
])
val
=
np
.
array
(
para
[
1
])
for
ii
in
range
(
len
(
label
)):
top4
=
pre
[
ii
]
accum_num_sum
+=
1
for
idx
in
top4
:
if
int
(
idx
)
in
input_word
[
ii
]:
continue
if
int
(
idx
)
==
int
(
label
[
ii
][
0
]):
accum_num
+=
1
break
if
step_id
%
1
==
0
:
print
(
"step:%d %d "
%
(
step_id
,
accum_num
))
print
(
"epoch:%d
\t
acc:%.3f "
%
(
epoch
,
1.0
*
accum_num
/
accum_num_sum
))
if
__name__
==
"__main__"
:
args
=
parse_args
()
start_index
=
args
.
start_index
last_index
=
args
.
last_index
test_dir
=
args
.
test_dir
model_dir
=
args
.
model_dir
batch_size
=
args
.
batch_size
dict_path
=
args
.
dict_path
use_cuda
=
True
if
args
.
use_cuda
else
False
print
(
"start index: "
,
start_index
,
" last_index:"
,
last_index
)
vocab_size
,
test_reader
,
id2word
=
utils
.
prepare_data
(
test_dir
,
dict_path
,
batch_size
=
batch_size
)
print
(
"vocab_size:"
,
vocab_size
)
infer_epoch
(
args
,
vocab_size
,
test_reader
=
test_reader
,
use_cuda
=
use_cuda
,
i2w
=
id2word
)
models/recall/word2vec/model.py
浏览文件 @
e858dff8
...
...
@@ -209,10 +209,10 @@ class Model(ModelBase):
emb_all_label_l2
=
fluid
.
layers
.
l2_normalize
(
x
=
emb_all_label
,
axis
=
1
)
dist
=
fluid
.
layers
.
matmul
(
x
=
target
,
y
=
emb_all_label_l2
,
transpose_y
=
True
)
values
,
pred_idx
=
fluid
.
layers
.
topk
(
input
=
dist
,
k
=
4
)
values
,
pred_idx
=
fluid
.
layers
.
topk
(
input
=
dist
,
1
)
label
=
fluid
.
layers
.
expand
(
fluid
.
layers
.
unsqueeze
(
inputs
[
3
],
axes
=
[
1
]),
expand_times
=
[
1
,
4
])
inputs
[
3
],
axes
=
[
1
]),
expand_times
=
[
1
,
1
])
label_ones
=
fluid
.
layers
.
fill_constant_batch_size_like
(
label
,
shape
=
[
-
1
,
1
],
value
=
1.0
,
dtype
=
'float32'
)
right_cnt
=
fluid
.
layers
.
reduce_sum
(
input
=
fluid
.
layers
.
cast
(
...
...
models/recall/word2vec/utils.py
0 → 100644
浏览文件 @
e858dff8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
collections
import
six
import
time
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle
import
os
import
preprocess
import
io
def
BuildWord_IdMap
(
dict_path
):
word_to_id
=
dict
()
id_to_word
=
dict
()
with
io
.
open
(
dict_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
line
in
f
:
word_to_id
[
line
.
split
(
' '
)[
0
]]
=
int
(
line
.
split
(
' '
)[
1
])
id_to_word
[
int
(
line
.
split
(
' '
)[
1
])]
=
line
.
split
(
' '
)[
0
]
return
word_to_id
,
id_to_word
def
prepare_data
(
file_dir
,
dict_path
,
batch_size
):
w2i
,
i2w
=
BuildWord_IdMap
(
dict_path
)
vocab_size
=
len
(
i2w
)
reader
=
fluid
.
io
.
batch
(
test
(
file_dir
,
w2i
),
batch_size
)
return
vocab_size
,
reader
,
i2w
def
check_version
(
with_shuffle_batch
=
False
):
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err
=
"PaddlePaddle version 1.6 or higher is required, "
\
"or a suitable develop version is satisfied as well.
\n
"
\
"Please make sure the version is good with your code."
\
try
:
if
with_shuffle_batch
:
fluid
.
require_version
(
'1.7.0'
)
else
:
fluid
.
require_version
(
'1.6.0'
)
except
Exception
as
e
:
logger
.
error
(
err
)
sys
.
exit
(
1
)
def
native_to_unicode
(
s
):
if
_is_unicode
(
s
):
return
s
try
:
return
_to_unicode
(
s
)
except
UnicodeDecodeError
:
res
=
_to_unicode
(
s
,
ignore_errors
=
True
)
return
res
def
_is_unicode
(
s
):
if
six
.
PY2
:
if
isinstance
(
s
,
unicode
):
return
True
else
:
if
isinstance
(
s
,
str
):
return
True
return
False
def
_to_unicode
(
s
,
ignore_errors
=
False
):
if
_is_unicode
(
s
):
return
s
error_mode
=
"ignore"
if
ignore_errors
else
"strict"
return
s
.
decode
(
"utf-8"
,
errors
=
error_mode
)
def
strip_lines
(
line
,
vocab
):
return
_replace_oov
(
vocab
,
native_to_unicode
(
line
))
def
_replace_oov
(
original_vocab
,
line
):
"""Replace out-of-vocab words with "<UNK>".
This maintains compatibility with published results.
Args:
original_vocab: a set of strings (The standard vocabulary for the dataset)
line: a unicode string - a space-delimited sequence of words.
Returns:
a unicode string - a space-delimited sequence of words.
"""
return
u
" "
.
join
([
word
if
word
in
original_vocab
else
u
"<UNK>"
for
word
in
line
.
split
()
])
def
reader_creator
(
file_dir
,
word_to_id
):
def
reader
():
files
=
os
.
listdir
(
file_dir
)
for
fi
in
files
:
with
io
.
open
(
os
.
path
.
join
(
file_dir
,
fi
),
"r"
,
encoding
=
'utf-8'
)
as
f
:
for
line
in
f
:
if
':'
in
line
:
pass
else
:
line
=
strip_lines
(
line
.
lower
(),
word_to_id
)
line
=
line
.
split
()
yield
[
word_to_id
[
line
[
0
]]],
[
word_to_id
[
line
[
1
]]],
[
word_to_id
[
line
[
2
]]
],
[
word_to_id
[
line
[
3
]]],
[
word_to_id
[
line
[
0
]],
word_to_id
[
line
[
1
]],
word_to_id
[
line
[
2
]]
]
return
reader
def
test
(
test_dir
,
w2i
):
return
reader_creator
(
test_dir
,
w2i
)
models/recall/word2vec/w2v_evaluate_reader.py
浏览文件 @
e858dff8
...
...
@@ -76,7 +76,7 @@ class Reader(ReaderBase):
def
generate_sample
(
self
,
line
):
def
reader
():
if
':'
in
line
:
pass
return
features
=
self
.
strip_lines
(
line
.
lower
(),
self
.
word_to_id
)
features
=
features
.
split
()
yield
[(
'analogy_a'
,
[
self
.
word_to_id
[
features
[
0
]]]),
...
...
models/recall/word2vec/w2v_reader.py
浏览文件 @
e858dff8
...
...
@@ -15,6 +15,7 @@
import
io
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddlerec.core.reader
import
ReaderBase
from
paddlerec.core.utils
import
envs
...
...
@@ -47,6 +48,10 @@ class Reader(ReaderBase):
self
.
with_shuffle_batch
=
envs
.
get_global_env
(
"hyper_parameters.with_shuffle_batch"
)
self
.
random_generator
=
NumpyRandomInt
(
1
,
self
.
window_size
+
1
)
self
.
batch_size
=
envs
.
get_global_env
(
"dataset.dataset_train.batch_size"
)
self
.
is_dataloader
=
envs
.
get_global_env
(
"dataset.dataset_train.type"
)
==
"DataLoader"
self
.
cs
=
None
if
not
self
.
with_shuffle_batch
:
...
...
@@ -88,11 +93,46 @@ class Reader(ReaderBase):
for
context_id
in
context_word_ids
:
output
=
[(
'input_word'
,
[
int
(
target_id
)]),
(
'true_label'
,
[
int
(
context_id
)])]
if
not
self
.
with_shuffle_batch
:
if
self
.
with_shuffle_batch
or
self
.
is_dataloader
:
yield
output
else
:
neg_array
=
self
.
cs
.
searchsorted
(
np
.
random
.
sample
(
self
.
neg_num
))
output
+=
[(
'neg_label'
,
[
int
(
str
(
i
))
for
i
in
neg_array
])]
yield
output
yield
output
return
reader
def
batch_tensor_creator
(
self
,
sample_reader
):
def
__reader__
():
result
=
[[],
[]]
for
sample
in
sample_reader
():
for
i
,
fea
in
enumerate
(
sample
):
result
[
i
].
append
(
fea
)
if
len
(
result
[
0
])
==
self
.
batch_size
:
tensor_result
=
[]
for
tensor
in
result
:
t
=
fluid
.
Tensor
()
dat
=
np
.
array
(
tensor
,
dtype
=
'int64'
)
if
len
(
dat
.
shape
)
>
2
:
dat
=
dat
.
reshape
((
dat
.
shape
[
0
],
dat
.
shape
[
2
]))
elif
len
(
dat
.
shape
)
==
1
:
dat
=
dat
.
reshape
((
-
1
,
1
))
t
.
set
(
dat
,
fluid
.
CPUPlace
())
tensor_result
.
append
(
t
)
if
self
.
with_shuffle_batch
:
yield
tensor_result
else
:
tt
=
fluid
.
Tensor
()
neg_array
=
self
.
cs
.
searchsorted
(
np
.
random
.
sample
(
self
.
neg_num
))
neg_array
=
np
.
tile
(
neg_array
,
self
.
batch_size
)
tt
.
set
(
neg_array
.
reshape
((
self
.
batch_size
,
self
.
neg_num
)),
fluid
.
CPUPlace
())
tensor_result
.
append
(
tt
)
yield
tensor_result
result
=
[[],
[]]
return
__reader__
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录