提交 9ab860e6 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #1750 from luotao1/wmt14

add wmt14 pretrained model
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
wmt14 dataset wmt14 dataset
""" """
import tarfile import tarfile
import gzip
from paddle.v2.dataset.common import download from paddle.v2.dataset.common import download
from paddle.v2.parameters import Parameters
__all__ = ['train', 'test', 'build_dict'] __all__ = ['train', 'test', 'build_dict']
...@@ -25,6 +27,9 @@ MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5' ...@@ -25,6 +27,9 @@ MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
# this is a small set of data for test. The original data is too large and will be add later. # this is a small set of data for test. The original data is too large and will be add later.
URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz' URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6' MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6'
# this is the pretrained model, whose bleu = 26.92
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
MD5_MODEL = '6b097d23e15654608c6f74923e975535'
START = "<s>" START = "<s>"
END = "<e>" END = "<e>"
...@@ -103,5 +108,13 @@ def test(dict_size): ...@@ -103,5 +108,13 @@ def test(dict_size):
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size) download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size)
def model():
tar_file = download(URL_MODEL, 'wmt14', MD5_MODEL)
with gzip.open(tar_file, 'r') as f:
parameters = Parameters.from_tar(f)
return parameters
def fetch(): def fetch():
download(URL_TRAIN, 'wmt14', MD5_TRAIN) download(URL_TRAIN, 'wmt14', MD5_TRAIN)
download(URL_MODEL, 'wmt14', MD5_MODEL)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册