提交 be4ad5f4 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #60 from zhaopu7/develop

add neural language model.
TBD
# 语言模型
## 简介
语言模型即 Language Model,简称LM,它是一个概率分布模型,简单来说,就是用来计算一个句子的概率的模型。利用它可以确定哪个词序列的可能性更大,或者给定若干个词,可以预测下一个最可能出现的词语。它是自然语言处理领域的一个重要的基础模型。
## 应用场景
**语言模型被应用在很多领域**,如:
* **自动写作**:语言模型可以根据上文生成下一个词,递归下去可以生成整个句子、段落、篇章。
* **QA**:语言模型可以根据Question生成Answer。
* **机器翻译**:当前主流的机器翻译模型大多基于Encoder-Decoder模式,其中Decoder就是一个语言模型,用来生成目标语言。
* **拼写检查**:语言模型可以计算出词语序列的概率,一般在拼写错误处序列的概率会骤减,可以用来识别拼写错误并提供改正候选集。
* **词性标注、句法分析、语音识别......**
## 关于本例
Language Model 常见的实现方式有 N-Gram、RNN、seq2seq。本例中实现了基于N-Gram、RNN的语言模型。**本例的文件结构如下**`images` 文件夹与使用无关可不关心):
```text
.
├── data # toy、demo数据,用户可据此格式化自己的数据
│ ├── chinese.test.txt # test用的数据demo
| ├── chinese.train.txt # train用的数据demo
│ └── input.txt # infer用的输入数据demo
├── config.py # 配置文件,包括data、train、infer相关配置
├── infer.py # 预测任务脚本,即生成文本
├── network_conf.py # 本例中涉及的各种网络结构均定义在此文件中,希望进一步修改模型结构,请修改此文件
├── reader.py # 读取数据接口
├── README.md # 文档
├── train.py # 训练任务脚本
└── utils.py # 定义通用的函数,例如:构建字典、加载字典等
```
**注:一般情况下基于N-Gram的语言模型不如基于RNN的语言模型效果好,所以实际使用时建议使用基于RNN的语言模型,本例中也将着重介绍基于RNN的模型,简略介绍基于N-Gram的模型。**
## RNN 语言模型
### 简介
RNN是一个序列模型,基本思路是:在时刻t,将前一时刻t-1的隐藏层输出和t时刻的词向量一起输入到隐藏层从而得到时刻t的特征表示,然后用这个特征表示得到t时刻的预测输出,如此在时间维上递归下去,如下图所示:
<p align=center><img src='images/rnn_str.png' width='500px'/></p>
可以看出RNN善于使用上文信息、历史知识,具有“记忆”功能。理论上RNN能实现“长依赖”(即利用很久之前的知识),但在实际应用中发现效果并不理想,于是出现了很多RNN的变种,如常用的LSTM和GRU,它们对传统RNN的cell进行了改进,弥补了RNN的不足,下图是LSTM的示意图:
<p align=center><img src='images/lstm.png' width='500px'/></p>
本例中即使用了LSTM、GRU。
### 模型实现
本例中RNN语言模型的实现简介如下:
* **定义模型参数**`config.py`中的`Config_rnn`**类**中定义了模型的参数变量。
* **定义模型结构**`network_conf.py`中的`rnn_lm`**函数**中定义了模型的**结构**,如下:
* 输入层:将输入的词(或字)序列映射成向量,即embedding。
* 中间层:根据配置实现RNN层,将上一步得到的embedding向量序列作为输入。
* 输出层:使用softmax归一化计算单词的概率,将output结果返回
* loss:定义模型的cost为多类交叉熵损失函数。
* **训练模型**`train.py`中的`main`方法实现了模型的训练,实现流程如下:
* 准备输入数据:建立并保存词典、构建train和test数据的reader。
* 初始化模型:包括模型的结构、参数。
* 构建训练器:demo中使用的是Adam优化算法。
* 定义回调函数:构建`event_handler`来跟踪训练过程中loss的变化,并在每轮时结束保存模型的参数。
* 训练:使用trainer训练模型。
* **生成文本**`infer.py`中的`main`方法实现了文本的生成,实现流程如下:
* 根据配置选择生成方法:RNN模型 or N-Gram模型。
* 加载train好的模型和词典文件。
* 读取`input_file`文件(每行为一个sentence的前缀),用启发式图搜索算法`beam_search`根据各sentence的前缀生成文本。
* 将生成的文本及其前缀保存到文件`output_file`
## N-Gram 语言模型
### 简介
N-Gram模型也称为N-1阶马尔科夫模型,它有一个有限历史假设:当前词的出现概率仅仅与前面N-1个词相关。一般采用最大似然估计(Maximum Likelihood Estimation,MLE)的方法对模型的参数进行估计。当N取1、2、3时,N-Gram模型分别称为unigram、bigram和trigram语言模型。一般情况下,N越大、训练语料的规模越大,参数估计的结果越可靠,但由于模型较简单、表达能力不强以及数据稀疏等问题。一般情况下用N-Gram实现的语言模型不如RNN、seq2seq效果好。
### 模型实现
本例中N-Gram语言模型的实现简介如下:
* **定义模型参数**`config.py`中的`Config_ngram`**类**中定义了模型的参数变量。
* **定义模型结构**`network_conf.py`中的`ngram_lm`**函数**中定义了模型的**结构**,如下:
* 输入层:本例中N取5,将前四个词分别做embedding,然后连接起来作为输入。
* 中间层:根据配置实现DNN层,将上一步得到的embedding向量序列作为输入。
* 输出层:使用softmax归一化计算单词的概率,将output结果返回
* loss:定义模型的cost为多类交叉熵损失函数。
* **训练模型**`train.py`中的`main`方法实现了模型的训练,实现流程与上文中RNN语言模型基本一致。
* **生成文本**`infer.py`中的`main`方法实现了文本的生成,实现流程与上文中RNN语言模型基本一致,区别在于构建input时本例会取每个前缀的最后4(N-1)个词作为输入。
## 使用说明
运行本例的方法如下:
* 1,运行`python train.py`命令,开始train模型(默认使用RNN),待训练结束。
* 2,运行`python infer.py`命令做prediction。(输入的文本默认为`data/input.txt`,生成的文本默认保存到`data/output.txt`中。)
**如果用户需要使用自己的语料、定制模型,需要修改的地方主要是`语料`和`config.py`中的配置,需要注意的细节和适配工作详情如下:**
### 语料适配
* 清洗语料:去除原文中空格、tab、乱码,按需去除数字、标点符号、特殊符号等。
* 编码格式:utf-8,本例中已经对中文做了适配。
* 内容格式:每个句子占一行;每行中的各词之间使用一个空格符分开。
* 按需要配置`config.py`中对于data的配置:
```python
# -- config : data --
train_file = 'data/chinese.train.txt'
test_file = 'data/chinese.test.txt'
vocab_file = 'data/vocab_cn.txt' # the file to save vocab
build_vocab_method = 'fixed_size' # 'frequency' or 'fixed_size'
vocab_max_size = 3000 # when build_vocab_method = 'fixed_size'
unk_threshold = 1 # # when build_vocab_method = 'frequency'
min_sentence_length = 3
max_sentence_length = 60
```
其中,`build_vocab_method `指定了构建词典的方法:**1,按词频**,即将出现次数小于`unk_threshold `的词视为`<UNK>`;**2,按词典长度**,`vocab_max_size`定义了词典的最大长度,如果语料中出现的不同词的个数大于这个值,则根据各词的词频倒序排,取`top(vocab_max_size)`个词纳入词典。
其中`min_sentence_length`和`max_sentence_length `分别指定了句子的最小和最大长度,小于最小长度的和大于最大长度的句子将被过滤掉、不参与训练。
*注:需要注意的是词典越大生成的内容越丰富但训练耗时越久,一般中文分词之后,语料中不同的词能有几万乃至几十万,如果vocab\_max\_size取值过小则导致\<UNK\>占比过高,如果vocab\_max\_size取值较大则严重影响训练速度(对精度也有影响),所以也有“按字”训练模型的方式,即:把每个汉字当做一个词,常用汉字也就几千个,使得字典的大小不会太大、不会丢失太多信息,但汉语中同一个字在不同词中语义相差很大,有时导致模型效果不理想。建议用户多试试、根据实际情况选择是“按词训练”还是“按字训练”。*
### 模型适配、训练
* 按需调整`config.py`中对于模型的配置,详解如下:
```python
# -- config : train --
use_which_model = 'rnn' # must be: 'rnn' or 'ngram'
use_gpu = False # whether to use gpu
trainer_count = 1 # number of trainer
class Config_rnn(object):
"""
config for RNN language model
"""
rnn_type = 'gru' # or 'lstm'
emb_dim = 200
hidden_size = 200
num_layer = 2
num_passs = 2
batch_size = 32
model_file_name_prefix = 'lm_' + rnn_type + '_params_pass_'
class Config_ngram(object):
"""
config for N-Gram language model
"""
emb_dim = 200
hidden_size = 200
num_layer = 2
N = 5
num_passs = 2
batch_size = 32
model_file_name_prefix = 'lm_ngram_pass_'
```
其中,`use_which_model`指定了要train的模型,如果使用RNN语言模型则设置为'rnn',如果使用N-Gram语言模型则设置为'ngram';`use_gpu`指定了train的时候是否使用gpu;`trainer_count`指定了并行度、用几个trainer去train模型;`rnn_type` 用于配置rnn cell类型,可以取‘lstm’或‘gru’;`hidden_size`配置unit个数;`num_layer`配置RNN的层数;`num_passs`配置训练的轮数;`emb_dim`配置embedding的dimension;`batch_size `配置了train model时每个batch的大小;`model_file_name_prefix `配置了要保存的模型的名字前缀。
* 运行`python train.py`命令训练模型,模型将被保存到当前目录。
### 按需生成文本
* 按需调整`config.py`中对于infer的配置,详解如下:
```python
# -- config : infer --
input_file = 'data/input.txt' # input file contains sentence prefix each line
output_file = 'data/output.txt' # the file to save results
num_words = 10 # the max number of words need to generate
beam_size = 5 # beam_width, the number of the prediction sentence for each prefix
```
其中,`input_file`中保存的是带生成的文本前缀,utf-8编码,每个前缀占一行,形如:
```text
我 是
```
用户将需要生成的文本前缀按此格式存入文件即可;
`num_words`指定了要生成多少个单词(实际生成过程中遇到结束符会停止生成,所以实际生成的词个数可能会比此值小);`beam_size`指定了beam search方法的width,即每个前缀生成多少个候选词序列;`output_file`指定了生成结果的存放位置。
* 运行`python infer.py`命令生成文本,生成的结果格式如下:
```text
我 <EOS> 0.107702672482
我 爱 。我 中国 中国 <EOS> 0.000177299271939
我 爱 中国 。我 是 中国 <EOS> 4.51695544709e-05
我 爱 中国 中国 <EOS> 0.000910127729821
我 爱 中国 。我 是 <EOS> 0.00015957862922
```
其中,‘我’是前缀,其下方的五个句子时补全的结果,每个句子末尾的浮点数表示此句子的生成概率。
# coding=utf-8
# -- config : data --
train_file = 'data/chinese.train.txt'
test_file = 'data/chinese.test.txt'
vocab_file = 'data/vocab_cn.txt' # the file to save vocab
build_vocab_method = 'fixed_size' # 'frequency' or 'fixed_size'
vocab_max_size = 3000 # when build_vocab_method = 'fixed_size'
unk_threshold = 1 # # when build_vocab_method = 'frequency'
min_sentence_length = 3
max_sentence_length = 60
# -- config : train --
use_which_model = 'ngram' # must be: 'rnn' or 'ngram'
use_gpu = False # whether to use gpu
trainer_count = 1 # number of trainer
class Config_rnn(object):
"""
config for RNN language model
"""
rnn_type = 'gru' # or 'lstm'
emb_dim = 200
hidden_size = 200
num_layer = 2
num_passs = 2
batch_size = 32
model_file_name_prefix = 'lm_' + rnn_type + '_params_pass_'
class Config_ngram(object):
"""
config for N-Gram language model
"""
emb_dim = 200
hidden_size = 200
num_layer = 2
N = 5
num_passs = 2
batch_size = 32
model_file_name_prefix = 'lm_ngram_pass_'
# -- config : infer --
input_file = 'data/input.txt' # input file contains sentence prefix each line
output_file = 'data/output.txt' # the file to save results
num_words = 10 # the max number of words need to generate
beam_size = 5 # beam_width, the number of the prediction sentence for each prefix
我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。
\ No newline at end of file
我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。我 是 中国 人 。
我 爱 中国 。
\ No newline at end of file
我 是
我 是 中国
我 爱
我 是 中国 人。
我 爱 中国
我 爱 中国 。我
我 爱 中国 。我 爱
我 爱 中国 。我 是
我 爱 中国 。我 是 中国
\ No newline at end of file
# coding=utf-8
import paddle.v2 as paddle
import gzip
import numpy as np
from utils import *
import network_conf
from config import *
def generate_using_rnn(word_id_dict, num_words, beam_size):
"""
Demo: use RNN model to do prediction.
:param word_id_dict: vocab.
:type word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
:param num_words: the number of the words to generate.
:type num_words: int
:param beam_size: beam width.
:type beam_size: int
:return: save prediction results to output_file
"""
# prepare and cache model
config = Config_rnn()
_, output_layer = network_conf.rnn_lm(
vocab_size=len(word_id_dict),
emb_dim=config.emb_dim,
rnn_type=config.rnn_type,
hidden_size=config.hidden_size,
num_layer=config.num_layer) # network config
model_file_name = config.model_file_name_prefix + str(config.num_passs -
1) + '.tar.gz'
parameters = paddle.parameters.Parameters.from_tar(
gzip.open(model_file_name)) # load parameters
inferer = paddle.inference.Inference(
output_layer=output_layer, parameters=parameters)
# tools, different from generate_using_ngram's tools
id_word_dict = dict(
[(v, k) for k, v in word_id_dict.items()]) # {id : word}
def str2ids(str):
return [[[
word_id_dict.get(w, word_id_dict['<UNK>']) for w in str.split()
]]]
def ids2str(ids):
return [[[id_word_dict.get(id, ' ') for id in ids]]]
# generate text
with open(input_file) as file:
output_f = open(output_file, 'w')
for line in file:
line = line.decode('utf-8').strip()
# generate
texts = {} # type: {text : probability}
texts[line] = 1
for _ in range(num_words):
texts_new = {}
for (text, prob) in texts.items():
if '<EOS>' in text: # stop prediction when <EOS> appear
texts_new[text] = prob
continue
# next word's probability distribution
predictions = inferer.infer(input=str2ids(text))
predictions[-1][word_id_dict['<UNK>']] = -1 # filter <UNK>
# find next beam_size words
for _ in range(beam_size):
cur_maxProb_index = np.argmax(
predictions[-1]) # next word's id
text_new = text + ' ' + id_word_dict[
cur_maxProb_index] # text append next word
texts_new[text_new] = texts[text] * predictions[-1][
cur_maxProb_index]
predictions[-1][cur_maxProb_index] = -1
texts.clear()
if len(texts_new) <= beam_size:
texts = texts_new
else: # cutting
texts = dict(
sorted(
texts_new.items(), key=lambda d: d[1], reverse=True)
[:beam_size])
# save results to output file
output_f.write(line.encode('utf-8') + '\n')
for (sentence, prob) in texts.items():
output_f.write('\t' + sentence.encode('utf-8', 'replace') + '\t'
+ str(prob) + '\n')
output_f.write('\n')
output_f.close()
print('already saved results to ' + output_file)
def generate_using_ngram(word_id_dict, num_words, beam_size):
"""
Demo: use N-Gram model to do prediction.
:param word_id_dict: vocab.
:type word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
:param num_words: the number of the words to generate.
:type num_words: int
:param beam_size: beam width.
:type beam_size: int
:return: save prediction results to output_file
"""
# prepare and cache model
config = Config_ngram()
_, output_layer = network_conf.ngram_lm(
vocab_size=len(word_id_dict),
emb_dim=config.emb_dim,
hidden_size=config.hidden_size,
num_layer=config.num_layer) # network config
model_file_name = config.model_file_name_prefix + str(config.num_passs -
1) + '.tar.gz'
parameters = paddle.parameters.Parameters.from_tar(
gzip.open(model_file_name)) # load parameters
inferer = paddle.inference.Inference(
output_layer=output_layer, parameters=parameters)
# tools, different from generate_using_rnn's tools
id_word_dict = dict(
[(v, k) for k, v in word_id_dict.items()]) # {id : word}
def str2ids(str):
return [[
word_id_dict.get(w, word_id_dict['<UNK>']) for w in str.split()
]]
def ids2str(ids):
return [[id_word_dict.get(id, ' ') for id in ids]]
# generate text
with open(input_file) as file:
output_f = open(output_file, 'w')
for line in file:
line = line.decode('utf-8').strip()
words = line.split()
if len(words) < config.N:
output_f.write(line.encode('utf-8') + "\n\tnone\n")
continue
# generate
texts = {} # type: {text : probability}
texts[line] = 1
for _ in range(num_words):
texts_new = {}
for (text, prob) in texts.items():
if '<EOS>' in text: # stop prediction when <EOS> appear
texts_new[text] = prob
continue
# next word's probability distribution
predictions = inferer.infer(
input=str2ids(' '.join(text.split()[-config.N:])))
predictions[-1][word_id_dict['<UNK>']] = -1 # filter <UNK>
# find next beam_size words
for _ in range(beam_size):
cur_maxProb_index = np.argmax(
predictions[-1]) # next word's id
text_new = text + ' ' + id_word_dict[
cur_maxProb_index] # text append nextWord
texts_new[text_new] = texts[text] * predictions[-1][
cur_maxProb_index]
predictions[-1][cur_maxProb_index] = -1
texts.clear()
if len(texts_new) <= beam_size:
texts = texts_new
else: # cutting
texts = dict(
sorted(
texts_new.items(), key=lambda d: d[1], reverse=True)
[:beam_size])
# save results to output file
output_f.write(line.encode('utf-8') + '\n')
for (sentence, prob) in texts.items():
output_f.write('\t' + sentence.encode('utf-8', 'replace') + '\t'
+ str(prob) + '\n')
output_f.write('\n')
output_f.close()
print('already saved results to ' + output_file)
def main():
# init paddle
paddle.init(use_gpu=use_gpu, trainer_count=trainer_count)
# prepare and cache vocab
if os.path.isfile(vocab_file):
word_id_dict = load_vocab(vocab_file) # load word dictionary
else:
if build_vocab_method == 'fixed_size':
word_id_dict = build_vocab_with_fixed_size(
train_file, vocab_max_size) # build vocab
else:
word_id_dict = build_vocab_using_threshhold(
train_file, unk_threshold) # build vocab
save_vocab(word_id_dict, vocab_file) # save vocab
# generate
if use_which_model == 'rnn':
generate_using_rnn(
word_id_dict=word_id_dict, num_words=num_words, beam_size=beam_size)
elif use_which_model == 'ngram':
generate_using_ngram(
word_id_dict=word_id_dict, num_words=num_words, beam_size=beam_size)
else:
raise Exception('use_which_model must be rnn or ngram!')
if __name__ == "__main__":
main()
# coding=utf-8
import paddle.v2 as paddle
def rnn_lm(vocab_size, emb_dim, rnn_type, hidden_size, num_layer):
"""
RNN language model definition.
:param vocab_size: size of vocab.
:param emb_dim: embedding vector's dimension.
:param rnn_type: the type of RNN cell.
:param hidden_size: number of unit.
:param num_layer: layer number.
:return: cost and output layer of model.
"""
assert emb_dim > 0 and hidden_size > 0 and vocab_size > 0 and num_layer > 0
# input layers
input = paddle.layer.data(
name="input", type=paddle.data_type.integer_value_sequence(vocab_size))
target = paddle.layer.data(
name="target", type=paddle.data_type.integer_value_sequence(vocab_size))
# embedding layer
input_emb = paddle.layer.embedding(input=input, size=emb_dim)
# rnn layer
if rnn_type == 'lstm':
rnn_cell = paddle.networks.simple_lstm(
input=input_emb, size=hidden_size)
for _ in range(num_layer - 1):
rnn_cell = paddle.networks.simple_lstm(
input=rnn_cell, size=hidden_size)
elif rnn_type == 'gru':
rnn_cell = paddle.networks.simple_gru(input=input_emb, size=hidden_size)
for _ in range(num_layer - 1):
rnn_cell = paddle.networks.simple_gru(
input=rnn_cell, size=hidden_size)
else:
raise Exception('rnn_type error!')
# fc(full connected) and output layer
output = paddle.layer.fc(
input=[rnn_cell], size=vocab_size, act=paddle.activation.Softmax())
# loss
cost = paddle.layer.classification_cost(input=output, label=target)
return cost, output
def ngram_lm(vocab_size, emb_dim, hidden_size, num_layer):
"""
N-Gram language model definition.
:param vocab_size: size of vocab.
:param emb_dim: embedding vector's dimension.
:param hidden_size: size of unit.
:param num_layer: layer number.
:return: cost and output layer of model.
"""
assert emb_dim > 0 and hidden_size > 0 and vocab_size > 0 and num_layer > 0
def wordemb(inlayer):
wordemb = paddle.layer.table_projection(
input=inlayer,
size=emb_dim,
param_attr=paddle.attr.Param(
name="_proj", initial_std=0.001, learning_rate=1, l2_rate=0))
return wordemb
# input layers
first_word = paddle.layer.data(
name="first_word", type=paddle.data_type.integer_value(vocab_size))
second_word = paddle.layer.data(
name="second_word", type=paddle.data_type.integer_value(vocab_size))
third_word = paddle.layer.data(
name="third_word", type=paddle.data_type.integer_value(vocab_size))
fourth_word = paddle.layer.data(
name="fourth_word", type=paddle.data_type.integer_value(vocab_size))
next_word = paddle.layer.data(
name="next_word", type=paddle.data_type.integer_value(vocab_size))
# embedding layer
first_emb = wordemb(first_word)
second_emb = wordemb(second_word)
third_emb = wordemb(third_word)
fourth_emb = wordemb(fourth_word)
context_emb = paddle.layer.concat(
input=[first_emb, second_emb, third_emb, fourth_emb])
# hidden layer
hidden = paddle.layer.fc(
input=context_emb, size=hidden_size, act=paddle.activation.Relu())
for _ in range(num_layer - 1):
hidden = paddle.layer.fc(
input=hidden, size=hidden_size, act=paddle.activation.Relu())
# fc(full connected) and output layer
predict_word = paddle.layer.fc(
input=[hidden], size=vocab_size, act=paddle.activation.Softmax())
# loss
cost = paddle.layer.classification_cost(input=predict_word, label=next_word)
return cost, predict_word
# coding=utf-8
import collections
import os
def rnn_reader(file_name, min_sentence_length, max_sentence_length,
word_id_dict):
"""
create reader for RNN, each line is a sample.
:param file_name: file name.
:param min_sentence_length: sentence's min length.
:param max_sentence_length: sentence's max length.
:param word_id_dict: vocab with content of '{word, id}', 'word' is string type , 'id' is int type.
:return: data reader.
"""
def reader():
UNK = word_id_dict['<UNK>']
with open(file_name) as file:
for line in file:
words = line.decode('utf-8', 'ignore').strip().split()
if len(words) < min_sentence_length or len(
words) > max_sentence_length:
continue
ids = [word_id_dict.get(w, UNK) for w in words]
ids.append(word_id_dict['<EOS>'])
target = ids[1:]
target.append(word_id_dict['<EOS>'])
yield ids[:], target[:]
return reader
def ngram_reader(file_name, N, word_id_dict):
"""
create reader for N-Gram.
:param file_name: file name.
:param N: N-Gram's N.
:param word_id_dict: vocab with content of '{word, id}', 'word' is string type , 'id' is int type.
:return: data reader.
"""
assert N >= 2
def reader():
ids = []
UNK_ID = word_id_dict['<UNK>']
cache_size = 10000000
with open(file_name) as file:
for line in file:
words = line.decode('utf-8', 'ignore').strip().split()
ids += [word_id_dict.get(w, UNK_ID) for w in words]
ids_len = len(ids)
if ids_len > cache_size: # output
for i in range(ids_len - N - 1):
yield tuple(ids[i:i + N])
ids = []
ids_len = len(ids)
for i in range(ids_len - N - 1):
yield tuple(ids[i:i + N])
return reader
# coding=utf-8
import sys
import paddle.v2 as paddle
import reader
from utils import *
import network_conf
import gzip
from config import *
def train(model_cost, train_reader, test_reader, model_file_name_prefix,
num_passes):
"""
train model.
:param model_cost: cost layer of the model to train.
:param train_reader: train data reader.
:param test_reader: test data reader.
:param model_file_name_prefix: model's prefix name.
:param num_passes: epoch.
:return:
"""
# init paddle
paddle.init(use_gpu=use_gpu, trainer_count=trainer_count)
# create parameters
parameters = paddle.parameters.create(model_cost)
# create optimizer
adam_optimizer = paddle.optimizer.Adam(
learning_rate=1e-3,
regularization=paddle.optimizer.L2Regularization(rate=1e-3),
model_average=paddle.optimizer.ModelAverage(
average_window=0.5, max_average_window=10000))
# create trainer
trainer = paddle.trainer.SGD(
cost=model_cost, parameters=parameters, update_equation=adam_optimizer)
# define event_handler callback
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print("\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics))
else:
sys.stdout.write('.')
sys.stdout.flush()
# save model each pass
if isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=test_reader)
print("\nTest with Pass %d, %s" % (event.pass_id, result.metrics))
with gzip.open(
model_file_name_prefix + str(event.pass_id) + '.tar.gz',
'w') as f:
parameters.to_tar(f)
# start to train
print('start training...')
trainer.train(
reader=train_reader, event_handler=event_handler, num_passes=num_passes)
print("Training finished.")
def main():
# prepare vocab
print('prepare vocab...')
if build_vocab_method == 'fixed_size':
word_id_dict = build_vocab_with_fixed_size(
train_file, vocab_max_size) # build vocab
else:
word_id_dict = build_vocab_using_threshhold(
train_file, unk_threshold) # build vocab
save_vocab(word_id_dict, vocab_file) # save vocab
# init model and data reader
if use_which_model == 'rnn':
# init RNN model
print('prepare rnn model...')
config = Config_rnn()
cost, _ = network_conf.rnn_lm(
len(word_id_dict), config.emb_dim, config.rnn_type,
config.hidden_size, config.num_layer)
# init RNN data reader
train_reader = paddle.batch(
paddle.reader.shuffle(
reader.rnn_reader(train_file, min_sentence_length,
max_sentence_length, word_id_dict),
buf_size=65536),
batch_size=config.batch_size)
test_reader = paddle.batch(
paddle.reader.shuffle(
reader.rnn_reader(test_file, min_sentence_length,
max_sentence_length, word_id_dict),
buf_size=65536),
batch_size=config.batch_size)
elif use_which_model == 'ngram':
# init N-Gram model
print('prepare ngram model...')
config = Config_ngram()
assert config.N == 5
cost, _ = network_conf.ngram_lm(
vocab_size=len(word_id_dict),
emb_dim=config.emb_dim,
hidden_size=config.hidden_size,
num_layer=config.num_layer)
# init N-Gram data reader
train_reader = paddle.batch(
paddle.reader.shuffle(
reader.ngram_reader(train_file, config.N, word_id_dict),
buf_size=65536),
batch_size=config.batch_size)
test_reader = paddle.batch(
paddle.reader.shuffle(
reader.ngram_reader(test_file, config.N, word_id_dict),
buf_size=65536),
batch_size=config.batch_size)
else:
raise Exception('use_which_model must be rnn or ngram!')
# train model
train(
model_cost=cost,
train_reader=train_reader,
test_reader=test_reader,
model_file_name_prefix=config.model_file_name_prefix,
num_passes=config.num_passs)
if __name__ == "__main__":
main()
# coding=utf-8
import os
import collections
def save_vocab(word_id_dict, vocab_file_name):
"""
save vocab.
:param word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
:param vocab_file_name: vocab file name.
"""
f = open(vocab_file_name, 'w')
for (k, v) in word_id_dict.items():
f.write(k.encode('utf-8') + '\t' + str(v) + '\n')
print('save vocab to ' + vocab_file_name)
f.close()
def load_vocab(vocab_file_name):
"""
load vocab from file.
:param vocab_file_name: vocab file name.
:return: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
"""
assert os.path.isfile(vocab_file_name)
dict = {}
with open(vocab_file_name) as file:
for line in file:
if len(line) < 2:
continue
kv = line.decode('utf-8').strip().split('\t')
dict[kv[0]] = int(kv[1])
return dict
def build_vocab_using_threshhold(file_name, unk_threshold):
"""
build vacab using_<UNK> threshhold.
:param file_name:
:param unk_threshold: <UNK> threshhold.
:type unk_threshold: int.
:return: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
"""
counter = {}
with open(file_name) as file:
for line in file:
words = line.decode('utf-8', 'ignore').strip().split()
for word in words:
if word in counter:
counter[word] += 1
else:
counter[word] = 1
counter_new = {}
for (word, frequency) in counter.items():
if frequency >= unk_threshold:
counter_new[word] = frequency
counter.clear()
counter_new = sorted(counter_new.items(), key=lambda d: -d[1])
words = [word_frequency[0] for word_frequency in counter_new]
word_id_dict = dict(zip(words, range(2, len(words) + 2)))
word_id_dict['<UNK>'] = 0
word_id_dict['<EOS>'] = 1
return word_id_dict
def build_vocab_with_fixed_size(file_name, vocab_max_size):
"""
build vacab with assigned max size.
:param vocab_max_size: vocab's max size.
:return: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
"""
words = []
for line in open(file_name):
words += line.decode('utf-8', 'ignore').strip().split()
counter = collections.Counter(words)
counter = sorted(counter.items(), key=lambda x: -x[1])
if len(counter) > vocab_max_size:
counter = counter[:vocab_max_size]
words, counts = zip(*counter)
word_id_dict = dict(zip(words, range(2, len(words) + 2)))
word_id_dict['<UNK>'] = 0
word_id_dict['<EOS>'] = 1
return word_id_dict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册