diff --git a/python/paddle/v2/dataset/wmt14.py b/python/paddle/v2/dataset/wmt14.py index ee63a93f5ad918b5bbc949ae6ba29082b3f6abd5..29e45bb124ba58252823ea13aab21ecab4ea7400 100644 --- a/python/paddle/v2/dataset/wmt14.py +++ b/python/paddle/v2/dataset/wmt14.py @@ -15,8 +15,10 @@ wmt14 dataset """ import tarfile +import gzip from paddle.v2.dataset.common import download +from paddle.v2.parameters import Parameters __all__ = ['train', 'test', 'build_dict'] @@ -25,6 +27,9 @@ MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5' # this is a small set of data for test. The original data is too large and will be add later. URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz' MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6' +# this is the pretrained model, whose bleu = 26.92 +URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz' +MD5_MODEL = '6b097d23e15654608c6f74923e975535' START = "" END = "" @@ -103,5 +108,13 @@ def test(dict_size): download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size) +def model(): + tar_file = download(URL_MODEL, 'wmt14', MD5_MODEL) + with gzip.open(tar_file, 'r') as f: + parameters = Parameters.from_tar(f) + return parameters + + def fetch(): download(URL_TRAIN, 'wmt14', MD5_TRAIN) + download(URL_MODEL, 'wmt14', MD5_MODEL)