From bb93d5c09349a30dcce55b1795374771632f4e10 Mon Sep 17 00:00:00 2001 From: wwhu Date: Mon, 8 May 2017 15:50:48 +0800 Subject: [PATCH] correct the code style --- .../random_schedule_generator.py | 2 -- scheduled_sampling/scheduled_sampling.py | 24 +++++++++++++------ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/scheduled_sampling/random_schedule_generator.py b/scheduled_sampling/random_schedule_generator.py index b86c867e..046dce63 100644 --- a/scheduled_sampling/random_schedule_generator.py +++ b/scheduled_sampling/random_schedule_generator.py @@ -1,7 +1,6 @@ import numpy as np import math import pdb - ''' The random sampling rate for scheduled sampling algoithm, which uses devcayed sampling rate. @@ -55,4 +54,3 @@ if __name__ == "__main__": schedule_generator = RandomScheduleGenerator("linear", 0.1, 500000) true_token_flag = schedule_generator.processBatch(5) pdb.set_trace() - pass \ No newline at end of file diff --git a/scheduled_sampling/scheduled_sampling.py b/scheduled_sampling/scheduled_sampling.py index 24c15756..e641c448 100644 --- a/scheduled_sampling/scheduled_sampling.py +++ b/scheduled_sampling/scheduled_sampling.py @@ -2,7 +2,6 @@ import sys import paddle.v2 as paddle from random_schedule_generator import RandomScheduleGenerator - schedule_generator = RandomScheduleGenerator("linear", 0.75, 1000000) @@ -19,6 +18,7 @@ def gen_schedule_data(reader): :return: the new reader with the field "true_token_flag". :rtype: callable """ + def data_reader(): for src_ids, trg_ids, trg_ids_next in reader(): yield src_ids, trg_ids, trg_ids_next, \ @@ -62,7 +62,8 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False): decoder_boot += paddle.layer.full_matrix_projection( input=backward_first) - def gru_decoder_with_attention_train(enc_vec, enc_proj, true_word, true_token_flag): + def gru_decoder_with_attention_train(enc_vec, enc_proj, true_word, + true_token_flag): decoder_mem = paddle.layer.memory( name='gru_decoder', size=decoder_size, boot_layer=decoder_boot) @@ -82,7 +83,8 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False): size=word_vector_dim, param_attr=paddle.attr.ParamAttr(name='_target_language_embedding')) - current_word = paddle.layer.multiplex(input=[true_token_flag, true_word, generated_word_emb]) + current_word = paddle.layer.multiplex( + input=[true_token_flag, true_word, generated_word_emb]) with paddle.layer.mixed(size=decoder_size * 3) as decoder_inputs: decoder_inputs += paddle.layer.full_matrix_projection(input=context) @@ -208,8 +210,12 @@ def main(): paddle.dataset.wmt14.train(dict_size), buf_size=8192)), batch_size=5) - feeding = {'source_language_word': 0, 'target_language_word': 1, - 'target_language_next_word': 2, 'true_token_flag': 3} + feeding = { + 'source_language_word': 0, + 'target_language_word': 1, + 'target_language_next_word': 2, + 'true_token_flag': 3 + } # define event_handler callback def event_handler(event): @@ -223,12 +229,16 @@ def main(): sys.stdout.flush() if isinstance(event, paddle.event.EndPass): # save parameters - with gzip.open('params_pass_%d.tar.gz' % event.pass_id, 'w') as f: + with gzip.open('params_pass_%d.tar.gz' % event.pass_id, + 'w') as f: parameters.to_tar(f) # start to train trainer.train( - reader=wmt14_reader, event_handler=event_handler, feeding=feeding, num_passes=2) + reader=wmt14_reader, + event_handler=event_handler, + feeding=feeding, + num_passes=2) # generate a english sequence to french else: -- GitLab