提交 b669b5fc 编写于 作者: Q qiaolongfei

fix style problem

上级 07a8f0ef
......@@ -3,7 +3,7 @@ import sys
import paddle.v2 as paddle
def seqToseq_net(source_dict_dim, target_dict_dim, is_generating):
def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
### Network Architecture
word_vector_dim = 512 # dimension of word vector
decoder_size = 512 # dimension of hidden unit in GRU Decoder network
......@@ -120,13 +120,7 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating):
eos_id=1,
beam_size=beam_size,
max_length=max_length)
#
# seqtext_printer_evaluator(
# input=beam_gen,
# id_input=data_layer(
# name="sent_id", size=1),
# dict_file=trg_dict_path,
# result_file=gen_trans_file)
return beam_gen
......@@ -138,7 +132,7 @@ def main():
source_dict_dim = target_dict_dim = dict_size
# define network topology
cost = seqToseq_net(source_dict_dim, target_dict_dim, False)
cost = seqToseq_net(source_dict_dim, target_dict_dim)
parameters = paddle.parameters.create(cost)
# define optimize method and trainer
......
......@@ -526,8 +526,8 @@ def beam_search(step,
assert num_results_per_sample <= beam_size
# logger.warning("num_results_per_sample should be less than beam_size")
if isinstance(input, StaticInputV2) or isinstance(
input, BaseGeneratedInputV2):
if isinstance(input, StaticInputV2) or isinstance(input,
BaseGeneratedInputV2):
input = [input]
generated_input_index = -1
......@@ -574,8 +574,7 @@ def beam_search(step,
# reverse=False,
# name=name,
# is_generating=True)
tmp = recurrent_group(
step=__real_step__, input=real_input, name=name)
tmp = recurrent_group(step=__real_step__, input=real_input, name=name)
return tmp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册