提交 e9a0aa86 编写于 作者: C caoying03

follow comments and rename the directory.

上级 6b0f946d
...@@ -13,7 +13,7 @@ __all__ = ["BeamSearch"] ...@@ -13,7 +13,7 @@ __all__ = ["BeamSearch"]
class BeamSearch(object): class BeamSearch(object):
""" """
generating sequence by using beam search Generating sequence by beam search
NOTE: this class only implements generating one sentence at a time. NOTE: this class only implements generating one sentence at a time.
""" """
...@@ -21,14 +21,14 @@ class BeamSearch(object): ...@@ -21,14 +21,14 @@ class BeamSearch(object):
""" """
constructor method. constructor method.
:param inferer: object of paddle.Inference that represent the entire :param inferer: object of paddle.Inference that represents the entire
network to forward compute the test batch. network to forward compute the test batch
:type inferer: paddle.Inference :type inferer: paddle.Inference
:param word_dict_file: path of word dictionary file :param word_dict_file: path of word dictionary file
:type word_dict_file: str :type word_dict_file: str
:param beam_size: expansion width in each iteration :param beam_size: expansion width in each iteration
:type param beam_size: int :type param beam_size: int
:param max_gen_len: the maximum number of iterations. :param max_gen_len: the maximum number of iterations
:type max_gen_len: int :type max_gen_len: int
""" """
self.inferer = inferer self.inferer = inferer
...@@ -43,7 +43,7 @@ class BeamSearch(object): ...@@ -43,7 +43,7 @@ class BeamSearch(object):
self.unk_id = next(x[0] for x in self.ids_2_word.iteritems() self.unk_id = next(x[0] for x in self.ids_2_word.iteritems()
if x[1] == "<unk>") if x[1] == "<unk>")
except StopIteration: except StopIteration:
logger.fatal(("the word dictionay must contains an ending mark " logger.fatal(("the word dictionay must contain an ending mark "
"in the text generation task.")) "in the text generation task."))
self.candidate_paths = [] self.candidate_paths = []
...@@ -52,7 +52,7 @@ class BeamSearch(object): ...@@ -52,7 +52,7 @@ class BeamSearch(object):
def _top_k(self, softmax_out, k): def _top_k(self, softmax_out, k):
""" """
get indices of the words with k highest probablities. get indices of the words with k highest probablities.
NOTE: <unk> will be exclued if it is among the top k words, then word NOTE: <unk> will be excluded if it is among the top k words, then word
with (k + 1)th highest probability will be returned. with (k + 1)th highest probability will be returned.
:param softmax_out: probablity over the dictionary :param softmax_out: probablity over the dictionary
...@@ -71,7 +71,7 @@ class BeamSearch(object): ...@@ -71,7 +71,7 @@ class BeamSearch(object):
:params batch: the input data batch :params batch: the input data batch
:type batch: list :type batch: list
:return: probalities of the predicted word :return: probablities of the predicted word
:rtype: ndarray :rtype: ndarray
""" """
return self.inferer.infer(input=batch, field=["value"]) return self.inferer.infer(input=batch, field=["value"])
......
...@@ -12,12 +12,18 @@ def rnn_lm(vocab_dim, ...@@ -12,12 +12,18 @@ def rnn_lm(vocab_dim,
""" """
RNN language model definition. RNN language model definition.
:param vocab_dim: size of vocab. :param vocab_dim: size of vocabulary.
:param emb_dim: embedding vector"s dimension. :type vocab_dim: int
:param emb_dim: dimension of the embedding vector
:type emb_dim: int
:param rnn_type: the type of RNN cell. :param rnn_type: the type of RNN cell.
:param hidden_size: number of unit. :type rnn_type: int
:param stacked_rnn_num: layer number. :param hidden_size: number of hidden unit.
:type hidden_size: int
:param stacked_rnn_num: number of stacked rnn cell.
:type stacked_rnn_num: int
:return: cost and output layer of model. :return: cost and output layer of model.
:rtype: LayerOutput
""" """
# input layers # input layers
......
...@@ -20,12 +20,16 @@ def train(topology, ...@@ -20,12 +20,16 @@ def train(topology,
""" """
train model. train model.
:param model_cost: cost layer of the model to train. :param topology: cost layer of the model to train.
:type topology: LayerOuput
:param train_reader: train data reader. :param train_reader: train data reader.
:type trainer_reader: collections.Iterable
:param test_reader: test data reader. :param test_reader: test data reader.
:param model_file_name_prefix: model"s prefix name. :type test_reader: collections.Iterable
:param num_passes: epoch. :param model_save_dir: path to save the trained model
:return: :type model_save_dir: str
:param num_passes: number of epoch
:type num_passes: int
""" """
if not os.path.exists(model_save_dir): if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir) os.mkdir(model_save_dir)
......
...@@ -17,14 +17,19 @@ def build_dict(data_file, ...@@ -17,14 +17,19 @@ def build_dict(data_file,
insert_extra_words=["<unk>", "<e>"]): insert_extra_words=["<unk>", "<e>"]):
""" """
:param data_file: path of data file :param data_file: path of data file
:type data_file: str
:param save_path: path to save the word dictionary :param save_path: path to save the word dictionary
:type save_path: str
:param vocab_max_size: if vocab_max_size is set, top vocab_max_size words :param vocab_max_size: if vocab_max_size is set, top vocab_max_size words
will be added into word vocabulary will be added into word vocabulary
:type vocab_max_size: int
:param cutoff_thd: if cutoff_thd is set, words whose frequencies are less :param cutoff_thd: if cutoff_thd is set, words whose frequencies are less
than cutoff_thd will not added into word vocabulary. than cutoff_thd will not be added into word vocabulary.
NOTE that: vocab_max_size and cutoff_thd cannot be set at the same time NOTE that: vocab_max_size and cutoff_thd cannot be set at the same time
:type cutoff_word_fre: int
:param extra_keys: extra keys defined by users that added into the word :param extra_keys: extra keys defined by users that added into the word
dictionary, ususally these keys includes <unk>, start and ending marks dictionary, ususally these keys include <unk>, start and ending marks
:type extra_keys: list
""" """
word_count = defaultdict(int) word_count = defaultdict(int)
with open(data_file, "r") as f: with open(data_file, "r") as f:
...@@ -53,12 +58,29 @@ def build_dict(data_file, ...@@ -53,12 +58,29 @@ def build_dict(data_file,
def load_dict(dict_path): def load_dict(dict_path):
""" """
load word dictionary from the given file. Each line of the give file is
a word in the word dictionary. The first column of the line, seperated by
TAB, is the key, while the line index is the value.
:param dict_path: path of word dictionary :param dict_path: path of word dictionary
:type dict_path: str
:return: the dictionary
:rtype: dict
""" """
return dict((line.strip().split("\t")[0], idx) return dict((line.strip().split("\t")[0], idx)
for idx, line in enumerate(open(dict_path, "r").readlines())) for idx, line in enumerate(open(dict_path, "r").readlines()))
def load_reverse_dict(dict_path): def load_reverse_dict(dict_path):
"""
load word dictionary from the given file. Each line of the give file is
a word in the word dictionary. The line index is the key, while the first
column of the line, seperated by TAB, is the value.
:param dict_path: path of word dictionary
:type dict_path: str
:return: the dictionary
:rtype: dict
"""
return dict((idx, line.strip().split("\t")[0]) return dict((idx, line.strip().split("\t")[0])
for idx, line in enumerate(open(dict_path, "r").readlines())) for idx, line in enumerate(open(dict_path, "r").readlines()))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册