提交 95c0277d 编写于 作者: K kuke 提交者: Yibing

Rewrite bidi encoder & code cleanup

上级 0fa397ea
......@@ -2,7 +2,6 @@
import sys
import gzip
import sqlite3
import paddle.v2 as paddle
### Parameters
......@@ -16,29 +15,21 @@ max_length = 50
def seq2seq_net(source_dict_dim, target_dict_dim, generating=False):
decoder_size = encoder_size = latent_chain_dim
### Encoder
#### Encoder
src_word_id = paddle.layer.data(
name='source_language_word',
type=paddle.data_type.integer_value_sequence(source_dict_dim))
src_embedding = paddle.layer.embedding(
input=src_word_id, size=word_vector_dim)
encoder_forward = paddle.networks.simple_gru(
input=src_embedding,
act=paddle.activation.Tanh(),
gate_act=paddle.activation.Sigmoid(),
size=encoder_size,
reverse=False)
encoder_backward = paddle.networks.simple_gru(
# use bidirectional_gru
encoded_vector = paddle.networks.bidirectional_gru(
input=src_embedding,
act=paddle.activation.Tanh(),
gate_act=paddle.activation.Sigmoid(),
size=encoder_size,
reverse=True)
encoded_vector = paddle.layer.concat(
input=[encoder_forward, encoder_backward])
fwd_act=paddle.activation.Tanh(),
fwd_gate_act=paddle.activation.Sigmoid(),
bwd_act=paddle.activation.Tanh(),
bwd_gate_act=paddle.activation.Sigmoid(),
return_seq=True)
#### Decoder
encoder_last = paddle.layer.last_seq(input=encoded_vector)
with paddle.layer.mixed(
......@@ -146,18 +137,8 @@ def train(source_dict_dim, target_dict_dim):
parameters.to_tar(f)
if event.batch_id % 10 == 0:
# wmt14_test_batch = paddle.batch(
# paddle.reader.shuffle(
# paddle.dataset.wmt14.test(source_dict_dim),
# buf_size=8192), batch_size=1)
#test_result = trainer.test(wmt14_test_batch)
print "\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id,
event.batch_id,
event.cost,
event.metrics, # test_result.cost, test_result.metrics
)
event.pass_id, event.batch_id, event.cost, event.metrics)
else:
sys.stdout.write('.')
sys.stdout.flush()
......@@ -167,7 +148,7 @@ def train(source_dict_dim, target_dict_dim):
reader=wmt14_reader, event_handler=event_handler, num_passes=2)
def generate(source_dict_dim, target_dict_dim):
def generate(source_dict_dim, target_dict_dim, init_models_path):
# load data samples for generation
gen_creator = paddle.dataset.wmt14.gen(source_dict_dim)
gen_data = []
......@@ -175,8 +156,7 @@ def generate(source_dict_dim, target_dict_dim):
gen_data.append((item[0], ))
beam_gen = seq2seq_net(source_dict_dim, target_dict_dim, True)
with gzip.open('models/nmt_without_att_params_batch_400.tar.gz') as f:
with gzip.open(init_models_path) as f:
parameters = paddle.parameters.Parameters.from_tar(f)
# prob is the prediction probabilities, and id is the prediction word.
beam_result = paddle.infer(
......@@ -208,15 +188,37 @@ def generate(source_dict_dim, target_dict_dim):
print "prob = %f:" % (prob[i][j]), seq_list[i * beam_size + j]
def usage_helper():
print "Please specify training/generating phase!"
print "Usage: python nmt_without_attention_v2.py --train/generate"
exit(1)
def main():
if not (len(sys.argv) == 2):
usage_helper()
if sys.argv[1] == '--train':
generating = False
elif sys.argv[1] == '--generate':
generating = True
else:
usage_helper()
paddle.init(use_gpu=False, trainer_count=4)
source_language_dict_dim = 30000
target_language_dict_dim = 30000
generating = True
if generating:
generate(source_language_dict_dim, target_language_dict_dim)
# shoud pass the right generated model's path here
init_models_path = 'models/nmt_without_att_params_batch_400.tar.gz'
if not os.path.exists(init_models_path):
print "Cannot find models for generation"
exit(1)
generate(source_language_dict_dim, target_language_dict_dim,
init_models_path)
else:
if not os.path.exists('./models'):
os.system('mkdir ./models')
train(source_language_dict_dim, target_language_dict_dim)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册