diff --git a/scheduled_sampling/random_schedule_generator.py b/scheduled_sampling/random_schedule_generator.py index b86c867e4aa13fb285ac81af78df87f34ce47ee4..046dce63520aa961ace5b538305b6134ee042a79 100644 --- a/scheduled_sampling/random_schedule_generator.py +++ b/scheduled_sampling/random_schedule_generator.py @@ -1,7 +1,6 @@ import numpy as np import math import pdb - ''' The random sampling rate for scheduled sampling algoithm, which uses devcayed sampling rate. @@ -55,4 +54,3 @@ if __name__ == "__main__": schedule_generator = RandomScheduleGenerator("linear", 0.1, 500000) true_token_flag = schedule_generator.processBatch(5) pdb.set_trace() - pass \ No newline at end of file diff --git a/scheduled_sampling/scheduled_sampling.py b/scheduled_sampling/scheduled_sampling.py index 24c15756b693d6acdbce4435db96bec91dfd623d..e641c44846d4a3ebb47190ef2a9293a9ecfd2f73 100644 --- a/scheduled_sampling/scheduled_sampling.py +++ b/scheduled_sampling/scheduled_sampling.py @@ -2,7 +2,6 @@ import sys import paddle.v2 as paddle from random_schedule_generator import RandomScheduleGenerator - schedule_generator = RandomScheduleGenerator("linear", 0.75, 1000000) @@ -19,6 +18,7 @@ def gen_schedule_data(reader): :return: the new reader with the field "true_token_flag". :rtype: callable """ + def data_reader(): for src_ids, trg_ids, trg_ids_next in reader(): yield src_ids, trg_ids, trg_ids_next, \ @@ -62,7 +62,8 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False): decoder_boot += paddle.layer.full_matrix_projection( input=backward_first) - def gru_decoder_with_attention_train(enc_vec, enc_proj, true_word, true_token_flag): + def gru_decoder_with_attention_train(enc_vec, enc_proj, true_word, + true_token_flag): decoder_mem = paddle.layer.memory( name='gru_decoder', size=decoder_size, boot_layer=decoder_boot) @@ -82,7 +83,8 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False): size=word_vector_dim, param_attr=paddle.attr.ParamAttr(name='_target_language_embedding')) - current_word = paddle.layer.multiplex(input=[true_token_flag, true_word, generated_word_emb]) + current_word = paddle.layer.multiplex( + input=[true_token_flag, true_word, generated_word_emb]) with paddle.layer.mixed(size=decoder_size * 3) as decoder_inputs: decoder_inputs += paddle.layer.full_matrix_projection(input=context) @@ -208,8 +210,12 @@ def main(): paddle.dataset.wmt14.train(dict_size), buf_size=8192)), batch_size=5) - feeding = {'source_language_word': 0, 'target_language_word': 1, - 'target_language_next_word': 2, 'true_token_flag': 3} + feeding = { + 'source_language_word': 0, + 'target_language_word': 1, + 'target_language_next_word': 2, + 'true_token_flag': 3 + } # define event_handler callback def event_handler(event): @@ -223,12 +229,16 @@ def main(): sys.stdout.flush() if isinstance(event, paddle.event.EndPass): # save parameters - with gzip.open('params_pass_%d.tar.gz' % event.pass_id, 'w') as f: + with gzip.open('params_pass_%d.tar.gz' % event.pass_id, + 'w') as f: parameters.to_tar(f) # start to train trainer.train( - reader=wmt14_reader, event_handler=event_handler, feeding=feeding, num_passes=2) + reader=wmt14_reader, + event_handler=event_handler, + feeding=feeding, + num_passes=2) # generate a english sequence to french else: