提交 48b4a5cb 编写于 作者: Y Yibing

add annotations for functions

上级 668b1061
...@@ -13,6 +13,15 @@ max_length = 50 ...@@ -13,6 +13,15 @@ max_length = 50
def seq2seq_net(source_dict_dim, target_dict_dim, generating=False): def seq2seq_net(source_dict_dim, target_dict_dim, generating=False):
'''
Define the network structure of NMT, including encoder and decoder.
:param source_dict_dim: size of source dictionary
:type source_dict_dim : int
:param target_dict_dim: size of target dictionary
:type target_dict_dim: int
'''
decoder_size = encoder_size = latent_chain_dim decoder_size = encoder_size = latent_chain_dim
#### Encoder #### Encoder
...@@ -39,7 +48,14 @@ def seq2seq_net(source_dict_dim, target_dict_dim, generating=False): ...@@ -39,7 +48,14 @@ def seq2seq_net(source_dict_dim, target_dict_dim, generating=False):
input=encoder_last) input=encoder_last)
# gru step # gru step
def gru_decoder_without_attention(enc_vec, current_word): def gru_decoder_without_attention(enc_vec, current_word):
'''
Step function for gru decoder
:param enc_vec: encoded vector of source language
:type enc_vec: layer object
:param current_word: current input of decoder
:type current_word: layer object
'''
decoder_mem = paddle.layer.memory( decoder_mem = paddle.layer.memory(
name='gru_decoder', name='gru_decoder',
size=decoder_size, size=decoder_size,
...@@ -112,6 +128,14 @@ def seq2seq_net(source_dict_dim, target_dict_dim, generating=False): ...@@ -112,6 +128,14 @@ def seq2seq_net(source_dict_dim, target_dict_dim, generating=False):
def train(source_dict_dim, target_dict_dim): def train(source_dict_dim, target_dict_dim):
'''
Training function for NMT
:param source_dict_dim: size of source dictionary
:type source_dict_dim: int
:param target_dict_dim: size of target dictionary
:type target_dict_dim: int
'''
cost = seq2seq_net(source_dict_dim, target_dict_dim) cost = seq2seq_net(source_dict_dim, target_dict_dim)
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
...@@ -149,6 +173,15 @@ def train(source_dict_dim, target_dict_dim): ...@@ -149,6 +173,15 @@ def train(source_dict_dim, target_dict_dim):
def generate(source_dict_dim, target_dict_dim, init_models_path): def generate(source_dict_dim, target_dict_dim, init_models_path):
'''
Generating function for NMT
:param source_dict_dim: size of source dictionary
:type source_dict_dim: int
:param target_dict_dim: size of target dictionary
:type target_dict_dim: int
'''
# load data samples for generation # load data samples for generation
gen_creator = paddle.dataset.wmt14.gen(source_dict_dim) gen_creator = paddle.dataset.wmt14.gen(source_dict_dim)
gen_data = [] gen_data = []
...@@ -203,7 +236,8 @@ def main(): ...@@ -203,7 +236,8 @@ def main():
else: else:
usage_helper() usage_helper()
paddle.init(use_gpu=False, trainer_count=4) # initialize paddle
paddle.init(use_gpu=False, trainer_count=1)
source_language_dict_dim = 30000 source_language_dict_dim = 30000
target_language_dict_dim = 30000 target_language_dict_dim = 30000
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册