提交 22c5add9 编写于 作者: Y Yibing

update README & source code

上级 48b4a5cb
此差异已折叠。
...@@ -136,6 +136,7 @@ def train(source_dict_dim, target_dict_dim): ...@@ -136,6 +136,7 @@ def train(source_dict_dim, target_dict_dim):
:param target_dict_dim: size of target dictionary :param target_dict_dim: size of target dictionary
:type target_dict_dim: int :type target_dict_dim: int
''' '''
# initialize model
cost = seq2seq_net(source_dict_dim, target_dict_dim) cost = seq2seq_net(source_dict_dim, target_dict_dim)
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
...@@ -180,6 +181,8 @@ def generate(source_dict_dim, target_dict_dim, init_models_path): ...@@ -180,6 +181,8 @@ def generate(source_dict_dim, target_dict_dim, init_models_path):
:type source_dict_dim: int :type source_dict_dim: int
:param target_dict_dim: size of target dictionary :param target_dict_dim: size of target dictionary
:type target_dict_dim: int :type target_dict_dim: int
:param init_models_path: path for inital model
:type init_models_path: string
''' '''
# load data samples for generation # load data samples for generation
...@@ -203,8 +206,7 @@ def generate(source_dict_dim, target_dict_dim, init_models_path): ...@@ -203,8 +206,7 @@ def generate(source_dict_dim, target_dict_dim, init_models_path):
# the delimited element of generated sequences is -1, # the delimited element of generated sequences is -1,
# the first element of each generated sequence is the sequence length # the first element of each generated sequence is the sequence length
seq_list = [] seq_list, seq = [], []
seq = []
for w in beam_result[1]: for w in beam_result[1]:
if w != -1: if w != -1:
seq.append(w) seq.append(w)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册