提交 fbea3918 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #1774 from luotao1/wmt14

add wmt14 trg_dict
...@@ -126,51 +126,57 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False): ...@@ -126,51 +126,57 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
def main(): def main():
paddle.init(use_gpu=False, trainer_count=1) paddle.init(use_gpu=False, trainer_count=1)
is_generating = True
# source and target dict dim. # source and target dict dim.
dict_size = 30000 dict_size = 30000
source_dict_dim = target_dict_dim = dict_size source_dict_dim = target_dict_dim = dict_size
# define network topology # train the network
cost = seqToseq_net(source_dict_dim, target_dict_dim) if not is_generating:
parameters = paddle.parameters.create(cost) cost = seqToseq_net(source_dict_dim, target_dict_dim)
parameters = paddle.parameters.create(cost)
# define optimize method and trainer
optimizer = paddle.optimizer.Adam( # define optimize method and trainer
learning_rate=5e-5, optimizer = paddle.optimizer.Adam(
regularization=paddle.optimizer.L2Regularization(rate=1e-3)) learning_rate=5e-5,
trainer = paddle.trainer.SGD(cost=cost, regularization=paddle.optimizer.L2Regularization(rate=8e-4))
parameters=parameters, trainer = paddle.trainer.SGD(cost=cost,
update_equation=optimizer) parameters=parameters,
update_equation=optimizer)
# define data reader # define data reader
feeding = { wmt14_reader = paddle.batch(
'source_language_word': 0, paddle.reader.shuffle(
'target_language_word': 1, paddle.dataset.wmt14.train(dict_size), buf_size=8192),
'target_language_next_word': 2 batch_size=5)
}
# define event_handler callback
wmt14_reader = paddle.batch( def event_handler(event):
paddle.reader.shuffle( if isinstance(event, paddle.event.EndIteration):
paddle.dataset.wmt14.train(dict_size=dict_size), buf_size=8192), if event.batch_id % 10 == 0:
batch_size=5) print "\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost,
# define event_handler callback event.metrics)
def event_handler(event): else:
if isinstance(event, paddle.event.EndIteration): sys.stdout.write('.')
if event.batch_id % 10 == 0: sys.stdout.flush()
print "\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics) # start to train
else: trainer.train(
sys.stdout.write('.') reader=wmt14_reader, event_handler=event_handler, num_passes=2)
sys.stdout.flush()
# generate a english sequence to french
# start to train else:
trainer.train( gen_creator = paddle.dataset.wmt14.test(dict_size)
reader=wmt14_reader, gen_data = []
event_handler=event_handler, for item in gen_creator():
num_passes=10000, gen_data.append((item[0], ))
feeding=feeding) if len(gen_data) == 3:
break
beam_gen = seqToseq_net(source_dict_dim, target_dict_dim, is_generating)
parameters = paddle.dataset.wmt14.model()
trg_dict = paddle.dataset.wmt14.trg_dict(dict_size)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -29,7 +29,7 @@ URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz ...@@ -29,7 +29,7 @@ URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz
MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6' MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6'
# this is the pretrained model, whose bleu = 26.92 # this is the pretrained model, whose bleu = 26.92
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz' URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
MD5_MODEL = '6b097d23e15654608c6f74923e975535' MD5_MODEL = '4ce14a26607fb8a1cc23bcdedb1895e4'
START = "<s>" START = "<s>"
END = "<e>" END = "<e>"
...@@ -115,6 +115,12 @@ def model(): ...@@ -115,6 +115,12 @@ def model():
return parameters return parameters
def trg_dict(dict_size):
tar_file = download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
return trg_dict
def fetch(): def fetch():
download(URL_TRAIN, 'wmt14', MD5_TRAIN) download(URL_TRAIN, 'wmt14', MD5_TRAIN)
download(URL_MODEL, 'wmt14', MD5_MODEL) download(URL_MODEL, 'wmt14', MD5_MODEL)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册