提交 48b4a5cb 编写于 作者: Y Yibing

add annotations for functions

上级 668b1061
......@@ -13,6 +13,15 @@ max_length = 50
def seq2seq_net(source_dict_dim, target_dict_dim, generating=False):
'''
Define the network structure of NMT, including encoder and decoder.
:param source_dict_dim: size of source dictionary
:type source_dict_dim : int
:param target_dict_dim: size of target dictionary
:type target_dict_dim: int
'''
decoder_size = encoder_size = latent_chain_dim
#### Encoder
......@@ -39,7 +48,14 @@ def seq2seq_net(source_dict_dim, target_dict_dim, generating=False):
input=encoder_last)
# gru step
def gru_decoder_without_attention(enc_vec, current_word):
'''
Step function for gru decoder
:param enc_vec: encoded vector of source language
:type enc_vec: layer object
:param current_word: current input of decoder
:type current_word: layer object
'''
decoder_mem = paddle.layer.memory(
name='gru_decoder',
size=decoder_size,
......@@ -112,6 +128,14 @@ def seq2seq_net(source_dict_dim, target_dict_dim, generating=False):
def train(source_dict_dim, target_dict_dim):
'''
Training function for NMT
:param source_dict_dim: size of source dictionary
:type source_dict_dim: int
:param target_dict_dim: size of target dictionary
:type target_dict_dim: int
'''
cost = seq2seq_net(source_dict_dim, target_dict_dim)
parameters = paddle.parameters.create(cost)
......@@ -149,6 +173,15 @@ def train(source_dict_dim, target_dict_dim):
def generate(source_dict_dim, target_dict_dim, init_models_path):
'''
Generating function for NMT
:param source_dict_dim: size of source dictionary
:type source_dict_dim: int
:param target_dict_dim: size of target dictionary
:type target_dict_dim: int
'''
# load data samples for generation
gen_creator = paddle.dataset.wmt14.gen(source_dict_dim)
gen_data = []
......@@ -203,7 +236,8 @@ def main():
else:
usage_helper()
paddle.init(use_gpu=False, trainer_count=4)
# initialize paddle
paddle.init(use_gpu=False, trainer_count=1)
source_language_dict_dim = 30000
target_language_dict_dim = 30000
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册