提交 a42233c2 编写于 作者: L Luo Tao

add wmt14 trg_dict

上级 caffcc83
......@@ -126,33 +126,28 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
def main():
paddle.init(use_gpu=False, trainer_count=1)
is_generating = True
# source and target dict dim.
dict_size = 30000
source_dict_dim = target_dict_dim = dict_size
# define network topology
# train the network
if not is_generating:
cost = seqToseq_net(source_dict_dim, target_dict_dim)
parameters = paddle.parameters.create(cost)
# define optimize method and trainer
optimizer = paddle.optimizer.Adam(
learning_rate=5e-5,
regularization=paddle.optimizer.L2Regularization(rate=1e-3))
regularization=paddle.optimizer.L2Regularization(rate=8e-4))
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer)
# define data reader
feeding = {
'source_language_word': 0,
'target_language_word': 1,
'target_language_next_word': 2
}
wmt14_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.train(dict_size=dict_size), buf_size=8192),
paddle.dataset.wmt14.train(dict_size), buf_size=8192),
batch_size=5)
# define event_handler callback
......@@ -160,17 +155,28 @@ def main():
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 10 == 0:
print "\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
event.pass_id, event.batch_id, event.cost,
event.metrics)
else:
sys.stdout.write('.')
sys.stdout.flush()
# start to train
trainer.train(
reader=wmt14_reader,
event_handler=event_handler,
num_passes=10000,
feeding=feeding)
reader=wmt14_reader, event_handler=event_handler, num_passes=2)
# generate a english sequence to french
else:
gen_creator = paddle.dataset.wmt14.test(dict_size)
gen_data = []
for item in gen_creator():
gen_data.append((item[0], ))
if len(gen_data) == 3:
break
beam_gen = seqToseq_net(source_dict_dim, target_dict_dim, is_generating)
parameters = paddle.dataset.wmt14.model()
trg_dict = paddle.dataset.wmt14.trg_dict(dict_size)
if __name__ == '__main__':
......
......@@ -29,7 +29,7 @@ URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz
MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6'
# this is the pretrained model, whose bleu = 26.92
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
MD5_MODEL = '6b097d23e15654608c6f74923e975535'
MD5_MODEL = '4ce14a26607fb8a1cc23bcdedb1895e4'
START = "<s>"
END = "<e>"
......@@ -115,6 +115,12 @@ def model():
return parameters
def trg_dict(dict_size):
tar_file = download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
return trg_dict
def fetch():
download(URL_TRAIN, 'wmt14', MD5_TRAIN)
download(URL_MODEL, 'wmt14', MD5_MODEL)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册