diff --git a/demo/seqToseq/api_train_v2.py b/demo/seqToseq/api_train_v2.py index 29fa4b48b0745d62ee524cd9b687d5d5b940b885..2809054e7d3a367f441188fe7f91037cfa5f1579 100644 --- a/demo/seqToseq/api_train_v2.py +++ b/demo/seqToseq/api_train_v2.py @@ -3,7 +3,7 @@ import sys import paddle.v2 as paddle -def seqToseq_net(source_dict_dim, target_dict_dim, is_generating): +def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False): ### Network Architecture word_vector_dim = 512 # dimension of word vector decoder_size = 512 # dimension of hidden unit in GRU Decoder network @@ -120,13 +120,7 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating): eos_id=1, beam_size=beam_size, max_length=max_length) - # - # seqtext_printer_evaluator( - # input=beam_gen, - # id_input=data_layer( - # name="sent_id", size=1), - # dict_file=trg_dict_path, - # result_file=gen_trans_file) + return beam_gen @@ -138,7 +132,7 @@ def main(): source_dict_dim = target_dict_dim = dict_size # define network topology - cost = seqToseq_net(source_dict_dim, target_dict_dim, False) + cost = seqToseq_net(source_dict_dim, target_dict_dim) parameters = paddle.parameters.create(cost) # define optimize method and trainer diff --git a/python/paddle/v2/layer.py b/python/paddle/v2/layer.py index affb375ebfdae5e507d34d322cdbc066f18fce01..384de9b9d57f88e84ab6067846174bb037502dc0 100644 --- a/python/paddle/v2/layer.py +++ b/python/paddle/v2/layer.py @@ -526,8 +526,8 @@ def beam_search(step, assert num_results_per_sample <= beam_size # logger.warning("num_results_per_sample should be less than beam_size") - if isinstance(input, StaticInputV2) or isinstance( - input, BaseGeneratedInputV2): + if isinstance(input, StaticInputV2) or isinstance(input, + BaseGeneratedInputV2): input = [input] generated_input_index = -1 @@ -574,8 +574,7 @@ def beam_search(step, # reverse=False, # name=name, # is_generating=True) - tmp = recurrent_group( - step=__real_step__, input=real_input, name=name) + tmp = recurrent_group(step=__real_step__, input=real_input, name=name) return tmp