提交 a42233c2 编写于 作者: L Luo Tao

add wmt14 trg_dict

上级 caffcc83
...@@ -126,33 +126,28 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False): ...@@ -126,33 +126,28 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
def main(): def main():
paddle.init(use_gpu=False, trainer_count=1) paddle.init(use_gpu=False, trainer_count=1)
is_generating = True
# source and target dict dim. # source and target dict dim.
dict_size = 30000 dict_size = 30000
source_dict_dim = target_dict_dim = dict_size source_dict_dim = target_dict_dim = dict_size
# define network topology # train the network
if not is_generating:
cost = seqToseq_net(source_dict_dim, target_dict_dim) cost = seqToseq_net(source_dict_dim, target_dict_dim)
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
# define optimize method and trainer # define optimize method and trainer
optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.Adam(
learning_rate=5e-5, learning_rate=5e-5,
regularization=paddle.optimizer.L2Regularization(rate=1e-3)) regularization=paddle.optimizer.L2Regularization(rate=8e-4))
trainer = paddle.trainer.SGD(cost=cost, trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters, parameters=parameters,
update_equation=optimizer) update_equation=optimizer)
# define data reader # define data reader
feeding = {
'source_language_word': 0,
'target_language_word': 1,
'target_language_next_word': 2
}
wmt14_reader = paddle.batch( wmt14_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.wmt14.train(dict_size=dict_size), buf_size=8192), paddle.dataset.wmt14.train(dict_size), buf_size=8192),
batch_size=5) batch_size=5)
# define event_handler callback # define event_handler callback
...@@ -160,17 +155,28 @@ def main(): ...@@ -160,17 +155,28 @@ def main():
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 10 == 0: if event.batch_id % 10 == 0:
print "\nPass %d, Batch %d, Cost %f, %s" % ( print "\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics) event.pass_id, event.batch_id, event.cost,
event.metrics)
else: else:
sys.stdout.write('.') sys.stdout.write('.')
sys.stdout.flush() sys.stdout.flush()
# start to train # start to train
trainer.train( trainer.train(
reader=wmt14_reader, reader=wmt14_reader, event_handler=event_handler, num_passes=2)
event_handler=event_handler,
num_passes=10000, # generate a english sequence to french
feeding=feeding) else:
gen_creator = paddle.dataset.wmt14.test(dict_size)
gen_data = []
for item in gen_creator():
gen_data.append((item[0], ))
if len(gen_data) == 3:
break
beam_gen = seqToseq_net(source_dict_dim, target_dict_dim, is_generating)
parameters = paddle.dataset.wmt14.model()
trg_dict = paddle.dataset.wmt14.trg_dict(dict_size)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -29,7 +29,7 @@ URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz ...@@ -29,7 +29,7 @@ URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz
MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6' MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6'
# this is the pretrained model, whose bleu = 26.92 # this is the pretrained model, whose bleu = 26.92
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz' URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
MD5_MODEL = '6b097d23e15654608c6f74923e975535' MD5_MODEL = '4ce14a26607fb8a1cc23bcdedb1895e4'
START = "<s>" START = "<s>"
END = "<e>" END = "<e>"
...@@ -115,6 +115,12 @@ def model(): ...@@ -115,6 +115,12 @@ def model():
return parameters return parameters
def trg_dict(dict_size):
tar_file = download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
return trg_dict
def fetch(): def fetch():
download(URL_TRAIN, 'wmt14', MD5_TRAIN) download(URL_TRAIN, 'wmt14', MD5_TRAIN)
download(URL_MODEL, 'wmt14', MD5_MODEL) download(URL_MODEL, 'wmt14', MD5_MODEL)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册