Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
03353075
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
03353075
编写于
6月 02, 2020
作者:
M
malin10
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add fasttext
上级
1125b796
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
891 addition
and
0 deletion
+891
-0
models/recall/fasttext/__init__.py
models/recall/fasttext/__init__.py
+13
-0
models/recall/fasttext/config.yaml
models/recall/fasttext/config.yaml
+84
-0
models/recall/fasttext/data_prepare.sh
models/recall/fasttext/data_prepare.sh
+40
-0
models/recall/fasttext/evaluate_reader.py
models/recall/fasttext/evaluate_reader.py
+107
-0
models/recall/fasttext/model.py
models/recall/fasttext/model.py
+226
-0
models/recall/fasttext/preprocess.py
models/recall/fasttext/preprocess.py
+313
-0
models/recall/fasttext/reader.py
models/recall/fasttext/reader.py
+108
-0
未找到文件。
models/recall/fasttext/__init__.py
0 → 100755
浏览文件 @
03353075
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
models/recall/fasttext/config.yaml
0 → 100755
浏览文件 @
03353075
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#workspace: "paddlerec.models.recall.fasttext"
workspace
:
"
/home/malin10/code/paddlerec/models/recall/fasttext"
# list of dataset
dataset
:
-
name
:
dataset_train
# name of dataset to distinguish different datasets
batch_size
:
10
type
:
DataLoader
# or QueueDataset
data_path
:
"
{workspace}/data/train"
word_count_dict_path
:
"
{workspace}/data/dict/word_count_dict.txt"
word_ngrams_path
:
"
{workspace}/data/dict/word_ngrams_id.txt"
data_converter
:
"
{workspace}/reader.py"
-
name
:
dataset_infer
# name
batch_size
:
10
type
:
DataLoader
# or QueueDataset
data_path
:
"
{workspace}/data/test"
word_id_dict_path
:
"
{workspace}/data/dict/word_id_dict.txt"
data_converter
:
"
{workspace}/evaluate_reader.py"
hyper_parameters
:
optimizer
:
learning_rate
:
1.0
decay_steps
:
100000
decay_rate
:
0.999
class
:
sgd
strategy
:
async
sparse_feature_number
:
227915
sparse_feature_dim
:
300
with_shuffle_batch
:
False
neg_num
:
5
window_size
:
5
min_n
:
3
max_n
:
5
# select runner by name
mode
:
train_runner
# config of each runner.
# runner is a kind of paddle training class, which wraps the train/infer process.
runner
:
-
name
:
train_runner
class
:
single_train
# num of epochs
epochs
:
2
# device to run training or infer
device
:
cpu
save_checkpoint_interval
:
1
# save model interval of epochs
save_inference_interval
:
1
# save inference
save_checkpoint_path
:
"
increment"
# save checkpoint path
save_inference_path
:
"
inference"
# save inference path
save_inference_feed_varnames
:
[]
# feed vars of save inference
save_inference_fetch_varnames
:
[]
# fetch vars of save inference
init_model_path
:
"
"
# load model path
fetch_period
:
10
-
name
:
infer_runner
class
:
single_infer
# num of epochs
epochs
:
1
# device to run training or infer
device
:
cpu
init_model_path
:
"
increment/0"
# load model path
# runner will run all the phase in each epoch
phase
:
-
name
:
phase1
model
:
"
{workspace}/model.py"
# user-defined model
dataset_name
:
dataset_train
# select dataset by name
thread_num
:
1
#- name: phase2
# model: "{workspace}/model.py" # user-defined model
# dataset_name: dataset_infer # select dataset by name
# thread_num: 1
models/recall/fasttext/data_prepare.sh
0 → 100755
浏览文件 @
03353075
#! /bin/bash
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# download train_data
mkdir
raw_data
wget
--no-check-certificate
https://paddlerec.bj.bcebos.com/word2vec/1-billion-word-language-modeling-benchmark-r13output.tar
tar
xvf 1-billion-word-language-modeling-benchmark-r13output.tar
mv
1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/ raw_data/
# preprocess data
python preprocess.py
--build_dict
--build_dict_corpus_dir
raw_data/training-monolingual.tokenized.shuffled
--dict_path
raw_data/word_count_dict.txt
--ngrams_path
raw_data/word_ngrams.txt
python preprocess.py
--filter_corpus
--dict_path
raw_data/word_count_dict.txt
--word_id_path
raw_data/word_id_dict.txt
--input_corpus_dir
raw_data/training-monolingual.tokenized.shuffled
--output_corpus_dir
raw_data/convert_text8
--ngrams_id_path
raw_data/word_ngrams_id.txt
--ngrams_path
raw_data/word_ngrams.txt
--min_count
5
--downsample
0.001
mv
raw_data/word_count_dict.txt data/dict/
mv
raw_data/word_id_dict.txt data/dict/
mv
raw_data/word_ngrams.txt data/dict/
mv
raw_data/word_ngrams_id.txt data/dict/
rm
-rf
data/train/
*
rm
-rf
data/test/
*
python preprocess.py
--data_resplit
--file_nums
24
--input_corpus_dir
=
raw_data/convert_text8
--output_corpus_dir
=
data/train
# download test data
wget
--no-check-certificate
https://paddlerec.bj.bcebos.com/word2vec/test_dir.tar
tar
xzvf test_dir.tar
-C
raw_data
mv
raw_data/data/test_dir/
*
data/test/
rm
-rf
raw_data
models/recall/fasttext/evaluate_reader.py
0 → 100755
浏览文件 @
03353075
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
io
import
six
from
paddlerec.core.reader
import
Reader
from
paddlerec.core.utils
import
envs
class
TrainReader
(
Reader
):
def
init
(
self
):
dict_path
=
envs
.
get_global_env
(
"dataset.dataset_infer.word_id_dict_path"
)
self
.
min_n
=
envs
.
get_global_env
(
"hyper_parameters.min_n"
)
self
.
max_n
=
envs
.
get_global_env
(
"hyper_parameters.max_n"
)
self
.
word_to_id
=
dict
()
self
.
id_to_word
=
dict
()
with
io
.
open
(
dict_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
line
in
f
:
self
.
word_to_id
[
line
.
split
(
' '
)[
0
]]
=
int
(
line
.
split
(
' '
)[
1
])
self
.
id_to_word
[
int
(
line
.
split
(
' '
)[
1
])]
=
line
.
split
(
' '
)[
0
]
self
.
dict_size
=
len
(
self
.
word_to_id
)
def
computeSubwords
(
self
,
word
):
ngrams
=
set
()
for
i
in
range
(
len
(
word
)
-
self
.
min_n
+
1
):
for
j
in
range
(
self
.
min_n
,
self
.
max_n
+
1
):
end
=
min
(
len
(
word
),
i
+
j
)
ngrams
.
add
(
""
.
join
(
word
[
i
:
end
]))
return
list
(
ngrams
)
def
native_to_unicode
(
self
,
s
):
if
self
.
_is_unicode
(
s
):
return
s
try
:
return
self
.
_to_unicode
(
s
)
except
UnicodeDecodeError
:
res
=
self
.
_to_unicode
(
s
,
ignore_errors
=
True
)
return
res
def
_is_unicode
(
self
,
s
):
if
six
.
PY2
:
if
isinstance
(
s
,
unicode
):
return
True
else
:
if
isinstance
(
s
,
str
):
return
True
return
False
def
_to_unicode
(
self
,
s
,
ignore_errors
=
False
):
if
self
.
_is_unicode
(
s
):
return
s
error_mode
=
"ignore"
if
ignore_errors
else
"strict"
return
s
.
decode
(
"utf-8"
,
errors
=
error_mode
)
def
strip_lines
(
self
,
line
,
vocab
):
return
self
.
_replace_oov
(
vocab
,
self
.
native_to_unicode
(
line
))
def
_replace_oov
(
self
,
original_vocab
,
line
):
"""Replace out-of-vocab words with "<UNK>".
This maintains compatibility with published results.
Args:
original_vocab: a set of strings (The standard vocabulary for the dataset)
line: a unicode string - a space-delimited sequence of words.
Returns:
a unicode string - a space-delimited sequence of words.
"""
return
u
" "
.
join
([
"<"
+
word
+
">"
if
"<"
+
word
+
">"
in
original_vocab
else
u
"<UNK>"
for
word
in
line
.
split
()
])
def
generate_sample
(
self
,
line
):
def
reader
():
if
':'
in
line
:
pass
features
=
self
.
strip_lines
(
line
.
lower
(),
self
.
word_to_id
)
features
=
features
.
split
()
inputs
=
[]
for
item
in
features
:
if
item
==
"<UNK>"
:
inputs
.
append
([
self
.
word_to_id
[
item
]])
else
:
ngrams
=
self
.
computeSubwords
(
item
)
res
=
[]
res
.
append
(
self
.
word_to_id
[
item
])
for
_
in
ngrams
:
res
.
append
(
self
.
word_to_id
[
_
])
inputs
.
append
(
res
)
yield
[(
'analogy_a'
,
inputs
[
0
]),
(
'analogy_b'
,
inputs
[
1
]),
(
'analogy_c'
,
inputs
[
2
]),
(
'analogy_d'
,
inputs
[
3
][
0
:
1
])]
return
reader
models/recall/fasttext/model.py
0 → 100755
浏览文件 @
03353075
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddlerec.core.utils
import
envs
from
paddlerec.core.model
import
Model
as
ModelBase
class
Model
(
ModelBase
):
def
__init__
(
self
,
config
):
ModelBase
.
__init__
(
self
,
config
)
def
_init_hyper_parameters
(
self
):
self
.
is_distributed
=
True
if
envs
.
get_trainer
(
)
==
"CtrTrainer"
else
False
self
.
sparse_feature_number
=
envs
.
get_global_env
(
"hyper_parameters.sparse_feature_number"
)
self
.
sparse_feature_dim
=
envs
.
get_global_env
(
"hyper_parameters.sparse_feature_dim"
)
self
.
neg_num
=
envs
.
get_global_env
(
"hyper_parameters.neg_num"
)
self
.
with_shuffle_batch
=
envs
.
get_global_env
(
"hyper_parameters.with_shuffle_batch"
)
self
.
learning_rate
=
envs
.
get_global_env
(
"hyper_parameters.optimizer.learning_rate"
)
self
.
decay_steps
=
envs
.
get_global_env
(
"hyper_parameters.optimizer.decay_steps"
)
self
.
decay_rate
=
envs
.
get_global_env
(
"hyper_parameters.optimizer.decay_rate"
)
def
input_data
(
self
,
is_infer
=
False
,
**
kwargs
):
if
is_infer
:
analogy_a
=
fluid
.
data
(
name
=
"analogy_a"
,
shape
=
[
None
,
1
],
lod_level
=
1
,
dtype
=
'int64'
)
analogy_b
=
fluid
.
data
(
name
=
"analogy_b"
,
shape
=
[
None
,
1
],
lod_level
=
1
,
dtype
=
'int64'
)
analogy_c
=
fluid
.
data
(
name
=
"analogy_c"
,
shape
=
[
None
,
1
],
lod_level
=
1
,
dtype
=
'int64'
)
analogy_d
=
fluid
.
data
(
name
=
"analogy_d"
,
shape
=
[
None
,
1
],
dtype
=
'int64'
)
return
[
analogy_a
,
analogy_b
,
analogy_c
,
analogy_d
]
input_word
=
fluid
.
data
(
name
=
"input_word"
,
shape
=
[
None
,
1
],
lod_level
=
1
,
dtype
=
'int64'
)
true_word
=
fluid
.
data
(
name
=
'true_label'
,
shape
=
[
None
,
1
],
lod_level
=
1
,
dtype
=
'int64'
)
if
self
.
with_shuffle_batch
:
return
[
input_word
,
true_word
]
neg_word
=
fluid
.
data
(
name
=
"neg_label"
,
shape
=
[
None
,
self
.
neg_num
],
dtype
=
'int64'
)
return
[
input_word
,
true_word
,
neg_word
]
def
net
(
self
,
inputs
,
is_infer
=
False
):
if
is_infer
:
self
.
infer_net
(
inputs
)
return
def
embedding_layer
(
input
,
table_name
,
initializer_instance
=
None
,
sequence_pool
=
False
):
emb
=
fluid
.
embedding
(
input
=
input
,
is_sparse
=
True
,
is_distributed
=
self
.
is_distributed
,
size
=
[
self
.
sparse_feature_number
,
self
.
sparse_feature_dim
],
param_attr
=
fluid
.
ParamAttr
(
name
=
table_name
,
initializer
=
initializer_instance
),
)
if
sequence_pool
:
emb
=
fluid
.
layers
.
sequence_pool
(
input
=
emb
,
pool_type
=
'average'
)
return
emb
init_width
=
1.0
/
self
.
sparse_feature_dim
emb_initializer
=
fluid
.
initializer
.
Uniform
(
-
init_width
,
init_width
)
emb_w_initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.0
)
input_emb
=
embedding_layer
(
inputs
[
0
],
"emb"
,
emb_initializer
,
True
)
input_emb
=
fluid
.
layers
.
squeeze
(
input
=
input_emb
,
axes
=
[
1
])
true_emb_w
=
embedding_layer
(
inputs
[
1
],
"emb_w"
,
emb_w_initializer
,
True
)
true_emb_w
=
fluid
.
layers
.
squeeze
(
input
=
true_emb_w
,
axes
=
[
1
])
if
self
.
with_shuffle_batch
:
neg_emb_w_list
=
[]
for
i
in
range
(
self
.
neg_num
):
neg_emb_w_list
.
append
(
fluid
.
contrib
.
layers
.
shuffle_batch
(
true_emb_w
))
# shuffle true_word
neg_emb_w_concat
=
fluid
.
layers
.
concat
(
neg_emb_w_list
,
axis
=
0
)
neg_emb_w
=
fluid
.
layers
.
reshape
(
neg_emb_w_concat
,
shape
=
[
-
1
,
self
.
neg_num
,
self
.
sparse_feature_dim
])
else
:
neg_emb_w
=
embedding_layer
(
inputs
[
2
],
"emb_w"
,
emb_w_initializer
)
true_logits
=
fluid
.
layers
.
reduce_sum
(
fluid
.
layers
.
elementwise_mul
(
input_emb
,
true_emb_w
),
dim
=
1
,
keep_dim
=
True
)
input_emb_re
=
fluid
.
layers
.
reshape
(
input_emb
,
shape
=
[
-
1
,
1
,
self
.
sparse_feature_dim
])
neg_matmul
=
fluid
.
layers
.
matmul
(
input_emb_re
,
neg_emb_w
,
transpose_y
=
True
)
neg_logits
=
fluid
.
layers
.
reshape
(
neg_matmul
,
shape
=
[
-
1
,
1
])
logits
=
fluid
.
layers
.
concat
([
true_logits
,
neg_logits
],
axis
=
0
)
label_ones
=
fluid
.
layers
.
fill_constant
(
shape
=
[
fluid
.
layers
.
shape
(
true_logits
)[
0
],
1
],
value
=
1.0
,
dtype
=
'float32'
)
label_zeros
=
fluid
.
layers
.
fill_constant
(
shape
=
[
fluid
.
layers
.
shape
(
neg_logits
)[
0
],
1
],
value
=
0.0
,
dtype
=
'float32'
)
label
=
fluid
.
layers
.
concat
([
label_ones
,
label_zeros
],
axis
=
0
)
loss
=
fluid
.
layers
.
log_loss
(
fluid
.
layers
.
sigmoid
(
logits
),
label
)
avg_cost
=
fluid
.
layers
.
reduce_sum
(
loss
)
global_right_cnt
=
fluid
.
layers
.
create_global_var
(
name
=
"global_right_cnt"
,
persistable
=
True
,
dtype
=
'float32'
,
shape
=
[
1
],
value
=
0
)
global_total_cnt
=
fluid
.
layers
.
create_global_var
(
name
=
"global_total_cnt"
,
persistable
=
True
,
dtype
=
'float32'
,
shape
=
[
1
],
value
=
0
)
global_right_cnt
.
stop_gradient
=
True
global_total_cnt
.
stop_gradient
=
True
self
.
_cost
=
avg_cost
self
.
_metrics
[
"LOSS"
]
=
avg_cost
def
optimizer
(
self
):
optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
fluid
.
layers
.
exponential_decay
(
learning_rate
=
self
.
learning_rate
,
decay_steps
=
self
.
decay_steps
,
decay_rate
=
self
.
decay_rate
,
staircase
=
True
))
return
optimizer
def
infer_net
(
self
,
inputs
):
def
embedding_layer
(
input
,
table_name
,
initializer_instance
=
None
,
sequence_pool
=
False
):
emb
=
fluid
.
embedding
(
input
=
input
,
size
=
[
self
.
sparse_feature_number
,
self
.
sparse_feature_dim
],
param_attr
=
table_name
)
if
sequence_pool
:
emb
=
fluid
.
layers
.
sequence_pool
(
input
=
emb
,
pool_type
=
'average'
)
return
emb
all_label
=
np
.
arange
(
self
.
sparse_feature_number
).
reshape
(
self
.
sparse_feature_number
).
astype
(
'int32'
)
self
.
all_label
=
fluid
.
layers
.
cast
(
x
=
fluid
.
layers
.
assign
(
all_label
),
dtype
=
'int64'
)
emb_all_label
=
embedding_layer
(
self
.
all_label
,
"emb"
)
emb_a
=
embedding_layer
(
inputs
[
0
],
"emb"
,
sequence_pool
=
True
)
emb_b
=
embedding_layer
(
inputs
[
1
],
"emb"
,
sequence_pool
=
True
)
emb_c
=
embedding_layer
(
inputs
[
2
],
"emb"
,
sequence_pool
=
True
)
target
=
fluid
.
layers
.
elementwise_add
(
fluid
.
layers
.
elementwise_sub
(
emb_b
,
emb_a
),
emb_c
)
emb_all_label_l2
=
fluid
.
layers
.
l2_normalize
(
x
=
emb_all_label
,
axis
=
1
)
dist
=
fluid
.
layers
.
matmul
(
x
=
target
,
y
=
emb_all_label_l2
,
transpose_y
=
True
)
values
,
pred_idx
=
fluid
.
layers
.
topk
(
input
=
dist
,
k
=
4
)
label
=
fluid
.
layers
.
expand
(
inputs
[
3
],
expand_times
=
[
1
,
4
])
label_ones
=
fluid
.
layers
.
fill_constant_batch_size_like
(
label
,
shape
=
[
-
1
,
1
],
value
=
1.0
,
dtype
=
'float32'
)
right_cnt
=
fluid
.
layers
.
reduce_sum
(
input
=
fluid
.
layers
.
cast
(
fluid
.
layers
.
equal
(
pred_idx
,
label
),
dtype
=
'float32'
))
total_cnt
=
fluid
.
layers
.
reduce_sum
(
label_ones
)
global_right_cnt
=
fluid
.
layers
.
create_global_var
(
name
=
"global_right_cnt"
,
persistable
=
True
,
dtype
=
'float32'
,
shape
=
[
1
],
value
=
0
)
global_total_cnt
=
fluid
.
layers
.
create_global_var
(
name
=
"global_total_cnt"
,
persistable
=
True
,
dtype
=
'float32'
,
shape
=
[
1
],
value
=
0
)
global_right_cnt
.
stop_gradient
=
True
global_total_cnt
.
stop_gradient
=
True
global_total_cnt
=
fluid
.
layers
.
Print
(
global_total_cnt
)
tmp1
=
fluid
.
layers
.
elementwise_add
(
right_cnt
,
global_right_cnt
)
fluid
.
layers
.
assign
(
tmp1
,
global_right_cnt
)
total_cnt
=
fluid
.
layers
.
Print
(
total_cnt
)
tmp2
=
fluid
.
layers
.
elementwise_add
(
total_cnt
,
global_total_cnt
)
fluid
.
layers
.
assign
(
tmp2
,
global_total_cnt
)
global_right_cnt
=
fluid
.
layers
.
Print
(
global_right_cnt
)
global_total_cnt
=
fluid
.
layers
.
Print
(
global_total_cnt
)
acc
=
fluid
.
layers
.
elementwise_div
(
global_right_cnt
,
global_total_cnt
,
name
=
"total_acc"
)
self
.
_infer_results
[
'acc'
]
=
acc
models/recall/fasttext/preprocess.py
0 → 100755
浏览文件 @
03353075
# -*- coding: utf-8 -*
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
io
import
math
import
os
import
random
import
re
import
six
import
argparse
prog
=
re
.
compile
(
"[^a-z ]"
,
flags
=
0
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Paddle Fluid word2 vector preprocess"
)
parser
.
add_argument
(
'--build_dict_corpus_dir'
,
type
=
str
,
help
=
"The dir of corpus"
)
parser
.
add_argument
(
'--input_corpus_dir'
,
type
=
str
,
help
=
"The dir of input corpus"
)
parser
.
add_argument
(
'--output_corpus_dir'
,
type
=
str
,
help
=
"The dir of output corpus"
)
parser
.
add_argument
(
'--dict_path'
,
type
=
str
,
default
=
'./dict'
,
help
=
"The path of dictionary "
)
parser
.
add_argument
(
'--word_id_path'
,
type
=
str
,
default
=
'./word_id'
,
help
=
"The path of word_id "
)
parser
.
add_argument
(
'--ngrams_path'
,
type
=
str
,
default
=
'./word_ngrams'
,
help
=
"The path of word_ngrams "
)
parser
.
add_argument
(
'--ngrams_id_path'
,
type
=
str
,
default
=
'./word_ngrams_id'
,
help
=
"The path of word_ngrams_id "
)
parser
.
add_argument
(
'--min_count'
,
type
=
int
,
default
=
5
,
help
=
"If the word count is less then min_count, it will be removed from dict"
)
parser
.
add_argument
(
'--min_n'
,
type
=
int
,
default
=
3
,
help
=
"min_n of ngrams"
)
parser
.
add_argument
(
'--max_n'
,
type
=
int
,
default
=
5
,
help
=
"max_n of ngrams"
)
parser
.
add_argument
(
'--file_nums'
,
type
=
int
,
default
=
1024
,
help
=
"re-split input corpus file nums"
)
parser
.
add_argument
(
'--downsample'
,
type
=
float
,
default
=
0.001
,
help
=
"filter word by downsample"
)
parser
.
add_argument
(
'--filter_corpus'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Filter corpus'
)
parser
.
add_argument
(
'--build_dict'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Build dict from corpus'
)
parser
.
add_argument
(
'--data_resplit'
,
action
=
'store_true'
,
default
=
False
,
help
=
're-split input corpus files'
)
return
parser
.
parse_args
()
def
text_strip
(
text
):
# English Preprocess Rule
return
prog
.
sub
(
""
,
text
.
lower
())
# Shameless copy from Tensorflow https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/text_encoder.py
# Unicode utility functions that work with Python 2 and 3
def
native_to_unicode
(
s
):
if
_is_unicode
(
s
):
return
s
try
:
return
_to_unicode
(
s
)
except
UnicodeDecodeError
:
res
=
_to_unicode
(
s
,
ignore_errors
=
True
)
return
res
def
_is_unicode
(
s
):
if
six
.
PY2
:
if
isinstance
(
s
,
unicode
):
return
True
else
:
if
isinstance
(
s
,
str
):
return
True
return
False
def
_to_unicode
(
s
,
ignore_errors
=
False
):
if
_is_unicode
(
s
):
return
s
error_mode
=
"ignore"
if
ignore_errors
else
"strict"
return
s
.
decode
(
"utf-8"
,
errors
=
error_mode
)
def
filter_corpus
(
args
):
"""
filter corpus and convert id.
"""
word_count
=
dict
()
word_to_id_
=
dict
()
word_all_count
=
0
id_counts
=
[]
word_id
=
0
# read dict
with
io
.
open
(
args
.
dict_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
line
in
f
:
word
,
count
=
line
.
split
()[
0
],
int
(
line
.
split
()[
1
])
word_count
[
word
]
=
count
word_to_id_
[
word
]
=
word_id
word_id
+=
1
id_counts
.
append
(
count
)
word_all_count
+=
count
word_ngrams
=
dict
()
with
io
.
open
(
args
.
ngrams_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
line
in
f
:
word
,
ngrams
=
line
.
rstrip
().
split
(
':'
)
ngrams
=
ngrams
.
split
()
ngrams
=
[
str
(
word_to_id_
[
_
])
for
_
in
ngrams
]
word_ngrams
[
word_to_id_
[
word
]]
=
' '
.
join
(
ngrams
)
with
io
.
open
(
args
.
ngrams_id_path
,
'w+'
,
encoding
=
'utf-8'
)
as
fid
:
for
k
,
v
in
word_ngrams
.
items
():
fid
.
write
(
u
'{} {}
\n
'
.
format
(
k
,
v
))
# write word2id file
print
(
"write word2id file to : "
+
args
.
dict_path
+
"_word_to_id_"
)
with
io
.
open
(
args
.
word_id_path
,
'w+'
,
encoding
=
'utf-8'
)
as
fid
:
for
k
,
v
in
word_to_id_
.
items
():
fid
.
write
(
k
+
" "
+
str
(
v
)
+
'
\n
'
)
# filter corpus and convert id
if
not
os
.
path
.
exists
(
args
.
output_corpus_dir
):
os
.
makedirs
(
args
.
output_corpus_dir
)
for
file
in
os
.
listdir
(
args
.
input_corpus_dir
):
with
io
.
open
(
args
.
output_corpus_dir
+
'/convert_'
+
file
+
'.csv'
,
"w"
)
as
wf
:
with
io
.
open
(
args
.
input_corpus_dir
+
'/'
+
file
,
encoding
=
'utf-8'
)
as
rf
:
print
(
args
.
input_corpus_dir
+
'/'
+
file
)
for
line
in
rf
:
signal
=
False
line
=
text_strip
(
line
)
words
=
line
.
split
()
write_line
=
""
for
item
in
words
:
if
item
in
word_count
:
idx
=
word_to_id_
[
item
]
else
:
idx
=
word_to_id_
[
native_to_unicode
(
'<UNK>'
)]
count_w
=
id_counts
[
idx
]
corpus_size
=
word_all_count
keep_prob
=
(
math
.
sqrt
(
count_w
/
(
args
.
downsample
*
corpus_size
))
+
1
)
*
(
args
.
downsample
*
corpus_size
)
/
count_w
r_value
=
random
.
random
()
if
r_value
>
keep_prob
:
continue
write_line
+=
str
(
idx
)
write_line
+=
","
signal
=
True
if
signal
:
write_line
=
write_line
[:
-
1
]
+
"
\n
"
wf
.
write
(
_to_unicode
(
write_line
))
def
computeSubwords
(
word
,
min_n
,
max_n
):
ngrams
=
set
()
for
i
in
range
(
len
(
word
)
-
min_n
+
1
):
for
j
in
range
(
min_n
,
max_n
+
1
):
end
=
min
(
len
(
word
),
i
+
j
)
ngrams
.
add
(
""
.
join
(
word
[
i
:
end
]))
return
list
(
ngrams
)
def
build_dict
(
args
):
"""
proprocess the data, generate dictionary and save into dict_path.
:param corpus_dir: the input data dir.
:param dict_path: the generated dict path. the data in dict is "word count"
:param min_count:
:return:
"""
# word to count
word_count
=
dict
()
for
file
in
os
.
listdir
(
args
.
build_dict_corpus_dir
):
with
io
.
open
(
args
.
build_dict_corpus_dir
+
"/"
+
file
,
encoding
=
'utf-8'
)
as
f
:
print
(
"build dict : "
,
args
.
build_dict_corpus_dir
+
"/"
+
file
)
for
line
in
f
:
line
=
text_strip
(
line
)
words
=
line
.
split
()
for
item
in
words
:
item
=
'<'
+
item
+
'>'
if
item
in
word_count
:
word_count
[
item
]
=
word_count
[
item
]
+
1
else
:
word_count
[
item
]
=
1
item_to_remove
=
[]
for
item
in
word_count
:
if
word_count
[
item
]
<=
args
.
min_count
:
item_to_remove
.
append
(
item
)
unk_sum
=
0
for
item
in
item_to_remove
:
unk_sum
+=
word_count
[
item
]
del
word_count
[
item
]
# sort by count
word_count
[
native_to_unicode
(
'<UNK>'
)]
=
unk_sum
word_ngrams
=
dict
()
ngrams_count
=
dict
()
for
item
in
word_count
:
ngrams
=
computeSubwords
(
item
,
args
.
min_n
,
args
.
max_n
)
word_ngrams
[
item
]
=
ngrams
for
sub_word
in
ngrams
:
if
sub_word
not
in
ngrams_count
:
ngrams_count
[
sub_word
]
=
1
else
:
ngrams_count
[
sub_word
]
=
ngrams_count
[
sub_word
]
+
1
ngrams_count
=
sorted
(
ngrams_count
.
items
(),
key
=
lambda
ngrams_count
:
-
ngrams_count
[
1
])
word_count
=
sorted
(
word_count
.
items
(),
key
=
lambda
word_count
:
-
word_count
[
1
])
with
io
.
open
(
args
.
dict_path
,
'w+'
,
encoding
=
'utf-8'
)
as
f
:
for
k
,
v
in
word_count
:
f
.
write
(
k
+
" "
+
str
(
v
)
+
'
\n
'
)
for
k
,
v
in
ngrams_count
:
f
.
write
(
k
+
" "
+
str
(
v
)
+
'
\n
'
)
with
io
.
open
(
args
.
ngrams_path
,
'w+'
,
encoding
=
'utf-8'
)
as
f
:
for
key
in
word_ngrams
:
f
.
write
(
key
+
":"
)
f
.
write
(
" "
.
join
(
word_ngrams
[
key
]))
f
.
write
(
u
'
\n
'
)
def
data_split
(
args
):
raw_data_dir
=
args
.
input_corpus_dir
new_data_dir
=
args
.
output_corpus_dir
if
not
os
.
path
.
exists
(
new_data_dir
):
os
.
mkdir
(
new_data_dir
)
files
=
os
.
listdir
(
raw_data_dir
)
print
(
files
)
index
=
0
contents
=
[]
for
file_
in
files
:
with
open
(
os
.
path
.
join
(
raw_data_dir
,
file_
),
'r'
)
as
f
:
contents
.
extend
(
f
.
readlines
())
num
=
int
(
args
.
file_nums
)
lines_per_file
=
len
(
contents
)
/
num
print
(
"contents: "
,
str
(
len
(
contents
)))
print
(
"lines_per_file: "
,
str
(
lines_per_file
))
for
i
in
range
(
1
,
num
+
1
):
with
open
(
os
.
path
.
join
(
new_data_dir
,
"part_"
+
str
(
i
)),
'w'
)
as
fout
:
data
=
contents
[(
i
-
1
)
*
lines_per_file
:
min
(
i
*
lines_per_file
,
len
(
contents
))]
for
line
in
data
:
fout
.
write
(
line
)
if
__name__
==
"__main__"
:
args
=
parse_args
()
if
args
.
build_dict
:
build_dict
(
args
)
elif
args
.
filter_corpus
:
filter_corpus
(
args
)
elif
args
.
data_resplit
:
data_split
(
args
)
else
:
print
(
"error command line, please choose --build_dict or --filter_corpus"
)
models/recall/fasttext/reader.py
0 → 100755
浏览文件 @
03353075
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
io
import
numpy
as
np
from
paddlerec.core.reader
import
Reader
from
paddlerec.core.utils
import
envs
class
NumpyRandomInt
(
object
):
def
__init__
(
self
,
a
,
b
,
buf_size
=
1000
):
self
.
idx
=
0
self
.
buffer
=
np
.
random
.
random_integers
(
a
,
b
,
buf_size
)
self
.
a
=
a
self
.
b
=
b
def
__call__
(
self
):
if
self
.
idx
==
len
(
self
.
buffer
):
self
.
buffer
=
np
.
random
.
random_integers
(
self
.
a
,
self
.
b
,
len
(
self
.
buffer
))
self
.
idx
=
0
result
=
self
.
buffer
[
self
.
idx
]
self
.
idx
+=
1
return
result
class
TrainReader
(
Reader
):
def
init
(
self
):
dict_path
=
envs
.
get_global_env
(
"dataset.dataset_train.word_count_dict_path"
)
word_ngrams_path
=
envs
.
get_global_env
(
"dataset.dataset_train.word_ngrams_path"
)
self
.
window_size
=
envs
.
get_global_env
(
"hyper_parameters.window_size"
)
self
.
neg_num
=
envs
.
get_global_env
(
"hyper_parameters.neg_num"
)
self
.
with_shuffle_batch
=
envs
.
get_global_env
(
"hyper_parameters.with_shuffle_batch"
)
self
.
random_generator
=
NumpyRandomInt
(
1
,
self
.
window_size
+
1
)
self
.
word_ngrams
=
dict
()
with
io
.
open
(
word_ngrams_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
line
in
f
:
line
=
line
.
rstrip
().
split
()
self
.
word_ngrams
[
str
(
line
[
0
])]
=
map
(
int
,
line
[
1
:])
self
.
cs
=
None
if
not
self
.
with_shuffle_batch
:
id_counts
=
[]
word_all_count
=
0
with
io
.
open
(
dict_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
line
in
f
:
word
,
count
=
line
.
split
()[
0
],
int
(
line
.
split
()[
1
])
id_counts
.
append
(
count
)
word_all_count
+=
count
id_frequencys
=
[
float
(
count
)
/
word_all_count
for
count
in
id_counts
]
np_power
=
np
.
power
(
np
.
array
(
id_frequencys
),
0.75
)
id_frequencys_pow
=
np_power
/
np_power
.
sum
()
self
.
cs
=
np
.
array
(
id_frequencys_pow
).
cumsum
()
def
get_context_words
(
self
,
words
,
idx
):
"""
Get the context word list of target word.
words: the words of the current line
idx: input word index
window_size: window size
"""
target_window
=
self
.
random_generator
()
start_point
=
idx
-
target_window
# if (idx - target_window) > 0 else 0
if
start_point
<
0
:
start_point
=
0
end_point
=
idx
+
target_window
targets
=
words
[
start_point
:
idx
]
+
words
[
idx
+
1
:
end_point
+
1
]
return
targets
def
generate_sample
(
self
,
line
):
def
reader
():
word_ids
=
[
w
for
w
in
line
.
split
()]
for
idx
,
target_id
in
enumerate
(
word_ids
):
input_word
=
[
int
(
target_id
)]
if
target_id
in
self
.
word_ngrams
:
input_word
+=
self
.
word_ngrams
[
target_id
]
context_word_ids
=
self
.
get_context_words
(
word_ids
,
idx
)
for
context_id
in
context_word_ids
:
output
=
[(
'input_word'
,
input_word
),
(
'true_label'
,
[
int
(
context_id
)])]
if
not
self
.
with_shuffle_batch
:
neg_array
=
self
.
cs
.
searchsorted
(
np
.
random
.
sample
(
self
.
neg_num
))
output
+=
[(
'neg_label'
,
[
int
(
str
(
i
))
for
i
in
neg_array
])]
yield
output
return
reader
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录