提交 e9a0aa86 编写于 作者: C caoying03

follow comments and rename the directory.

上级 6b0f946d
......@@ -13,7 +13,7 @@ __all__ = ["BeamSearch"]
class BeamSearch(object):
"""
generating sequence by using beam search
Generating sequence by beam search
NOTE: this class only implements generating one sentence at a time.
"""
......@@ -21,14 +21,14 @@ class BeamSearch(object):
"""
constructor method.
:param inferer: object of paddle.Inference that represent the entire
network to forward compute the test batch.
:param inferer: object of paddle.Inference that represents the entire
network to forward compute the test batch
:type inferer: paddle.Inference
:param word_dict_file: path of word dictionary file
:type word_dict_file: str
:param beam_size: expansion width in each iteration
:type param beam_size: int
:param max_gen_len: the maximum number of iterations.
:param max_gen_len: the maximum number of iterations
:type max_gen_len: int
"""
self.inferer = inferer
......@@ -43,7 +43,7 @@ class BeamSearch(object):
self.unk_id = next(x[0] for x in self.ids_2_word.iteritems()
if x[1] == "<unk>")
except StopIteration:
logger.fatal(("the word dictionay must contains an ending mark "
logger.fatal(("the word dictionay must contain an ending mark "
"in the text generation task."))
self.candidate_paths = []
......@@ -52,7 +52,7 @@ class BeamSearch(object):
def _top_k(self, softmax_out, k):
"""
get indices of the words with k highest probablities.
NOTE: <unk> will be exclued if it is among the top k words, then word
NOTE: <unk> will be excluded if it is among the top k words, then word
with (k + 1)th highest probability will be returned.
:param softmax_out: probablity over the dictionary
......@@ -71,7 +71,7 @@ class BeamSearch(object):
:params batch: the input data batch
:type batch: list
:return: probalities of the predicted word
:return: probablities of the predicted word
:rtype: ndarray
"""
return self.inferer.infer(input=batch, field=["value"])
......
......@@ -12,12 +12,18 @@ def rnn_lm(vocab_dim,
"""
RNN language model definition.
:param vocab_dim: size of vocab.
:param emb_dim: embedding vector"s dimension.
:param vocab_dim: size of vocabulary.
:type vocab_dim: int
:param emb_dim: dimension of the embedding vector
:type emb_dim: int
:param rnn_type: the type of RNN cell.
:param hidden_size: number of unit.
:param stacked_rnn_num: layer number.
:type rnn_type: int
:param hidden_size: number of hidden unit.
:type hidden_size: int
:param stacked_rnn_num: number of stacked rnn cell.
:type stacked_rnn_num: int
:return: cost and output layer of model.
:rtype: LayerOutput
"""
# input layers
......
......@@ -20,12 +20,16 @@ def train(topology,
"""
train model.
:param model_cost: cost layer of the model to train.
:param topology: cost layer of the model to train.
:type topology: LayerOuput
:param train_reader: train data reader.
:type trainer_reader: collections.Iterable
:param test_reader: test data reader.
:param model_file_name_prefix: model"s prefix name.
:param num_passes: epoch.
:return:
:type test_reader: collections.Iterable
:param model_save_dir: path to save the trained model
:type model_save_dir: str
:param num_passes: number of epoch
:type num_passes: int
"""
if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir)
......
......@@ -17,14 +17,19 @@ def build_dict(data_file,
insert_extra_words=["<unk>", "<e>"]):
"""
:param data_file: path of data file
:type data_file: str
:param save_path: path to save the word dictionary
:type save_path: str
:param vocab_max_size: if vocab_max_size is set, top vocab_max_size words
will be added into word vocabulary
:type vocab_max_size: int
:param cutoff_thd: if cutoff_thd is set, words whose frequencies are less
than cutoff_thd will not added into word vocabulary.
than cutoff_thd will not be added into word vocabulary.
NOTE that: vocab_max_size and cutoff_thd cannot be set at the same time
:type cutoff_word_fre: int
:param extra_keys: extra keys defined by users that added into the word
dictionary, ususally these keys includes <unk>, start and ending marks
dictionary, ususally these keys include <unk>, start and ending marks
:type extra_keys: list
"""
word_count = defaultdict(int)
with open(data_file, "r") as f:
......@@ -53,12 +58,29 @@ def build_dict(data_file,
def load_dict(dict_path):
"""
load word dictionary from the given file. Each line of the give file is
a word in the word dictionary. The first column of the line, seperated by
TAB, is the key, while the line index is the value.
:param dict_path: path of word dictionary
:type dict_path: str
:return: the dictionary
:rtype: dict
"""
return dict((line.strip().split("\t")[0], idx)
for idx, line in enumerate(open(dict_path, "r").readlines()))
def load_reverse_dict(dict_path):
"""
load word dictionary from the given file. Each line of the give file is
a word in the word dictionary. The line index is the key, while the first
column of the line, seperated by TAB, is the value.
:param dict_path: path of word dictionary
:type dict_path: str
:return: the dictionary
:rtype: dict
"""
return dict((idx, line.strip().split("\t")[0])
for idx, line in enumerate(open(dict_path, "r").readlines()))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册