Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
05b7d07d
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
接近 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
05b7d07d
编写于
11月 22, 2019
作者:
S
slf12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add quant_embedding demo
上级
9fb9b6d2
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
1544 addition
and
0 deletion
+1544
-0
demo/quant/quant_embedding/README.md
demo/quant/quant_embedding/README.md
+238
-0
demo/quant/quant_embedding/cluster_train.py
demo/quant/quant_embedding/cluster_train.py
+250
-0
demo/quant/quant_embedding/cluster_train.sh
demo/quant/quant_embedding/cluster_train.sh
+68
-0
demo/quant/quant_embedding/image/after.png
demo/quant/quant_embedding/image/after.png
+0
-0
demo/quant/quant_embedding/image/before.png
demo/quant/quant_embedding/image/before.png
+0
-0
demo/quant/quant_embedding/infer.py
demo/quant/quant_embedding/infer.py
+227
-0
demo/quant/quant_embedding/net.py
demo/quant/quant_embedding/net.py
+136
-0
demo/quant/quant_embedding/preprocess.py
demo/quant/quant_embedding/preprocess.py
+195
-0
demo/quant/quant_embedding/reader.py
demo/quant/quant_embedding/reader.py
+106
-0
demo/quant/quant_embedding/train.py
demo/quant/quant_embedding/train.py
+228
-0
demo/quant/quant_embedding/utils.py
demo/quant/quant_embedding/utils.py
+96
-0
未找到文件。
demo/quant/quant_embedding/README.md
0 → 100755
浏览文件 @
05b7d07d
# Embedding量化示例
本示例介绍如何使用Embedding量化的接口
[
paddleslim.quant.quant_embedding
](
)
。
``quant_embedding``
接口将网络中的Embedding参数从
``float32``
类型量化到
``8-bit``
整数类型,在几乎不损失模型精度的情况下较少模型的存储空间和显存占用。
接口如下:
```
quant_embedding(program, place, config, scope=None)
```
参数介绍:
-
program(fluid.Program) : 需要量化的program
-
scope(fluid.Scope, optional) : 用来获取和写入
``Variable``
, 如果设置为
``None``
,则使用
``fluid.global_scope()``
.
-
place(fluid.CPUPlace or fluid.CUDAPlace): 运行program的设备
-
config(dict) : 定义量化的配置。可以配置的参数有:
-
``'params_name'``
(str, required): 需要进行量化的参数名称,此参数必须设置。
-
``'quantize_type'``
(str, optional): 量化的类型,目前支持的类型是
``'abs_max'``
, 待支持的类型有
``'log', 'product_quantization'``
。 默认值是
``'abs_max'``
.
-
``'quantize_bits'``
(int, optional): 量化的
``bit``
数,目前支持的
``bit``
数为8。默认值是8.
-
``'dtype'``
(str, optional): 量化之后的数据类型, 目前支持的是
``'int8'``
. 默认值是
``int8``
。
-
``'threshold'``
(float, optional): 量化之前将根据此阈值对需要量化的参数值进行
``clip``
. 如果不设置,则跳过
``clip``
过程直接量化。
该接口对program的修改:
量化前:
<p
align=
"center"
>
<img
src=
"./image/before.png"
height=
200
width=
100
hspace=
'10'
/>
<br
/>
<strong>
图3:应用ConvertToInt8Pass后的结果
</strong>
</p>
量化后:
<p
align=
"center"
>
<img
src=
"./image/after.png"
height=
300
width=
300
hspace=
'10'
/>
<br
/>
<strong>
图3:应用ConvertToInt8Pass后的结果
</strong>
</p>
以下将以
``基于skip-gram的word2vector模型``
为例来说明如何使用
``quant_embedding``
接口。首先介绍
``基于skip-gram的word2vector模型``
的正常训练和测试流程。
## 基于skip-gram的word2vector模型
以下是本例的简要目录结构及说明:
```
text
.
├── cluster_train.py # 分布式训练函数
├── cluster_train.sh # 本地模拟多机脚本
├── train.py # 训练函数
├── infer.py # 预测脚本
├── net.py # 网络结构
├── preprocess.py # 预处理脚本,包括构建词典和预处理文本
├── reader.py # 训练阶段的文本读写
├── train.py # 训练函数
└── utils.py # 通用函数
```
### 介绍
本例实现了skip-gram模式的word2vector模型。
同时推荐用户参考
[
IPython Notebook demo
](
https://aistudio.baidu.com/aistudio/projectDetail/124377
)
### 数据下载
全量数据集使用的是来自1 Billion Word Language Model Benchmark的(http://www.statmt.org/lm-benchmark) 的数据集.
```
bash
mkdir
data
wget http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz
tar
xzvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz
mv
1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/ data/
```
备用数据地址下载命令如下
```
bash
mkdir
data
wget https://paddlerec.bj.bcebos.com/word2vec/1-billion-word-language-modeling-benchmark-r13output.tar
tar
xvf 1-billion-word-language-modeling-benchmark-r13output.tar
mv
1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/ data/
```
为了方便快速验证,我们也提供了经典的text8样例数据集,包含1700w个词。 下载命令如下
```
bash
mkdir
data
wget https://paddlerec.bj.bcebos.com/word2vec/text.tar
tar
xvf text.tar
mv
text data/
```
### 数据预处理
以样例数据集为例进行预处理。全量数据集注意解压后以training-monolingual.tokenized.shuffled 目录为预处理目录,和样例数据集的text目录并列。
词典格式: 词
<空格>
词频。注意低频词用'UNK'表示
可以按格式自建词典,如果自建词典跳过第一步。
```
the 1061396
of 593677
and 416629
one 411764
in 372201
a 325873
<UNK> 324608
to 316376
zero 264975
nine 250430
```
第一步根据英文语料生成词典,中文语料可以通过修改text_strip方法自定义处理方法。
```
bash
python preprocess.py
--build_dict
--build_dict_corpus_dir
data/text/
--dict_path
data/test_build_dict
```
第二步根据词典将文本转成id, 同时进行downsample,按照概率过滤常见词, 同时生成word和id映射的文件,文件名为词典+"_word_to_id_"。
```
bash
python preprocess.py
--filter_corpus
--dict_path
data/test_build_dict
--input_corpus_dir
data/text
--output_corpus_dir
data/convert_text8
--min_count
5
--downsample
0.001
```
### 训练
具体的参数配置可运行
```
bash
python train.py
-h
```
单机多线程训练
```
bash
OPENBLAS_NUM_THREADS
=
1
CPU_NUM
=
5 python train.py
--train_data_dir
data/convert_text8
--dict_path
data/test_build_dict
--num_passes
10
--batch_size
100
--model_output_dir
v1_cpu5_b100_lr1dir
--base_lr
1.0
--print_batch
1000
--with_speed
--is_sparse
```
本地单机模拟多机训练
```
bash
sh cluster_train.sh
```
本示例中按照单机多线程训练的命令进行训练,训练完毕后,可看到在当前文件夹下保存模型的路径为:
``v1_cpu5_b100_lr1dir``
, 运行
``ls v1_cpu5_b100_lr1dir``
可看到该文件夹下保存了训练的10个epoch的模型文件。
```
pass-0 pass-1 pass-2 pass-3 pass-4 pass-5 pass-6 pass-7 pass-8 pass-9
```
### 预测
测试集下载命令如下
```
bash
#全量数据集测试集
wget https://paddlerec.bj.bcebos.com/word2vec/test_dir.tar
#样本数据集测试集
wget https://paddlerec.bj.bcebos.com/word2vec/test_mid_dir.tar
```
预测命令,注意词典名称需要加后缀"_word_to_id_", 此文件是预处理阶段生成的。
```
bash
python infer.py
--infer_epoch
--test_dir
data/test_mid_dir
--dict_path
data/test_build_dict_word_to_id_
--batch_size
20000
--model_dir
v1_cpu5_b100_lr1dir/
--start_index
0
--last_index
9
```
运行该预测命令, 可看到如下输出
```
('start index: ', 0, ' last_index:', 9)
('vocab_size:', 63642)
step:1 249
epoch:0 acc:0.014
step:1 590
epoch:1 acc:0.033
step:1 982
epoch:2 acc:0.055
step:1 1338
epoch:3 acc:0.075
step:1 1653
epoch:4 acc:0.093
step:1 1914
epoch:5 acc:0.107
step:1 2204
epoch:6 acc:0.124
step:1 2416
epoch:7 acc:0.136
step:1 2606
epoch:8 acc:0.146
step:1 2722
epoch:9 acc:0.153
```
## 量化``基于skip-gram的word2vector模型``保存的模型
量化配置为:
```
config = {
'params_name': 'emb',
'quantize_type': 'abs_max'
}
```
运行命令为:
```
bash
python infer.py
--infer_epoch
--test_dir
data/test_mid_dir
--dict_path
data/test_build_dict_word_to_id_
--batch_size
20000
--model_dir
v1_cpu5_b100_lr1dir/
--start_index
0
--last_index
9
--emb_quant
True
```
运行输出为:
```
('start index: ', 0, ' last_index:', 9)
('vocab_size:', 63642)
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 253
epoch:0 acc:0.014
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 586
epoch:1 acc:0.033
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 970
epoch:2 acc:0.054
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 1364
epoch:3 acc:0.077
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 1642
epoch:4 acc:0.092
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 1936
epoch:5 acc:0.109
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 2216
epoch:6 acc:0.124
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 2419
epoch:7 acc:0.136
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 2603
epoch:8 acc:0.146
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 2719
epoch:9 acc:0.153
```
demo/quant/quant_embedding/cluster_train.py
0 → 100755
浏览文件 @
05b7d07d
from
__future__
import
print_function
import
argparse
import
logging
import
os
import
time
import
math
import
random
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
import
six
import
reader
from
net
import
skip_gram_word2vec
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
)
logger
=
logging
.
getLogger
(
"fluid"
)
logger
.
setLevel
(
logging
.
INFO
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"PaddlePaddle Word2vec example"
)
parser
.
add_argument
(
'--train_data_dir'
,
type
=
str
,
default
=
'./data/text'
,
help
=
"The path of taining dataset"
)
parser
.
add_argument
(
'--base_lr'
,
type
=
float
,
default
=
0.01
,
help
=
"The number of learing rate (default: 0.01)"
)
parser
.
add_argument
(
'--save_step'
,
type
=
int
,
default
=
500000
,
help
=
"The number of step to save (default: 500000)"
)
parser
.
add_argument
(
'--print_batch'
,
type
=
int
,
default
=
100
,
help
=
"The number of print_batch (default: 10)"
)
parser
.
add_argument
(
'--dict_path'
,
type
=
str
,
default
=
'./data/1-billion_dict'
,
help
=
"The path of data dict"
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
500
,
help
=
"The size of mini-batch (default:500)"
)
parser
.
add_argument
(
'--num_passes'
,
type
=
int
,
default
=
10
,
help
=
"The number of passes to train (default: 10)"
)
parser
.
add_argument
(
'--model_output_dir'
,
type
=
str
,
default
=
'models'
,
help
=
'The path for model to store (default: models)'
)
parser
.
add_argument
(
'--nce_num'
,
type
=
int
,
default
=
5
,
help
=
'nce_num'
)
parser
.
add_argument
(
'--embedding_size'
,
type
=
int
,
default
=
64
,
help
=
'sparse feature hashing space for index processing'
)
parser
.
add_argument
(
'--is_sparse'
,
action
=
'store_true'
,
required
=
False
,
default
=
False
,
help
=
'embedding and nce will use sparse or not, (default: False)'
)
parser
.
add_argument
(
'--with_speed'
,
action
=
'store_true'
,
required
=
False
,
default
=
False
,
help
=
'print speed or not , (default: False)'
)
parser
.
add_argument
(
'--role'
,
type
=
str
,
default
=
'pserver'
,
help
=
'trainer or pserver'
)
parser
.
add_argument
(
'--endpoints'
,
type
=
str
,
default
=
'127.0.0.1:6000'
,
help
=
'The pserver endpoints, like: 127.0.0.1:6000, 127.0.0.1:6001'
)
parser
.
add_argument
(
'--current_endpoint'
,
type
=
str
,
default
=
'127.0.0.1:6000'
,
help
=
'The current_endpoint'
)
parser
.
add_argument
(
'--trainer_id'
,
type
=
int
,
default
=
0
,
help
=
'trainer id ,only trainer_id=0 save model'
)
parser
.
add_argument
(
'--trainers'
,
type
=
int
,
default
=
1
,
help
=
'The num of trianers, (default: 1)'
)
return
parser
.
parse_args
()
def
convert_python_to_tensor
(
weight
,
batch_size
,
sample_reader
):
def
__reader__
():
cs
=
np
.
array
(
weight
).
cumsum
()
result
=
[[],
[]]
for
sample
in
sample_reader
():
for
i
,
fea
in
enumerate
(
sample
):
result
[
i
].
append
(
fea
)
if
len
(
result
[
0
])
==
batch_size
:
tensor_result
=
[]
for
tensor
in
result
:
t
=
fluid
.
Tensor
()
dat
=
np
.
array
(
tensor
,
dtype
=
'int64'
)
if
len
(
dat
.
shape
)
>
2
:
dat
=
dat
.
reshape
((
dat
.
shape
[
0
],
dat
.
shape
[
2
]))
elif
len
(
dat
.
shape
)
==
1
:
dat
=
dat
.
reshape
((
-
1
,
1
))
t
.
set
(
dat
,
fluid
.
CPUPlace
())
tensor_result
.
append
(
t
)
tt
=
fluid
.
Tensor
()
neg_array
=
cs
.
searchsorted
(
np
.
random
.
sample
(
args
.
nce_num
))
neg_array
=
np
.
tile
(
neg_array
,
batch_size
)
tt
.
set
(
neg_array
.
reshape
((
batch_size
,
args
.
nce_num
)),
fluid
.
CPUPlace
())
tensor_result
.
append
(
tt
)
yield
tensor_result
result
=
[[],
[]]
return
__reader__
def
train_loop
(
args
,
train_program
,
reader
,
py_reader
,
loss
,
trainer_id
,
weight
):
py_reader
.
decorate_tensor_provider
(
convert_python_to_tensor
(
weight
,
args
.
batch_size
,
reader
.
train
()))
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
print
(
"CPU_NUM:"
+
str
(
os
.
getenv
(
"CPU_NUM"
)))
train_exe
=
exe
for
pass_id
in
range
(
args
.
num_passes
):
py_reader
.
start
()
time
.
sleep
(
10
)
epoch_start
=
time
.
time
()
batch_id
=
0
start
=
time
.
time
()
try
:
while
True
:
loss_val
=
train_exe
.
run
(
fetch_list
=
[
loss
.
name
])
loss_val
=
np
.
mean
(
loss_val
)
if
batch_id
%
args
.
print_batch
==
0
:
logger
.
info
(
"TRAIN --> pass: {} batch: {} loss: {} reader queue:{}"
.
format
(
pass_id
,
batch_id
,
loss_val
.
mean
(),
py_reader
.
queue
.
size
()))
if
args
.
with_speed
:
if
batch_id
%
500
==
0
and
batch_id
!=
0
:
elapsed
=
(
time
.
time
()
-
start
)
start
=
time
.
time
()
samples
=
1001
*
args
.
batch_size
*
int
(
os
.
getenv
(
"CPU_NUM"
))
logger
.
info
(
"Time used: {}, Samples/Sec: {}"
.
format
(
elapsed
,
samples
/
elapsed
))
if
batch_id
%
args
.
save_step
==
0
and
batch_id
!=
0
:
model_dir
=
args
.
model_output_dir
+
'/pass-'
+
str
(
pass_id
)
+
(
'/batch-'
+
str
(
batch_id
))
if
trainer_id
==
0
:
fluid
.
io
.
save_params
(
executor
=
exe
,
dirname
=
model_dir
)
print
(
"model saved in %s"
%
model_dir
)
batch_id
+=
1
except
fluid
.
core
.
EOFException
:
py_reader
.
reset
()
epoch_end
=
time
.
time
()
logger
.
info
(
"Epoch: {0}, Train total expend: {1} "
.
format
(
pass_id
,
epoch_end
-
epoch_start
))
model_dir
=
args
.
model_output_dir
+
'/pass-'
+
str
(
pass_id
)
if
trainer_id
==
0
:
fluid
.
io
.
save_params
(
executor
=
exe
,
dirname
=
model_dir
)
print
(
"model saved in %s"
%
model_dir
)
def
GetFileList
(
data_path
):
return
os
.
listdir
(
data_path
)
def
train
(
args
):
if
not
os
.
path
.
isdir
(
args
.
model_output_dir
)
and
args
.
trainer_id
==
0
:
os
.
mkdir
(
args
.
model_output_dir
)
filelist
=
GetFileList
(
args
.
train_data_dir
)
word2vec_reader
=
reader
.
Word2VecReader
(
args
.
dict_path
,
args
.
train_data_dir
,
filelist
,
0
,
1
)
logger
.
info
(
"dict_size: {}"
.
format
(
word2vec_reader
.
dict_size
))
np_power
=
np
.
power
(
np
.
array
(
word2vec_reader
.
id_frequencys
),
0.75
)
id_frequencys_pow
=
np_power
/
np_power
.
sum
()
loss
,
py_reader
=
skip_gram_word2vec
(
word2vec_reader
.
dict_size
,
args
.
embedding_size
,
is_sparse
=
args
.
is_sparse
,
neg_num
=
args
.
nce_num
)
optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
fluid
.
layers
.
exponential_decay
(
learning_rate
=
args
.
base_lr
,
decay_steps
=
100000
,
decay_rate
=
0.999
,
staircase
=
True
))
optimizer
.
minimize
(
loss
)
logger
.
info
(
"run dist training"
)
t
=
fluid
.
DistributeTranspiler
()
t
.
transpile
(
args
.
trainer_id
,
pservers
=
args
.
endpoints
,
trainers
=
args
.
trainers
)
if
args
.
role
==
"pserver"
:
print
(
"run psever"
)
pserver_prog
=
t
.
get_pserver_program
(
args
.
current_endpoint
)
pserver_startup
=
t
.
get_startup_program
(
args
.
current_endpoint
,
pserver_prog
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
exe
.
run
(
pserver_startup
)
exe
.
run
(
pserver_prog
)
elif
args
.
role
==
"trainer"
:
print
(
"run trainer"
)
train_loop
(
args
,
t
.
get_trainer_program
(),
word2vec_reader
,
py_reader
,
loss
,
args
.
trainer_id
,
id_frequencys_pow
)
if
__name__
==
'__main__'
:
args
=
parse_args
()
train
(
args
)
demo/quant/quant_embedding/cluster_train.sh
0 → 100755
浏览文件 @
05b7d07d
#!/bin/bash
#export GLOG_v=30
#export GLOG_logtostderr=1
# start pserver0
export
CPU_NUM
=
5
export
FLAGS_rpc_deadline
=
3000000
python cluster_train.py
\
--train_data_dir
data/convert_text8
\
--dict_path
data/test_build_dict
\
--batch_size
100
\
--model_output_dir
dis_model
\
--base_lr
1.0
\
--print_batch
1
\
--is_sparse
\
--with_speed
\
--role
pserver
\
--endpoints
127.0.0.1:6000,127.0.0.1:6001
\
--current_endpoint
127.0.0.1:6000
\
--trainers
2
\
>
pserver0.log 2>&1 &
python cluster_train.py
\
--train_data_dir
data/convert_text8
\
--dict_path
data/test_build_dict
\
--batch_size
100
\
--model_output_dir
dis_model
\
--base_lr
1.0
\
--print_batch
1
\
--is_sparse
\
--with_speed
\
--role
pserver
\
--endpoints
127.0.0.1:6000,127.0.0.1:6001
\
--current_endpoint
127.0.0.1:6001
\
--trainers
2
\
>
pserver1.log 2>&1 &
# start trainer0
python cluster_train.py
\
--train_data_dir
data/convert_text8
\
--dict_path
data/test_build_dict
\
--batch_size
100
\
--model_output_dir
dis_model
\
--base_lr
1.0
\
--print_batch
1000
\
--is_sparse
\
--with_speed
\
--role
trainer
\
--endpoints
127.0.0.1:6000,127.0.0.1:6001
\
--trainers
2
\
--trainer_id
0
\
>
trainer0.log 2>&1 &
# start trainer1
python cluster_train.py
\
--train_data_dir
data/convert_text8
\
--dict_path
data/test_build_dict
\
--batch_size
100
\
--model_output_dir
dis_model
\
--base_lr
1.0
\
--print_batch
1000
\
--is_sparse
\
--with_speed
\
--role
trainer
\
--endpoints
127.0.0.1:6000,127.0.0.1:6001
\
--trainers
2
\
--trainer_id
1
\
>
trainer1.log 2>&1 &
demo/quant/quant_embedding/image/after.png
0 → 100644
浏览文件 @
05b7d07d
82.2 KB
demo/quant/quant_embedding/image/before.png
0 → 100644
浏览文件 @
05b7d07d
31.0 KB
demo/quant/quant_embedding/infer.py
0 → 100755
浏览文件 @
05b7d07d
import
argparse
import
sys
import
time
import
math
import
unittest
import
contextlib
import
numpy
as
np
import
six
import
paddle.fluid
as
fluid
import
paddle
import
net
import
utils
sys
.
path
.
append
(
sys
.
path
[
0
]
+
"/../../../"
)
from
paddleslim.quant
import
quant_embedding
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
"PaddlePaddle Word2vec infer example"
)
parser
.
add_argument
(
'--dict_path'
,
type
=
str
,
default
=
'./data/data_c/1-billion_dict_word_to_id_'
,
help
=
"The path of dic"
)
parser
.
add_argument
(
'--infer_epoch'
,
action
=
'store_true'
,
required
=
False
,
default
=
False
,
help
=
'infer by epoch'
)
parser
.
add_argument
(
'--infer_step'
,
action
=
'store_true'
,
required
=
False
,
default
=
False
,
help
=
'infer by step'
)
parser
.
add_argument
(
'--test_dir'
,
type
=
str
,
default
=
'test_data'
,
help
=
'test file address'
)
parser
.
add_argument
(
'--print_step'
,
type
=
int
,
default
=
'500000'
,
help
=
'print step'
)
parser
.
add_argument
(
'--start_index'
,
type
=
int
,
default
=
'0'
,
help
=
'start index'
)
parser
.
add_argument
(
'--start_batch'
,
type
=
int
,
default
=
'1'
,
help
=
'start index'
)
parser
.
add_argument
(
'--end_batch'
,
type
=
int
,
default
=
'13'
,
help
=
'start index'
)
parser
.
add_argument
(
'--last_index'
,
type
=
int
,
default
=
'100'
,
help
=
'last index'
)
parser
.
add_argument
(
'--model_dir'
,
type
=
str
,
default
=
'model'
,
help
=
'model dir'
)
parser
.
add_argument
(
'--use_cuda'
,
type
=
int
,
default
=
'0'
,
help
=
'whether use cuda'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
'5'
,
help
=
'batch_size'
)
parser
.
add_argument
(
'--emb_size'
,
type
=
int
,
default
=
'64'
,
help
=
'batch_size'
)
parser
.
add_argument
(
'--emb_quant'
,
type
=
bool
,
default
=
False
,
help
=
'whether to quang embedding parameter'
)
args
=
parser
.
parse_args
()
return
args
def
infer_epoch
(
args
,
vocab_size
,
test_reader
,
use_cuda
,
i2w
):
""" inference function """
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
emb_size
=
args
.
emb_size
batch_size
=
args
.
batch_size
with
fluid
.
scope_guard
(
fluid
.
Scope
()):
main_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_program
):
values
,
pred
=
net
.
infer_network
(
vocab_size
,
emb_size
)
for
epoch
in
range
(
start_index
,
last_index
+
1
):
copy_program
=
main_program
.
clone
()
model_path
=
model_dir
+
"/pass-"
+
str
(
epoch
)
fluid
.
io
.
load_params
(
executor
=
exe
,
dirname
=
model_path
,
main_program
=
copy_program
)
if
args
.
emb_quant
:
config
=
{
'params_name'
:
'emb'
,
'quantize_type'
:
'abs_max'
}
copy_program
=
quant_embedding
(
copy_program
,
place
,
config
)
accum_num
=
0
accum_num_sum
=
0.0
t0
=
time
.
time
()
step_id
=
0
for
data
in
test_reader
():
step_id
+=
1
b_size
=
len
([
dat
[
0
]
for
dat
in
data
])
wa
=
np
.
array
(
[
dat
[
0
]
for
dat
in
data
]).
astype
(
"int64"
).
reshape
(
b_size
,
1
)
wb
=
np
.
array
(
[
dat
[
1
]
for
dat
in
data
]).
astype
(
"int64"
).
reshape
(
b_size
,
1
)
wc
=
np
.
array
(
[
dat
[
2
]
for
dat
in
data
]).
astype
(
"int64"
).
reshape
(
b_size
,
1
)
label
=
[
dat
[
3
]
for
dat
in
data
]
input_word
=
[
dat
[
4
]
for
dat
in
data
]
para
=
exe
.
run
(
copy_program
,
feed
=
{
"analogy_a"
:
wa
,
"analogy_b"
:
wb
,
"analogy_c"
:
wc
,
"all_label"
:
np
.
arange
(
vocab_size
).
reshape
(
vocab_size
,
1
).
astype
(
"int64"
),
},
fetch_list
=
[
pred
.
name
,
values
],
return_numpy
=
False
)
pre
=
np
.
array
(
para
[
0
])
val
=
np
.
array
(
para
[
1
])
for
ii
in
range
(
len
(
label
)):
top4
=
pre
[
ii
]
accum_num_sum
+=
1
for
idx
in
top4
:
if
int
(
idx
)
in
input_word
[
ii
]:
continue
if
int
(
idx
)
==
int
(
label
[
ii
][
0
]):
accum_num
+=
1
break
if
step_id
%
1
==
0
:
print
(
"step:%d %d "
%
(
step_id
,
accum_num
))
print
(
"epoch:%d
\t
acc:%.3f "
%
(
epoch
,
1.0
*
accum_num
/
accum_num_sum
))
def
infer_step
(
args
,
vocab_size
,
test_reader
,
use_cuda
,
i2w
):
""" inference function """
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
emb_size
=
args
.
emb_size
batch_size
=
args
.
batch_size
with
fluid
.
scope_guard
(
fluid
.
Scope
()):
main_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_program
):
values
,
pred
=
net
.
infer_network
(
vocab_size
,
emb_size
)
for
epoch
in
range
(
start_index
,
last_index
+
1
):
for
batchid
in
range
(
args
.
start_batch
,
args
.
end_batch
):
copy_program
=
main_program
.
clone
()
model_path
=
model_dir
+
"/pass-"
+
str
(
epoch
)
+
(
'/batch-'
+
str
(
batchid
*
args
.
print_step
))
fluid
.
io
.
load_params
(
executor
=
exe
,
dirname
=
model_path
,
main_program
=
copy_program
)
accum_num
=
0
accum_num_sum
=
0.0
t0
=
time
.
time
()
step_id
=
0
for
data
in
test_reader
():
step_id
+=
1
b_size
=
len
([
dat
[
0
]
for
dat
in
data
])
wa
=
np
.
array
(
[
dat
[
0
]
for
dat
in
data
]).
astype
(
"int64"
).
reshape
(
b_size
,
1
)
wb
=
np
.
array
(
[
dat
[
1
]
for
dat
in
data
]).
astype
(
"int64"
).
reshape
(
b_size
,
1
)
wc
=
np
.
array
(
[
dat
[
2
]
for
dat
in
data
]).
astype
(
"int64"
).
reshape
(
b_size
,
1
)
label
=
[
dat
[
3
]
for
dat
in
data
]
input_word
=
[
dat
[
4
]
for
dat
in
data
]
para
=
exe
.
run
(
copy_program
,
feed
=
{
"analogy_a"
:
wa
,
"analogy_b"
:
wb
,
"analogy_c"
:
wc
,
"all_label"
:
np
.
arange
(
vocab_size
).
reshape
(
vocab_size
,
1
),
},
fetch_list
=
[
pred
.
name
,
values
],
return_numpy
=
False
)
pre
=
np
.
array
(
para
[
0
])
val
=
np
.
array
(
para
[
1
])
for
ii
in
range
(
len
(
label
)):
top4
=
pre
[
ii
]
accum_num_sum
+=
1
for
idx
in
top4
:
if
int
(
idx
)
in
input_word
[
ii
]:
continue
if
int
(
idx
)
==
int
(
label
[
ii
][
0
]):
accum_num
+=
1
break
if
step_id
%
1
==
0
:
print
(
"step:%d %d "
%
(
step_id
,
accum_num
))
print
(
"epoch:%d
\t
acc:%.3f "
%
(
epoch
,
1.0
*
accum_num
/
accum_num_sum
))
t1
=
time
.
time
()
if
__name__
==
"__main__"
:
args
=
parse_args
()
start_index
=
args
.
start_index
last_index
=
args
.
last_index
test_dir
=
args
.
test_dir
model_dir
=
args
.
model_dir
batch_size
=
args
.
batch_size
dict_path
=
args
.
dict_path
use_cuda
=
True
if
args
.
use_cuda
else
False
print
(
"start index: "
,
start_index
,
" last_index:"
,
last_index
)
vocab_size
,
test_reader
,
id2word
=
utils
.
prepare_data
(
test_dir
,
dict_path
,
batch_size
=
batch_size
)
print
(
"vocab_size:"
,
vocab_size
)
if
args
.
infer_step
:
infer_step
(
args
,
vocab_size
,
test_reader
=
test_reader
,
use_cuda
=
use_cuda
,
i2w
=
id2word
)
else
:
infer_epoch
(
args
,
vocab_size
,
test_reader
=
test_reader
,
use_cuda
=
use_cuda
,
i2w
=
id2word
)
demo/quant/quant_embedding/net.py
0 → 100755
浏览文件 @
05b7d07d
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
neural network for word2vec
"""
from
__future__
import
print_function
import
math
import
numpy
as
np
import
paddle.fluid
as
fluid
def
skip_gram_word2vec
(
dict_size
,
embedding_size
,
is_sparse
=
False
,
neg_num
=
5
):
datas
=
[]
input_word
=
fluid
.
layers
.
data
(
name
=
"input_word"
,
shape
=
[
1
],
dtype
=
'int64'
)
true_word
=
fluid
.
layers
.
data
(
name
=
'true_label'
,
shape
=
[
1
],
dtype
=
'int64'
)
neg_word
=
fluid
.
layers
.
data
(
name
=
"neg_label"
,
shape
=
[
neg_num
],
dtype
=
'int64'
)
datas
.
append
(
input_word
)
datas
.
append
(
true_word
)
datas
.
append
(
neg_word
)
py_reader
=
fluid
.
layers
.
create_py_reader_by_data
(
capacity
=
64
,
feed_list
=
datas
,
name
=
'py_reader'
,
use_double_buffer
=
True
)
words
=
fluid
.
layers
.
read_file
(
py_reader
)
init_width
=
0.5
/
embedding_size
input_emb
=
fluid
.
layers
.
embedding
(
input
=
words
[
0
],
is_sparse
=
is_sparse
,
size
=
[
dict_size
,
embedding_size
],
param_attr
=
fluid
.
ParamAttr
(
name
=
'emb'
,
initializer
=
fluid
.
initializer
.
Uniform
(
-
init_width
,
init_width
)))
true_emb_w
=
fluid
.
layers
.
embedding
(
input
=
words
[
1
],
is_sparse
=
is_sparse
,
size
=
[
dict_size
,
embedding_size
],
param_attr
=
fluid
.
ParamAttr
(
name
=
'emb_w'
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.0
)))
true_emb_b
=
fluid
.
layers
.
embedding
(
input
=
words
[
1
],
is_sparse
=
is_sparse
,
size
=
[
dict_size
,
1
],
param_attr
=
fluid
.
ParamAttr
(
name
=
'emb_b'
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.0
)))
neg_word_reshape
=
fluid
.
layers
.
reshape
(
words
[
2
],
shape
=
[
-
1
,
1
])
neg_word_reshape
.
stop_gradient
=
True
neg_emb_w
=
fluid
.
layers
.
embedding
(
input
=
neg_word_reshape
,
is_sparse
=
is_sparse
,
size
=
[
dict_size
,
embedding_size
],
param_attr
=
fluid
.
ParamAttr
(
name
=
'emb_w'
,
learning_rate
=
1.0
))
neg_emb_w_re
=
fluid
.
layers
.
reshape
(
neg_emb_w
,
shape
=
[
-
1
,
neg_num
,
embedding_size
])
neg_emb_b
=
fluid
.
layers
.
embedding
(
input
=
neg_word_reshape
,
is_sparse
=
is_sparse
,
size
=
[
dict_size
,
1
],
param_attr
=
fluid
.
ParamAttr
(
name
=
'emb_b'
,
learning_rate
=
1.0
))
neg_emb_b_vec
=
fluid
.
layers
.
reshape
(
neg_emb_b
,
shape
=
[
-
1
,
neg_num
])
true_logits
=
fluid
.
layers
.
elementwise_add
(
fluid
.
layers
.
reduce_sum
(
fluid
.
layers
.
elementwise_mul
(
input_emb
,
true_emb_w
),
dim
=
1
,
keep_dim
=
True
),
true_emb_b
)
input_emb_re
=
fluid
.
layers
.
reshape
(
input_emb
,
shape
=
[
-
1
,
1
,
embedding_size
])
neg_matmul
=
fluid
.
layers
.
matmul
(
input_emb_re
,
neg_emb_w_re
,
transpose_y
=
True
)
neg_matmul_re
=
fluid
.
layers
.
reshape
(
neg_matmul
,
shape
=
[
-
1
,
neg_num
])
neg_logits
=
fluid
.
layers
.
elementwise_add
(
neg_matmul_re
,
neg_emb_b_vec
)
#nce loss
label_ones
=
fluid
.
layers
.
fill_constant_batch_size_like
(
true_logits
,
shape
=
[
-
1
,
1
],
value
=
1.0
,
dtype
=
'float32'
)
label_zeros
=
fluid
.
layers
.
fill_constant_batch_size_like
(
true_logits
,
shape
=
[
-
1
,
neg_num
],
value
=
0.0
,
dtype
=
'float32'
)
true_xent
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
true_logits
,
label_ones
)
neg_xent
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
neg_logits
,
label_zeros
)
cost
=
fluid
.
layers
.
elementwise_add
(
fluid
.
layers
.
reduce_sum
(
true_xent
,
dim
=
1
),
fluid
.
layers
.
reduce_sum
(
neg_xent
,
dim
=
1
))
avg_cost
=
fluid
.
layers
.
reduce_mean
(
cost
)
return
avg_cost
,
py_reader
def
infer_network
(
vocab_size
,
emb_size
):
analogy_a
=
fluid
.
layers
.
data
(
name
=
"analogy_a"
,
shape
=
[
1
],
dtype
=
'int64'
)
analogy_b
=
fluid
.
layers
.
data
(
name
=
"analogy_b"
,
shape
=
[
1
],
dtype
=
'int64'
)
analogy_c
=
fluid
.
layers
.
data
(
name
=
"analogy_c"
,
shape
=
[
1
],
dtype
=
'int64'
)
all_label
=
fluid
.
layers
.
data
(
name
=
"all_label"
,
shape
=
[
vocab_size
,
1
],
dtype
=
'int64'
,
append_batch_size
=
False
)
emb_all_label
=
fluid
.
layers
.
embedding
(
input
=
all_label
,
size
=
[
vocab_size
,
emb_size
],
param_attr
=
"emb"
)
emb_a
=
fluid
.
layers
.
embedding
(
input
=
analogy_a
,
size
=
[
vocab_size
,
emb_size
],
param_attr
=
"emb"
)
emb_b
=
fluid
.
layers
.
embedding
(
input
=
analogy_b
,
size
=
[
vocab_size
,
emb_size
],
param_attr
=
"emb"
)
emb_c
=
fluid
.
layers
.
embedding
(
input
=
analogy_c
,
size
=
[
vocab_size
,
emb_size
],
param_attr
=
"emb"
)
target
=
fluid
.
layers
.
elementwise_add
(
fluid
.
layers
.
elementwise_sub
(
emb_b
,
emb_a
),
emb_c
)
emb_all_label_l2
=
fluid
.
layers
.
l2_normalize
(
x
=
emb_all_label
,
axis
=
1
)
dist
=
fluid
.
layers
.
matmul
(
x
=
target
,
y
=
emb_all_label_l2
,
transpose_y
=
True
)
values
,
pred_idx
=
fluid
.
layers
.
topk
(
input
=
dist
,
k
=
4
)
return
values
,
pred_idx
demo/quant/quant_embedding/preprocess.py
0 → 100755
浏览文件 @
05b7d07d
# -*- coding: utf-8 -*
import
os
import
random
import
re
import
six
import
argparse
import
io
import
math
prog
=
re
.
compile
(
"[^a-z ]"
,
flags
=
0
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Paddle Fluid word2 vector preprocess"
)
parser
.
add_argument
(
'--build_dict_corpus_dir'
,
type
=
str
,
help
=
"The dir of corpus"
)
parser
.
add_argument
(
'--input_corpus_dir'
,
type
=
str
,
help
=
"The dir of input corpus"
)
parser
.
add_argument
(
'--output_corpus_dir'
,
type
=
str
,
help
=
"The dir of output corpus"
)
parser
.
add_argument
(
'--dict_path'
,
type
=
str
,
default
=
'./dict'
,
help
=
"The path of dictionary "
)
parser
.
add_argument
(
'--min_count'
,
type
=
int
,
default
=
5
,
help
=
"If the word count is less then min_count, it will be removed from dict"
)
parser
.
add_argument
(
'--downsample'
,
type
=
float
,
default
=
0.001
,
help
=
"filter word by downsample"
)
parser
.
add_argument
(
'--filter_corpus'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Filter corpus'
)
parser
.
add_argument
(
'--build_dict'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Build dict from corpus'
)
return
parser
.
parse_args
()
def
text_strip
(
text
):
#English Preprocess Rule
return
prog
.
sub
(
""
,
text
.
lower
())
# Shameless copy from Tensorflow https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/text_encoder.py
# Unicode utility functions that work with Python 2 and 3
def
native_to_unicode
(
s
):
if
_is_unicode
(
s
):
return
s
try
:
return
_to_unicode
(
s
)
except
UnicodeDecodeError
:
res
=
_to_unicode
(
s
,
ignore_errors
=
True
)
return
res
def
_is_unicode
(
s
):
if
six
.
PY2
:
if
isinstance
(
s
,
unicode
):
return
True
else
:
if
isinstance
(
s
,
str
):
return
True
return
False
def
_to_unicode
(
s
,
ignore_errors
=
False
):
if
_is_unicode
(
s
):
return
s
error_mode
=
"ignore"
if
ignore_errors
else
"strict"
return
s
.
decode
(
"utf-8"
,
errors
=
error_mode
)
def
filter_corpus
(
args
):
"""
filter corpus and convert id.
"""
word_count
=
dict
()
word_to_id_
=
dict
()
word_all_count
=
0
id_counts
=
[]
word_id
=
0
#read dict
with
io
.
open
(
args
.
dict_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
line
in
f
:
word
,
count
=
line
.
split
()[
0
],
int
(
line
.
split
()[
1
])
word_count
[
word
]
=
count
word_to_id_
[
word
]
=
word_id
word_id
+=
1
id_counts
.
append
(
count
)
word_all_count
+=
count
#write word2id file
print
(
"write word2id file to : "
+
args
.
dict_path
+
"_word_to_id_"
)
with
io
.
open
(
args
.
dict_path
+
"_word_to_id_"
,
'w+'
,
encoding
=
'utf-8'
)
as
fid
:
for
k
,
v
in
word_to_id_
.
items
():
fid
.
write
(
k
+
" "
+
str
(
v
)
+
'
\n
'
)
#filter corpus and convert id
if
not
os
.
path
.
exists
(
args
.
output_corpus_dir
):
os
.
makedirs
(
args
.
output_corpus_dir
)
for
file
in
os
.
listdir
(
args
.
input_corpus_dir
):
with
io
.
open
(
args
.
output_corpus_dir
+
'/convert_'
+
file
,
"w"
)
as
wf
:
with
io
.
open
(
args
.
input_corpus_dir
+
'/'
+
file
,
encoding
=
'utf-8'
)
as
rf
:
print
(
args
.
input_corpus_dir
+
'/'
+
file
)
for
line
in
rf
:
signal
=
False
line
=
text_strip
(
line
)
words
=
line
.
split
()
for
item
in
words
:
if
item
in
word_count
:
idx
=
word_to_id_
[
item
]
else
:
idx
=
word_to_id_
[
native_to_unicode
(
'<UNK>'
)]
count_w
=
id_counts
[
idx
]
corpus_size
=
word_all_count
keep_prob
=
(
math
.
sqrt
(
count_w
/
(
args
.
downsample
*
corpus_size
))
+
1
)
*
(
args
.
downsample
*
corpus_size
)
/
count_w
r_value
=
random
.
random
()
if
r_value
>
keep_prob
:
continue
wf
.
write
(
_to_unicode
(
str
(
idx
)
+
" "
))
signal
=
True
if
signal
:
wf
.
write
(
_to_unicode
(
"
\n
"
))
def
build_dict
(
args
):
"""
proprocess the data, generate dictionary and save into dict_path.
:param corpus_dir: the input data dir.
:param dict_path: the generated dict path. the data in dict is "word count"
:param min_count:
:return:
"""
# word to count
word_count
=
dict
()
for
file
in
os
.
listdir
(
args
.
build_dict_corpus_dir
):
with
io
.
open
(
args
.
build_dict_corpus_dir
+
"/"
+
file
,
encoding
=
'utf-8'
)
as
f
:
print
(
"build dict : "
,
args
.
build_dict_corpus_dir
+
"/"
+
file
)
for
line
in
f
:
line
=
text_strip
(
line
)
words
=
line
.
split
()
for
item
in
words
:
if
item
in
word_count
:
word_count
[
item
]
=
word_count
[
item
]
+
1
else
:
word_count
[
item
]
=
1
item_to_remove
=
[]
for
item
in
word_count
:
if
word_count
[
item
]
<=
args
.
min_count
:
item_to_remove
.
append
(
item
)
unk_sum
=
0
for
item
in
item_to_remove
:
unk_sum
+=
word_count
[
item
]
del
word_count
[
item
]
#sort by count
word_count
[
native_to_unicode
(
'<UNK>'
)]
=
unk_sum
word_count
=
sorted
(
word_count
.
items
(),
key
=
lambda
word_count
:
-
word_count
[
1
])
with
io
.
open
(
args
.
dict_path
,
'w+'
,
encoding
=
'utf-8'
)
as
f
:
for
k
,
v
in
word_count
:
f
.
write
(
k
+
" "
+
str
(
v
)
+
'
\n
'
)
if
__name__
==
"__main__"
:
args
=
parse_args
()
if
args
.
build_dict
:
build_dict
(
args
)
elif
args
.
filter_corpus
:
filter_corpus
(
args
)
else
:
print
(
"error command line, please choose --build_dict or --filter_corpus"
)
demo/quant/quant_embedding/reader.py
0 → 100755
浏览文件 @
05b7d07d
# -*- coding: utf-8 -*
import
numpy
as
np
import
preprocess
import
logging
import
math
import
random
import
io
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
)
logger
=
logging
.
getLogger
(
"fluid"
)
logger
.
setLevel
(
logging
.
INFO
)
class
NumpyRandomInt
(
object
):
def
__init__
(
self
,
a
,
b
,
buf_size
=
1000
):
self
.
idx
=
0
self
.
buffer
=
np
.
random
.
random_integers
(
a
,
b
,
buf_size
)
self
.
a
=
a
self
.
b
=
b
def
__call__
(
self
):
if
self
.
idx
==
len
(
self
.
buffer
):
self
.
buffer
=
np
.
random
.
random_integers
(
self
.
a
,
self
.
b
,
len
(
self
.
buffer
))
self
.
idx
=
0
result
=
self
.
buffer
[
self
.
idx
]
self
.
idx
+=
1
return
result
class
Word2VecReader
(
object
):
def
__init__
(
self
,
dict_path
,
data_path
,
filelist
,
trainer_id
,
trainer_num
,
window_size
=
5
):
self
.
window_size_
=
window_size
self
.
data_path_
=
data_path
self
.
filelist
=
filelist
self
.
trainer_id
=
trainer_id
self
.
trainer_num
=
trainer_num
word_all_count
=
0
id_counts
=
[]
word_id
=
0
with
io
.
open
(
dict_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
line
in
f
:
word
,
count
=
line
.
split
()[
0
],
int
(
line
.
split
()[
1
])
word_id
+=
1
id_counts
.
append
(
count
)
word_all_count
+=
count
self
.
word_all_count
=
word_all_count
self
.
corpus_size_
=
word_all_count
self
.
dict_size
=
len
(
id_counts
)
self
.
id_counts_
=
id_counts
print
(
"corpus_size:"
,
self
.
corpus_size_
)
self
.
id_frequencys
=
[
float
(
count
)
/
word_all_count
for
count
in
self
.
id_counts_
]
print
(
"dict_size = "
+
str
(
self
.
dict_size
)
+
" word_all_count = "
+
str
(
word_all_count
))
self
.
random_generator
=
NumpyRandomInt
(
1
,
self
.
window_size_
+
1
)
def
get_context_words
(
self
,
words
,
idx
):
"""
Get the context word list of target word.
words: the words of the current line
idx: input word index
window_size: window size
"""
target_window
=
self
.
random_generator
()
start_point
=
idx
-
target_window
# if (idx - target_window) > 0 else 0
if
start_point
<
0
:
start_point
=
0
end_point
=
idx
+
target_window
targets
=
words
[
start_point
:
idx
]
+
words
[
idx
+
1
:
end_point
+
1
]
return
targets
def
train
(
self
):
def
nce_reader
():
for
file
in
self
.
filelist
:
with
io
.
open
(
self
.
data_path_
+
"/"
+
file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
logger
.
info
(
"running data in {}"
.
format
(
self
.
data_path_
+
"/"
+
file
))
count
=
1
for
line
in
f
:
if
self
.
trainer_id
==
count
%
self
.
trainer_num
:
word_ids
=
[
int
(
w
)
for
w
in
line
.
split
()]
for
idx
,
target_id
in
enumerate
(
word_ids
):
context_word_ids
=
self
.
get_context_words
(
word_ids
,
idx
)
for
context_id
in
context_word_ids
:
yield
[
target_id
],
[
context_id
]
count
+=
1
return
nce_reader
demo/quant/quant_embedding/train.py
0 → 100755
浏览文件 @
05b7d07d
from
__future__
import
print_function
import
argparse
import
logging
import
os
import
time
import
math
import
random
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
import
six
import
reader
from
net
import
skip_gram_word2vec
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
)
logger
=
logging
.
getLogger
(
"fluid"
)
logger
.
setLevel
(
logging
.
INFO
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"PaddlePaddle Word2vec example"
)
parser
.
add_argument
(
'--train_data_dir'
,
type
=
str
,
default
=
'./data/text'
,
help
=
"The path of taining dataset"
)
parser
.
add_argument
(
'--base_lr'
,
type
=
float
,
default
=
0.01
,
help
=
"The number of learing rate (default: 0.01)"
)
parser
.
add_argument
(
'--save_step'
,
type
=
int
,
default
=
500000
,
help
=
"The number of step to save (default: 500000)"
)
parser
.
add_argument
(
'--print_batch'
,
type
=
int
,
default
=
10
,
help
=
"The number of print_batch (default: 10)"
)
parser
.
add_argument
(
'--dict_path'
,
type
=
str
,
default
=
'./data/1-billion_dict'
,
help
=
"The path of data dict"
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
500
,
help
=
"The size of mini-batch (default:500)"
)
parser
.
add_argument
(
'--num_passes'
,
type
=
int
,
default
=
10
,
help
=
"The number of passes to train (default: 10)"
)
parser
.
add_argument
(
'--model_output_dir'
,
type
=
str
,
default
=
'models'
,
help
=
'The path for model to store (default: models)'
)
parser
.
add_argument
(
'--nce_num'
,
type
=
int
,
default
=
5
,
help
=
'nce_num'
)
parser
.
add_argument
(
'--embedding_size'
,
type
=
int
,
default
=
64
,
help
=
'sparse feature hashing space for index processing'
)
parser
.
add_argument
(
'--is_sparse'
,
action
=
'store_true'
,
required
=
False
,
default
=
False
,
help
=
'embedding and nce will use sparse or not, (default: False)'
)
parser
.
add_argument
(
'--with_speed'
,
action
=
'store_true'
,
required
=
False
,
default
=
False
,
help
=
'print speed or not , (default: False)'
)
return
parser
.
parse_args
()
def
convert_python_to_tensor
(
weight
,
batch_size
,
sample_reader
):
def
__reader__
():
cs
=
np
.
array
(
weight
).
cumsum
()
result
=
[[],
[]]
for
sample
in
sample_reader
():
for
i
,
fea
in
enumerate
(
sample
):
result
[
i
].
append
(
fea
)
if
len
(
result
[
0
])
==
batch_size
:
tensor_result
=
[]
for
tensor
in
result
:
t
=
fluid
.
Tensor
()
dat
=
np
.
array
(
tensor
,
dtype
=
'int64'
)
if
len
(
dat
.
shape
)
>
2
:
dat
=
dat
.
reshape
((
dat
.
shape
[
0
],
dat
.
shape
[
2
]))
elif
len
(
dat
.
shape
)
==
1
:
dat
=
dat
.
reshape
((
-
1
,
1
))
t
.
set
(
dat
,
fluid
.
CPUPlace
())
tensor_result
.
append
(
t
)
tt
=
fluid
.
Tensor
()
neg_array
=
cs
.
searchsorted
(
np
.
random
.
sample
(
args
.
nce_num
))
neg_array
=
np
.
tile
(
neg_array
,
batch_size
)
tt
.
set
(
neg_array
.
reshape
((
batch_size
,
args
.
nce_num
)),
fluid
.
CPUPlace
())
tensor_result
.
append
(
tt
)
yield
tensor_result
result
=
[[],
[]]
return
__reader__
def
train_loop
(
args
,
train_program
,
reader
,
py_reader
,
loss
,
trainer_id
,
weight
):
py_reader
.
decorate_tensor_provider
(
convert_python_to_tensor
(
weight
,
args
.
batch_size
,
reader
.
train
()))
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
exec_strategy
=
fluid
.
ExecutionStrategy
()
exec_strategy
.
use_experimental_executor
=
True
print
(
"CPU_NUM:"
+
str
(
os
.
getenv
(
"CPU_NUM"
)))
exec_strategy
.
num_threads
=
int
(
os
.
getenv
(
"CPU_NUM"
))
build_strategy
=
fluid
.
BuildStrategy
()
if
int
(
os
.
getenv
(
"CPU_NUM"
))
>
1
:
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
train_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
False
,
loss_name
=
loss
.
name
,
main_program
=
train_program
,
build_strategy
=
build_strategy
,
exec_strategy
=
exec_strategy
)
for
pass_id
in
range
(
args
.
num_passes
):
py_reader
.
start
()
time
.
sleep
(
10
)
epoch_start
=
time
.
time
()
batch_id
=
0
start
=
time
.
time
()
try
:
while
True
:
loss_val
=
train_exe
.
run
(
fetch_list
=
[
loss
.
name
])
loss_val
=
np
.
mean
(
loss_val
)
if
batch_id
%
args
.
print_batch
==
0
:
logger
.
info
(
"TRAIN --> pass: {} batch: {} loss: {} reader queue:{}"
.
format
(
pass_id
,
batch_id
,
loss_val
.
mean
(),
py_reader
.
queue
.
size
()))
if
args
.
with_speed
:
if
batch_id
%
500
==
0
and
batch_id
!=
0
:
elapsed
=
(
time
.
time
()
-
start
)
start
=
time
.
time
()
samples
=
1001
*
args
.
batch_size
*
int
(
os
.
getenv
(
"CPU_NUM"
))
logger
.
info
(
"Time used: {}, Samples/Sec: {}"
.
format
(
elapsed
,
samples
/
elapsed
))
if
batch_id
%
args
.
save_step
==
0
and
batch_id
!=
0
:
model_dir
=
args
.
model_output_dir
+
'/pass-'
+
str
(
pass_id
)
+
(
'/batch-'
+
str
(
batch_id
))
if
trainer_id
==
0
:
fluid
.
io
.
save_params
(
executor
=
exe
,
dirname
=
model_dir
)
print
(
"model saved in %s"
%
model_dir
)
batch_id
+=
1
except
fluid
.
core
.
EOFException
:
py_reader
.
reset
()
epoch_end
=
time
.
time
()
logger
.
info
(
"Epoch: {0}, Train total expend: {1} "
.
format
(
pass_id
,
epoch_end
-
epoch_start
))
model_dir
=
args
.
model_output_dir
+
'/pass-'
+
str
(
pass_id
)
if
trainer_id
==
0
:
fluid
.
io
.
save_params
(
executor
=
exe
,
dirname
=
model_dir
)
print
(
"model saved in %s"
%
model_dir
)
def
GetFileList
(
data_path
):
return
os
.
listdir
(
data_path
)
def
train
(
args
):
if
not
os
.
path
.
isdir
(
args
.
model_output_dir
):
os
.
mkdir
(
args
.
model_output_dir
)
filelist
=
GetFileList
(
args
.
train_data_dir
)
word2vec_reader
=
reader
.
Word2VecReader
(
args
.
dict_path
,
args
.
train_data_dir
,
filelist
,
0
,
1
)
logger
.
info
(
"dict_size: {}"
.
format
(
word2vec_reader
.
dict_size
))
np_power
=
np
.
power
(
np
.
array
(
word2vec_reader
.
id_frequencys
),
0.75
)
id_frequencys_pow
=
np_power
/
np_power
.
sum
()
loss
,
py_reader
=
skip_gram_word2vec
(
word2vec_reader
.
dict_size
,
args
.
embedding_size
,
is_sparse
=
args
.
is_sparse
,
neg_num
=
args
.
nce_num
)
optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
fluid
.
layers
.
exponential_decay
(
learning_rate
=
args
.
base_lr
,
decay_steps
=
100000
,
decay_rate
=
0.999
,
staircase
=
True
))
optimizer
.
minimize
(
loss
)
# do local training
logger
.
info
(
"run local training"
)
main_program
=
fluid
.
default_main_program
()
train_loop
(
args
,
main_program
,
word2vec_reader
,
py_reader
,
loss
,
0
,
id_frequencys_pow
)
if
__name__
==
'__main__'
:
args
=
parse_args
()
train
(
args
)
demo/quant/quant_embedding/utils.py
0 → 100755
浏览文件 @
05b7d07d
import
sys
import
collections
import
six
import
time
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle
import
os
import
preprocess
def
BuildWord_IdMap
(
dict_path
):
word_to_id
=
dict
()
id_to_word
=
dict
()
with
open
(
dict_path
,
'r'
)
as
f
:
for
line
in
f
:
word_to_id
[
line
.
split
(
' '
)[
0
]]
=
int
(
line
.
split
(
' '
)[
1
])
id_to_word
[
int
(
line
.
split
(
' '
)[
1
])]
=
line
.
split
(
' '
)[
0
]
return
word_to_id
,
id_to_word
def
prepare_data
(
file_dir
,
dict_path
,
batch_size
):
w2i
,
i2w
=
BuildWord_IdMap
(
dict_path
)
vocab_size
=
len
(
i2w
)
reader
=
paddle
.
batch
(
test
(
file_dir
,
w2i
),
batch_size
)
return
vocab_size
,
reader
,
i2w
def
native_to_unicode
(
s
):
if
_is_unicode
(
s
):
return
s
try
:
return
_to_unicode
(
s
)
except
UnicodeDecodeError
:
res
=
_to_unicode
(
s
,
ignore_errors
=
True
)
return
res
def
_is_unicode
(
s
):
if
six
.
PY2
:
if
isinstance
(
s
,
unicode
):
return
True
else
:
if
isinstance
(
s
,
str
):
return
True
return
False
def
_to_unicode
(
s
,
ignore_errors
=
False
):
if
_is_unicode
(
s
):
return
s
error_mode
=
"ignore"
if
ignore_errors
else
"strict"
return
s
.
decode
(
"utf-8"
,
errors
=
error_mode
)
def
strip_lines
(
line
,
vocab
):
return
_replace_oov
(
vocab
,
native_to_unicode
(
line
))
def
_replace_oov
(
original_vocab
,
line
):
"""Replace out-of-vocab words with "<UNK>".
This maintains compatibility with published results.
Args:
original_vocab: a set of strings (The standard vocabulary for the dataset)
line: a unicode string - a space-delimited sequence of words.
Returns:
a unicode string - a space-delimited sequence of words.
"""
return
u
" "
.
join
([
word
if
word
in
original_vocab
else
u
"<UNK>"
for
word
in
line
.
split
()
])
def
reader_creator
(
file_dir
,
word_to_id
):
def
reader
():
files
=
os
.
listdir
(
file_dir
)
for
fi
in
files
:
with
open
(
file_dir
+
'/'
+
fi
,
"r"
)
as
f
:
for
line
in
f
:
if
':'
in
line
:
pass
else
:
line
=
strip_lines
(
line
.
lower
(),
word_to_id
)
line
=
line
.
split
()
yield
[
word_to_id
[
line
[
0
]]],
[
word_to_id
[
line
[
1
]]],
[
word_to_id
[
line
[
2
]]
],
[
word_to_id
[
line
[
3
]]],
[
word_to_id
[
line
[
0
]],
word_to_id
[
line
[
1
]],
word_to_id
[
line
[
2
]]
]
return
reader
def
test
(
test_dir
,
w2i
):
return
reader_creator
(
test_dir
,
w2i
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录