提交 b669b5fc 编写于 作者: Q qiaolongfei

fix style problem

上级 07a8f0ef
...@@ -3,7 +3,7 @@ import sys ...@@ -3,7 +3,7 @@ import sys
import paddle.v2 as paddle import paddle.v2 as paddle
def seqToseq_net(source_dict_dim, target_dict_dim, is_generating): def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
### Network Architecture ### Network Architecture
word_vector_dim = 512 # dimension of word vector word_vector_dim = 512 # dimension of word vector
decoder_size = 512 # dimension of hidden unit in GRU Decoder network decoder_size = 512 # dimension of hidden unit in GRU Decoder network
...@@ -120,13 +120,7 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating): ...@@ -120,13 +120,7 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating):
eos_id=1, eos_id=1,
beam_size=beam_size, beam_size=beam_size,
max_length=max_length) max_length=max_length)
#
# seqtext_printer_evaluator(
# input=beam_gen,
# id_input=data_layer(
# name="sent_id", size=1),
# dict_file=trg_dict_path,
# result_file=gen_trans_file)
return beam_gen return beam_gen
...@@ -138,7 +132,7 @@ def main(): ...@@ -138,7 +132,7 @@ def main():
source_dict_dim = target_dict_dim = dict_size source_dict_dim = target_dict_dim = dict_size
# define network topology # define network topology
cost = seqToseq_net(source_dict_dim, target_dict_dim, False) cost = seqToseq_net(source_dict_dim, target_dict_dim)
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
# define optimize method and trainer # define optimize method and trainer
......
...@@ -526,8 +526,8 @@ def beam_search(step, ...@@ -526,8 +526,8 @@ def beam_search(step,
assert num_results_per_sample <= beam_size assert num_results_per_sample <= beam_size
# logger.warning("num_results_per_sample should be less than beam_size") # logger.warning("num_results_per_sample should be less than beam_size")
if isinstance(input, StaticInputV2) or isinstance( if isinstance(input, StaticInputV2) or isinstance(input,
input, BaseGeneratedInputV2): BaseGeneratedInputV2):
input = [input] input = [input]
generated_input_index = -1 generated_input_index = -1
...@@ -574,8 +574,7 @@ def beam_search(step, ...@@ -574,8 +574,7 @@ def beam_search(step,
# reverse=False, # reverse=False,
# name=name, # name=name,
# is_generating=True) # is_generating=True)
tmp = recurrent_group( tmp = recurrent_group(step=__real_step__, input=real_input, name=name)
step=__real_step__, input=real_input, name=name)
return tmp return tmp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册