From a42233c2c21ab8afde85cdf6bf174ae66f93dbd3 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Wed, 12 Apr 2017 12:49:32 +0800 Subject: [PATCH] add wmt14 trg_dict --- demo/seqToseq/api_train_v2.py | 86 +++++++++++++++++-------------- python/paddle/v2/dataset/wmt14.py | 8 ++- 2 files changed, 53 insertions(+), 41 deletions(-) diff --git a/demo/seqToseq/api_train_v2.py b/demo/seqToseq/api_train_v2.py index 2809054e7d3..ac2665b5b35 100644 --- a/demo/seqToseq/api_train_v2.py +++ b/demo/seqToseq/api_train_v2.py @@ -126,51 +126,57 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False): def main(): paddle.init(use_gpu=False, trainer_count=1) + is_generating = True # source and target dict dim. dict_size = 30000 source_dict_dim = target_dict_dim = dict_size - # define network topology - cost = seqToseq_net(source_dict_dim, target_dict_dim) - parameters = paddle.parameters.create(cost) - - # define optimize method and trainer - optimizer = paddle.optimizer.Adam( - learning_rate=5e-5, - regularization=paddle.optimizer.L2Regularization(rate=1e-3)) - trainer = paddle.trainer.SGD(cost=cost, - parameters=parameters, - update_equation=optimizer) - - # define data reader - feeding = { - 'source_language_word': 0, - 'target_language_word': 1, - 'target_language_next_word': 2 - } - - wmt14_reader = paddle.batch( - paddle.reader.shuffle( - paddle.dataset.wmt14.train(dict_size=dict_size), buf_size=8192), - batch_size=5) - - # define event_handler callback - def event_handler(event): - if isinstance(event, paddle.event.EndIteration): - if event.batch_id % 10 == 0: - print "\nPass %d, Batch %d, Cost %f, %s" % ( - event.pass_id, event.batch_id, event.cost, event.metrics) - else: - sys.stdout.write('.') - sys.stdout.flush() - - # start to train - trainer.train( - reader=wmt14_reader, - event_handler=event_handler, - num_passes=10000, - feeding=feeding) + # train the network + if not is_generating: + cost = seqToseq_net(source_dict_dim, target_dict_dim) + parameters = paddle.parameters.create(cost) + + # define optimize method and trainer + optimizer = paddle.optimizer.Adam( + learning_rate=5e-5, + regularization=paddle.optimizer.L2Regularization(rate=8e-4)) + trainer = paddle.trainer.SGD(cost=cost, + parameters=parameters, + update_equation=optimizer) + # define data reader + wmt14_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.wmt14.train(dict_size), buf_size=8192), + batch_size=5) + + # define event_handler callback + def event_handler(event): + if isinstance(event, paddle.event.EndIteration): + if event.batch_id % 10 == 0: + print "\nPass %d, Batch %d, Cost %f, %s" % ( + event.pass_id, event.batch_id, event.cost, + event.metrics) + else: + sys.stdout.write('.') + sys.stdout.flush() + + # start to train + trainer.train( + reader=wmt14_reader, event_handler=event_handler, num_passes=2) + + # generate a english sequence to french + else: + gen_creator = paddle.dataset.wmt14.test(dict_size) + gen_data = [] + for item in gen_creator(): + gen_data.append((item[0], )) + if len(gen_data) == 3: + break + + beam_gen = seqToseq_net(source_dict_dim, target_dict_dim, is_generating) + parameters = paddle.dataset.wmt14.model() + trg_dict = paddle.dataset.wmt14.trg_dict(dict_size) if __name__ == '__main__': diff --git a/python/paddle/v2/dataset/wmt14.py b/python/paddle/v2/dataset/wmt14.py index 29e45bb124b..ad9f5f18d67 100644 --- a/python/paddle/v2/dataset/wmt14.py +++ b/python/paddle/v2/dataset/wmt14.py @@ -29,7 +29,7 @@ URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6' # this is the pretrained model, whose bleu = 26.92 URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz' -MD5_MODEL = '6b097d23e15654608c6f74923e975535' +MD5_MODEL = '4ce14a26607fb8a1cc23bcdedb1895e4' START = "" END = "" @@ -115,6 +115,12 @@ def model(): return parameters +def trg_dict(dict_size): + tar_file = download(URL_TRAIN, 'wmt14', MD5_TRAIN) + src_dict, trg_dict = __read_to_dict__(tar_file, dict_size) + return trg_dict + + def fetch(): download(URL_TRAIN, 'wmt14', MD5_TRAIN) download(URL_MODEL, 'wmt14', MD5_MODEL) -- GitLab