提交 bb93d5c0 编写于 作者: W wwhu

correct the code style

上级 363b62d1
import numpy as np
import math
import pdb
'''
The random sampling rate for scheduled sampling algoithm, which uses devcayed
sampling rate.
......@@ -55,4 +54,3 @@ if __name__ == "__main__":
schedule_generator = RandomScheduleGenerator("linear", 0.1, 500000)
true_token_flag = schedule_generator.processBatch(5)
pdb.set_trace()
pass
\ No newline at end of file
......@@ -2,7 +2,6 @@ import sys
import paddle.v2 as paddle
from random_schedule_generator import RandomScheduleGenerator
schedule_generator = RandomScheduleGenerator("linear", 0.75, 1000000)
......@@ -19,6 +18,7 @@ def gen_schedule_data(reader):
:return: the new reader with the field "true_token_flag".
:rtype: callable
"""
def data_reader():
for src_ids, trg_ids, trg_ids_next in reader():
yield src_ids, trg_ids, trg_ids_next, \
......@@ -62,7 +62,8 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
decoder_boot += paddle.layer.full_matrix_projection(
input=backward_first)
def gru_decoder_with_attention_train(enc_vec, enc_proj, true_word, true_token_flag):
def gru_decoder_with_attention_train(enc_vec, enc_proj, true_word,
true_token_flag):
decoder_mem = paddle.layer.memory(
name='gru_decoder', size=decoder_size, boot_layer=decoder_boot)
......@@ -82,7 +83,8 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
size=word_vector_dim,
param_attr=paddle.attr.ParamAttr(name='_target_language_embedding'))
current_word = paddle.layer.multiplex(input=[true_token_flag, true_word, generated_word_emb])
current_word = paddle.layer.multiplex(
input=[true_token_flag, true_word, generated_word_emb])
with paddle.layer.mixed(size=decoder_size * 3) as decoder_inputs:
decoder_inputs += paddle.layer.full_matrix_projection(input=context)
......@@ -208,8 +210,12 @@ def main():
paddle.dataset.wmt14.train(dict_size), buf_size=8192)),
batch_size=5)
feeding = {'source_language_word': 0, 'target_language_word': 1,
'target_language_next_word': 2, 'true_token_flag': 3}
feeding = {
'source_language_word': 0,
'target_language_word': 1,
'target_language_next_word': 2,
'true_token_flag': 3
}
# define event_handler callback
def event_handler(event):
......@@ -223,12 +229,16 @@ def main():
sys.stdout.flush()
if isinstance(event, paddle.event.EndPass):
# save parameters
with gzip.open('params_pass_%d.tar.gz' % event.pass_id, 'w') as f:
with gzip.open('params_pass_%d.tar.gz' % event.pass_id,
'w') as f:
parameters.to_tar(f)
# start to train
trainer.train(
reader=wmt14_reader, event_handler=event_handler, feeding=feeding, num_passes=2)
reader=wmt14_reader,
event_handler=event_handler,
feeding=feeding,
num_passes=2)
# generate a english sequence to french
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册