README.md 4.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# 基于skip-gram的word2vector模型

以下是本例的简要目录结构及说明:

```text
.
├── train.py            # 训练函数
├── infer.py            # 预测脚本
├── net.py              # 网络结构
├── preprocess.py       # 预处理脚本,包括构建词典和预处理文本
├── reader.py           # 训练阶段的文本读写
├── README.md           # 使用说明
├── train.py            # 训练函数
└── utils.py            # 通用函数

```

## 介绍
本例实现了skip-gram模式的word2vector模型。

1
123malin 已提交
21
**目前模型库下模型均要求使用PaddlePaddle 1.6及以上版本或适当的develop版本。若要使用shuffle_batch功能,则需使用PaddlePaddle 1.7及以上版本。**
Z
zhang wenhui 已提交
22

L
Li Fuchen 已提交
23
同时推荐用户参考[ IPython Notebook demo](https://aistudio.baidu.com/aistudio/projectDetail/124377)
24 25 26 27 28

## 数据下载
全量数据集使用的是来自1 Billion Word Language Model Benchmark的(http://www.statmt.org/lm-benchmark) 的数据集.

```bash
Z
zhangwenhui03 已提交
29
mkdir data
30
wget http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz
Z
zhang wenhui 已提交
31
tar xzvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz
32 33 34 35 36 37
mv 1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/ data/
```

备用数据地址下载命令如下

```bash
Z
zhangwenhui03 已提交
38
mkdir data
Z
zhang wenhui 已提交
39
wget --no-check-certificate https://paddlerec.bj.bcebos.com/word2vec/1-billion-word-language-modeling-benchmark-r13output.tar
40 41 42 43 44 45 46
tar xvf 1-billion-word-language-modeling-benchmark-r13output.tar
mv 1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/ data/
```

为了方便快速验证,我们也提供了经典的text8样例数据集,包含1700w个词。 下载命令如下

```bash
Z
zhangwenhui03 已提交
47
mkdir data
Z
zhang wenhui 已提交
48
wget --no-check-certificate https://paddlerec.bj.bcebos.com/word2vec/text.tar
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
tar xvf text.tar
mv text data/
```


## 数据预处理
以样例数据集为例进行预处理。全量数据集注意解压后以training-monolingual.tokenized.shuffled 目录为预处理目录,和样例数据集的text目录并列。

词典格式: 词<空格>词频。注意低频词用'UNK'表示

可以按格式自建词典,如果自建词典跳过第一步。
```
the 1061396
of 593677
and 416629
one 411764
in 372201
a 325873
<UNK> 324608
to 316376
zero 264975
nine 250430
```

第一步根据英文语料生成词典,中文语料可以通过修改text_strip方法自定义处理方法。

```bash
python preprocess.py --build_dict --build_dict_corpus_dir data/text/ --dict_path data/test_build_dict
```

Z
zhangwenhui03 已提交
79
第二步根据词典将文本转成id, 同时进行downsample,按照概率过滤常见词, 同时生成word和id映射的文件,文件名为词典+"_word_to_id_"。
80 81

```bash
Z
zhangwenhui03 已提交
82
python preprocess.py --filter_corpus --dict_path data/test_build_dict --input_corpus_dir data/text --output_corpus_dir data/convert_text8 --min_count 5 --downsample 0.001
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
```

## 训练
具体的参数配置可运行


```bash
python train.py -h
```

单机多线程训练
```bash
OPENBLAS_NUM_THREADS=1 CPU_NUM=5 python train.py --train_data_dir data/convert_text8 --dict_path data/test_build_dict --num_passes 10 --batch_size 100 --model_output_dir v1_cpu5_b100_lr1dir --base_lr 1.0 --print_batch 1000 --with_speed --is_sparse
```

1
123malin 已提交
98
若需要开启shuffle_batch功能,需在命令中加入`--with_shuffle_batch`。单机模拟分布式多机训练,需更改`cluster_train.sh`文件,在各个节点的启动命令中加入`--with_shuffle_batch`
99 100 101 102 103 104

## 预测
测试集下载命令如下

```bash
#全量数据集测试集
Z
zhang wenhui 已提交
105
wget --no-check-certificate https://paddlerec.bj.bcebos.com/word2vec/test_dir.tar
106
#样本数据集测试集
Z
zhang wenhui 已提交
107
wget --no-check-certificate https://paddlerec.bj.bcebos.com/word2vec/test_mid_dir.tar
108 109
```

Z
zhangwenhui03 已提交
110
预测命令,注意词典名称需要加后缀"_word_to_id_", 此文件是预处理阶段生成的。
111
```bash
Z
zhangwenhui03 已提交
112
python infer.py --infer_epoch --test_dir data/test_mid_dir --dict_path data/test_build_dict_word_to_id_ --batch_size 20000 --model_dir v1_cpu5_b100_lr1dir/  --start_index 0 --last_index 10
113
```