Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
e9a0aa86
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e9a0aa86
编写于
6月 26, 2017
作者:
C
caoying03
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
follow comments and rename the directory.
上级
6b0f946d
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
52 addition
and
17 deletion
+52
-17
generate_sequence_by_rnn_lm/.gitignore
generate_sequence_by_rnn_lm/.gitignore
+3
-0
generate_sequence_by_rnn_lm/README.md
generate_sequence_by_rnn_lm/README.md
+0
-0
generate_sequence_by_rnn_lm/beam_search.py
generate_sequence_by_rnn_lm/beam_search.py
+7
-7
generate_sequence_by_rnn_lm/config.py
generate_sequence_by_rnn_lm/config.py
+0
-0
generate_sequence_by_rnn_lm/data/train_data_examples.txt
generate_sequence_by_rnn_lm/data/train_data_examples.txt
+0
-0
generate_sequence_by_rnn_lm/generate.py
generate_sequence_by_rnn_lm/generate.py
+0
-0
generate_sequence_by_rnn_lm/images/ngram.png
generate_sequence_by_rnn_lm/images/ngram.png
+0
-0
generate_sequence_by_rnn_lm/images/rnn.png
generate_sequence_by_rnn_lm/images/rnn.png
+0
-0
generate_sequence_by_rnn_lm/index.html
generate_sequence_by_rnn_lm/index.html
+0
-0
generate_sequence_by_rnn_lm/network_conf.py
generate_sequence_by_rnn_lm/network_conf.py
+10
-4
generate_sequence_by_rnn_lm/reader.py
generate_sequence_by_rnn_lm/reader.py
+0
-0
generate_sequence_by_rnn_lm/train.py
generate_sequence_by_rnn_lm/train.py
+8
-4
generate_sequence_by_rnn_lm/utils.py
generate_sequence_by_rnn_lm/utils.py
+24
-2
未找到文件。
generate_sequence_by_rnn_lm/.gitignore
0 → 100644
浏览文件 @
e9a0aa86
*.pyc
*.tar.gz
models
language_model
/README.md
→
generate_sequence_by_rnn_lm
/README.md
浏览文件 @
e9a0aa86
文件已移动
language_model
/beam_search.py
→
generate_sequence_by_rnn_lm
/beam_search.py
浏览文件 @
e9a0aa86
...
...
@@ -13,7 +13,7 @@ __all__ = ["BeamSearch"]
class
BeamSearch
(
object
):
"""
generating sequence by using
beam search
Generating sequence by
beam search
NOTE: this class only implements generating one sentence at a time.
"""
...
...
@@ -21,14 +21,14 @@ class BeamSearch(object):
"""
constructor method.
:param inferer: object of paddle.Inference that represent the entire
network to forward compute the test batch
.
:param inferer: object of paddle.Inference that represent
s
the entire
network to forward compute the test batch
:type inferer: paddle.Inference
:param word_dict_file: path of word dictionary file
:type word_dict_file: str
:param beam_size: expansion width in each iteration
:type param beam_size: int
:param max_gen_len: the maximum number of iterations
.
:param max_gen_len: the maximum number of iterations
:type max_gen_len: int
"""
self
.
inferer
=
inferer
...
...
@@ -43,7 +43,7 @@ class BeamSearch(object):
self
.
unk_id
=
next
(
x
[
0
]
for
x
in
self
.
ids_2_word
.
iteritems
()
if
x
[
1
]
==
"<unk>"
)
except
StopIteration
:
logger
.
fatal
((
"the word dictionay must contain
s
an ending mark "
logger
.
fatal
((
"the word dictionay must contain an ending mark "
"in the text generation task."
))
self
.
candidate_paths
=
[]
...
...
@@ -52,7 +52,7 @@ class BeamSearch(object):
def
_top_k
(
self
,
softmax_out
,
k
):
"""
get indices of the words with k highest probablities.
NOTE: <unk> will be exclued if it is among the top k words, then word
NOTE: <unk> will be exclu
d
ed if it is among the top k words, then word
with (k + 1)th highest probability will be returned.
:param softmax_out: probablity over the dictionary
...
...
@@ -71,7 +71,7 @@ class BeamSearch(object):
:params batch: the input data batch
:type batch: list
:return: probalities of the predicted word
:return: proba
b
lities of the predicted word
:rtype: ndarray
"""
return
self
.
inferer
.
infer
(
input
=
batch
,
field
=
[
"value"
])
...
...
language_model
/config.py
→
generate_sequence_by_rnn_lm
/config.py
浏览文件 @
e9a0aa86
文件已移动
language_model
/data/train_data_examples.txt
→
generate_sequence_by_rnn_lm
/data/train_data_examples.txt
浏览文件 @
e9a0aa86
文件已移动
language_model
/generate.py
→
generate_sequence_by_rnn_lm
/generate.py
浏览文件 @
e9a0aa86
文件已移动
language_model
/images/ngram.png
→
generate_sequence_by_rnn_lm
/images/ngram.png
浏览文件 @
e9a0aa86
文件已移动
language_model
/images/rnn.png
→
generate_sequence_by_rnn_lm
/images/rnn.png
浏览文件 @
e9a0aa86
文件已移动
language_model
/index.html
→
generate_sequence_by_rnn_lm
/index.html
浏览文件 @
e9a0aa86
文件已移动
language_model
/network_conf.py
→
generate_sequence_by_rnn_lm
/network_conf.py
浏览文件 @
e9a0aa86
...
...
@@ -12,12 +12,18 @@ def rnn_lm(vocab_dim,
"""
RNN language model definition.
:param vocab_dim: size of vocab.
:param emb_dim: embedding vector"s dimension.
:param vocab_dim: size of vocabulary.
:type vocab_dim: int
:param emb_dim: dimension of the embedding vector
:type emb_dim: int
:param rnn_type: the type of RNN cell.
:param hidden_size: number of unit.
:param stacked_rnn_num: layer number.
:type rnn_type: int
:param hidden_size: number of hidden unit.
:type hidden_size: int
:param stacked_rnn_num: number of stacked rnn cell.
:type stacked_rnn_num: int
:return: cost and output layer of model.
:rtype: LayerOutput
"""
# input layers
...
...
language_model
/reader.py
→
generate_sequence_by_rnn_lm
/reader.py
浏览文件 @
e9a0aa86
文件已移动
language_model
/train.py
→
generate_sequence_by_rnn_lm
/train.py
浏览文件 @
e9a0aa86
...
...
@@ -20,12 +20,16 @@ def train(topology,
"""
train model.
:param model_cost: cost layer of the model to train.
:param topology: cost layer of the model to train.
:type topology: LayerOuput
:param train_reader: train data reader.
:type trainer_reader: collections.Iterable
:param test_reader: test data reader.
:param model_file_name_prefix: model"s prefix name.
:param num_passes: epoch.
:return:
:type test_reader: collections.Iterable
:param model_save_dir: path to save the trained model
:type model_save_dir: str
:param num_passes: number of epoch
:type num_passes: int
"""
if
not
os
.
path
.
exists
(
model_save_dir
):
os
.
mkdir
(
model_save_dir
)
...
...
language_model
/utils.py
→
generate_sequence_by_rnn_lm
/utils.py
浏览文件 @
e9a0aa86
...
...
@@ -17,14 +17,19 @@ def build_dict(data_file,
insert_extra_words
=
[
"<unk>"
,
"<e>"
]):
"""
:param data_file: path of data file
:type data_file: str
:param save_path: path to save the word dictionary
:type save_path: str
:param vocab_max_size: if vocab_max_size is set, top vocab_max_size words
will be added into word vocabulary
:type vocab_max_size: int
:param cutoff_thd: if cutoff_thd is set, words whose frequencies are less
than cutoff_thd will not added into word vocabulary.
than cutoff_thd will not
be
added into word vocabulary.
NOTE that: vocab_max_size and cutoff_thd cannot be set at the same time
:type cutoff_word_fre: int
:param extra_keys: extra keys defined by users that added into the word
dictionary, ususally these keys includes <unk>, start and ending marks
dictionary, ususally these keys include <unk>, start and ending marks
:type extra_keys: list
"""
word_count
=
defaultdict
(
int
)
with
open
(
data_file
,
"r"
)
as
f
:
...
...
@@ -53,12 +58,29 @@ def build_dict(data_file,
def
load_dict
(
dict_path
):
"""
load word dictionary from the given file. Each line of the give file is
a word in the word dictionary. The first column of the line, seperated by
TAB, is the key, while the line index is the value.
:param dict_path: path of word dictionary
:type dict_path: str
:return: the dictionary
:rtype: dict
"""
return
dict
((
line
.
strip
().
split
(
"
\t
"
)[
0
],
idx
)
for
idx
,
line
in
enumerate
(
open
(
dict_path
,
"r"
).
readlines
()))
def
load_reverse_dict
(
dict_path
):
"""
load word dictionary from the given file. Each line of the give file is
a word in the word dictionary. The line index is the key, while the first
column of the line, seperated by TAB, is the value.
:param dict_path: path of word dictionary
:type dict_path: str
:return: the dictionary
:rtype: dict
"""
return
dict
((
idx
,
line
.
strip
().
split
(
"
\t
"
)[
0
])
for
idx
,
line
in
enumerate
(
open
(
dict_path
,
"r"
).
readlines
()))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录