提交 bb93d5c0 编写于 作者: W wwhu

correct the code style

上级 363b62d1
import numpy as np import numpy as np
import math import math
import pdb import pdb
''' '''
The random sampling rate for scheduled sampling algoithm, which uses devcayed The random sampling rate for scheduled sampling algoithm, which uses devcayed
sampling rate. sampling rate.
...@@ -55,4 +54,3 @@ if __name__ == "__main__": ...@@ -55,4 +54,3 @@ if __name__ == "__main__":
schedule_generator = RandomScheduleGenerator("linear", 0.1, 500000) schedule_generator = RandomScheduleGenerator("linear", 0.1, 500000)
true_token_flag = schedule_generator.processBatch(5) true_token_flag = schedule_generator.processBatch(5)
pdb.set_trace() pdb.set_trace()
pass
\ No newline at end of file
...@@ -2,7 +2,6 @@ import sys ...@@ -2,7 +2,6 @@ import sys
import paddle.v2 as paddle import paddle.v2 as paddle
from random_schedule_generator import RandomScheduleGenerator from random_schedule_generator import RandomScheduleGenerator
schedule_generator = RandomScheduleGenerator("linear", 0.75, 1000000) schedule_generator = RandomScheduleGenerator("linear", 0.75, 1000000)
...@@ -19,6 +18,7 @@ def gen_schedule_data(reader): ...@@ -19,6 +18,7 @@ def gen_schedule_data(reader):
:return: the new reader with the field "true_token_flag". :return: the new reader with the field "true_token_flag".
:rtype: callable :rtype: callable
""" """
def data_reader(): def data_reader():
for src_ids, trg_ids, trg_ids_next in reader(): for src_ids, trg_ids, trg_ids_next in reader():
yield src_ids, trg_ids, trg_ids_next, \ yield src_ids, trg_ids, trg_ids_next, \
...@@ -62,7 +62,8 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False): ...@@ -62,7 +62,8 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
decoder_boot += paddle.layer.full_matrix_projection( decoder_boot += paddle.layer.full_matrix_projection(
input=backward_first) input=backward_first)
def gru_decoder_with_attention_train(enc_vec, enc_proj, true_word, true_token_flag): def gru_decoder_with_attention_train(enc_vec, enc_proj, true_word,
true_token_flag):
decoder_mem = paddle.layer.memory( decoder_mem = paddle.layer.memory(
name='gru_decoder', size=decoder_size, boot_layer=decoder_boot) name='gru_decoder', size=decoder_size, boot_layer=decoder_boot)
...@@ -82,7 +83,8 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False): ...@@ -82,7 +83,8 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
size=word_vector_dim, size=word_vector_dim,
param_attr=paddle.attr.ParamAttr(name='_target_language_embedding')) param_attr=paddle.attr.ParamAttr(name='_target_language_embedding'))
current_word = paddle.layer.multiplex(input=[true_token_flag, true_word, generated_word_emb]) current_word = paddle.layer.multiplex(
input=[true_token_flag, true_word, generated_word_emb])
with paddle.layer.mixed(size=decoder_size * 3) as decoder_inputs: with paddle.layer.mixed(size=decoder_size * 3) as decoder_inputs:
decoder_inputs += paddle.layer.full_matrix_projection(input=context) decoder_inputs += paddle.layer.full_matrix_projection(input=context)
...@@ -208,8 +210,12 @@ def main(): ...@@ -208,8 +210,12 @@ def main():
paddle.dataset.wmt14.train(dict_size), buf_size=8192)), paddle.dataset.wmt14.train(dict_size), buf_size=8192)),
batch_size=5) batch_size=5)
feeding = {'source_language_word': 0, 'target_language_word': 1, feeding = {
'target_language_next_word': 2, 'true_token_flag': 3} 'source_language_word': 0,
'target_language_word': 1,
'target_language_next_word': 2,
'true_token_flag': 3
}
# define event_handler callback # define event_handler callback
def event_handler(event): def event_handler(event):
...@@ -223,12 +229,16 @@ def main(): ...@@ -223,12 +229,16 @@ def main():
sys.stdout.flush() sys.stdout.flush()
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
# save parameters # save parameters
with gzip.open('params_pass_%d.tar.gz' % event.pass_id, 'w') as f: with gzip.open('params_pass_%d.tar.gz' % event.pass_id,
'w') as f:
parameters.to_tar(f) parameters.to_tar(f)
# start to train # start to train
trainer.train( trainer.train(
reader=wmt14_reader, event_handler=event_handler, feeding=feeding, num_passes=2) reader=wmt14_reader,
event_handler=event_handler,
feeding=feeding,
num_passes=2)
# generate a english sequence to french # generate a english sequence to french
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册