提交 cd161926 编写于 作者: Y Yi Wang

Merge branch 'develop' of https://github.com/paddlepaddle/paddle into memory_cpu_allocator

...@@ -166,11 +166,21 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config, ...@@ -166,11 +166,21 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config,
outArgStream_ = HPPL_STREAM_1; outArgStream_ = HPPL_STREAM_1;
start();
}
void MultiGradientMachine::start() {
for (auto& thread : threads_) { for (auto& thread : threads_) {
thread->start(); thread->start();
} }
} }
void MultiGradientMachine::finish() {
for (auto& thread : threads_) {
thread->stop();
}
}
std::vector<const std::vector<ParameterPtr>*> std::vector<const std::vector<ParameterPtr>*>
MultiGradientMachine::getSlaveParameters() { MultiGradientMachine::getSlaveParameters() {
std::vector<const std::vector<ParameterPtr>*> vec; std::vector<const std::vector<ParameterPtr>*> vec;
...@@ -326,12 +336,6 @@ void MultiGradientMachine::onPassEnd() { ...@@ -326,12 +336,6 @@ void MultiGradientMachine::onPassEnd() {
} }
} }
void MultiGradientMachine::finish() {
for (auto& thread : threads_) {
thread->stop();
}
}
Evaluator* MultiGradientMachine::makeEvaluator() const { Evaluator* MultiGradientMachine::makeEvaluator() const {
return threads_[0]->getGradientMachine()->makeEvaluator(); return threads_[0]->getGradientMachine()->makeEvaluator();
} }
...@@ -445,7 +449,7 @@ TrainerThread::TrainerThread(const ModelConfig& config, ...@@ -445,7 +449,7 @@ TrainerThread::TrainerThread(const ModelConfig& config,
gradStream_ = HPPL_STREAM_2; gradStream_ = HPPL_STREAM_2;
valueStream_ = HPPL_STREAM_3; valueStream_ = HPPL_STREAM_3;
stopping_ = false; stopping_ = true;
updateCounter_ = 0; updateCounter_ = 0;
parameterUpdated_ = false; parameterUpdated_ = false;
} }
...@@ -453,6 +457,10 @@ TrainerThread::TrainerThread(const ModelConfig& config, ...@@ -453,6 +457,10 @@ TrainerThread::TrainerThread(const ModelConfig& config,
TrainerThread::~TrainerThread() { stop(); } TrainerThread::~TrainerThread() { stop(); }
void TrainerThread::start() { void TrainerThread::start() {
if (!stopping_) return;
stopping_ = false;
gradientMachine_->start(); gradientMachine_->start();
computeThread_.reset(new std::thread([this]() { computeThread(); })); computeThread_.reset(new std::thread([this]() { computeThread(); }));
......
...@@ -176,6 +176,10 @@ public: ...@@ -176,6 +176,10 @@ public:
explicit MultiGradientMachine(const ModelConfig& config, bool useGpu); explicit MultiGradientMachine(const ModelConfig& config, bool useGpu);
virtual void start();
virtual void finish();
virtual void prefetch(const std::vector<Argument>& inArgs); virtual void prefetch(const std::vector<Argument>& inArgs);
virtual void forward(const std::vector<Argument>& inArgs, virtual void forward(const std::vector<Argument>& inArgs,
...@@ -193,8 +197,6 @@ public: ...@@ -193,8 +197,6 @@ public:
virtual void onPassEnd(); virtual void onPassEnd();
virtual void finish();
virtual Evaluator* makeEvaluator() const; virtual Evaluator* makeEvaluator() const;
virtual void eval(Evaluator* evaluator) const; virtual void eval(Evaluator* evaluator) const;
......
...@@ -31,10 +31,10 @@ images per class. ...@@ -31,10 +31,10 @@ images per class.
import cPickle import cPickle
import itertools import itertools
import numpy import numpy
from common import download import paddle.v2.dataset.common
import tarfile import tarfile
__all__ = ['train100', 'test100', 'train10', 'test10'] __all__ = ['train100', 'test100', 'train10', 'test10', 'convert']
URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/' URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/'
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz' CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
...@@ -75,7 +75,8 @@ def train100(): ...@@ -75,7 +75,8 @@ def train100():
:rtype: callable :rtype: callable
""" """
return reader_creator( return reader_creator(
download(CIFAR100_URL, 'cifar', CIFAR100_MD5), 'train') paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
'train')
def test100(): def test100():
...@@ -88,7 +89,9 @@ def test100(): ...@@ -88,7 +89,9 @@ def test100():
:return: Test reader creator. :return: Test reader creator.
:rtype: callable :rtype: callable
""" """
return reader_creator(download(CIFAR100_URL, 'cifar', CIFAR100_MD5), 'test') return reader_creator(
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
'test')
def train10(): def train10():
...@@ -102,7 +105,8 @@ def train10(): ...@@ -102,7 +105,8 @@ def train10():
:rtype: callable :rtype: callable
""" """
return reader_creator( return reader_creator(
download(CIFAR10_URL, 'cifar', CIFAR10_MD5), 'data_batch') paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'data_batch')
def test10(): def test10():
...@@ -116,9 +120,20 @@ def test10(): ...@@ -116,9 +120,20 @@ def test10():
:rtype: callable :rtype: callable
""" """
return reader_creator( return reader_creator(
download(CIFAR10_URL, 'cifar', CIFAR10_MD5), 'test_batch') paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'test_batch')
def fetch(): def fetch():
download(CIFAR10_URL, 'cifar', CIFAR10_MD5) paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5)
download(CIFAR100_URL, 'cifar', CIFAR100_MD5) paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5)
def convert(path):
"""
Converts dataset to recordio format
"""
paddle.v2.dataset.common.convert(path, train100(), 10, "cifar_train100")
paddle.v2.dataset.common.convert(path, test100(), 10, "cifar_test100")
paddle.v2.dataset.common.convert(path, train10(), 10, "cifar_train10")
paddle.v2.dataset.common.convert(path, test10(), 10, "cifar_test10")
...@@ -23,7 +23,10 @@ import paddle.v2.dataset ...@@ -23,7 +23,10 @@ import paddle.v2.dataset
import cPickle import cPickle
import glob import glob
__all__ = ['DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader'] __all__ = [
'DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader',
'convert'
]
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
......
...@@ -23,9 +23,9 @@ to initialize SRL model. ...@@ -23,9 +23,9 @@ to initialize SRL model.
import tarfile import tarfile
import gzip import gzip
import itertools import itertools
from common import download import paddle.v2.dataset.common
__all__ = ['test, get_dict', 'get_embedding'] __all__ = ['test, get_dict', 'get_embedding', 'convert']
DATA_URL = 'http://www.cs.upc.edu/~srlconll/conll05st-tests.tar.gz' DATA_URL = 'http://www.cs.upc.edu/~srlconll/conll05st-tests.tar.gz'
DATA_MD5 = '387719152ae52d60422c016e92a742fc' DATA_MD5 = '387719152ae52d60422c016e92a742fc'
...@@ -182,9 +182,15 @@ def get_dict(): ...@@ -182,9 +182,15 @@ def get_dict():
""" """
Get the word, verb and label dictionary of Wikipedia corpus. Get the word, verb and label dictionary of Wikipedia corpus.
""" """
word_dict = load_dict(download(WORDDICT_URL, 'conll05st', WORDDICT_MD5)) word_dict = load_dict(
verb_dict = load_dict(download(VERBDICT_URL, 'conll05st', VERBDICT_MD5)) paddle.v2.dataset.common.download(WORDDICT_URL, 'conll05st',
label_dict = load_dict(download(TRGDICT_URL, 'conll05st', TRGDICT_MD5)) WORDDICT_MD5))
verb_dict = load_dict(
paddle.v2.dataset.common.download(VERBDICT_URL, 'conll05st',
VERBDICT_MD5))
label_dict = load_dict(
paddle.v2.dataset.common.download(TRGDICT_URL, 'conll05st',
TRGDICT_MD5))
return word_dict, verb_dict, label_dict return word_dict, verb_dict, label_dict
...@@ -192,7 +198,7 @@ def get_embedding(): ...@@ -192,7 +198,7 @@ def get_embedding():
""" """
Get the trained word vector based on Wikipedia corpus. Get the trained word vector based on Wikipedia corpus.
""" """
return download(EMB_URL, 'conll05st', EMB_MD5) return paddle.v2.dataset.common.download(EMB_URL, 'conll05st', EMB_MD5)
def test(): def test():
...@@ -209,15 +215,23 @@ def test(): ...@@ -209,15 +215,23 @@ def test():
""" """
word_dict, verb_dict, label_dict = get_dict() word_dict, verb_dict, label_dict = get_dict()
reader = corpus_reader( reader = corpus_reader(
download(DATA_URL, 'conll05st', DATA_MD5), paddle.v2.dataset.common.download(DATA_URL, 'conll05st', DATA_MD5),
words_name='conll05st-release/test.wsj/words/test.wsj.words.gz', words_name='conll05st-release/test.wsj/words/test.wsj.words.gz',
props_name='conll05st-release/test.wsj/props/test.wsj.props.gz') props_name='conll05st-release/test.wsj/props/test.wsj.props.gz')
return reader_creator(reader, word_dict, verb_dict, label_dict) return reader_creator(reader, word_dict, verb_dict, label_dict)
def fetch(): def fetch():
download(WORDDICT_URL, 'conll05st', WORDDICT_MD5) paddle.v2.dataset.common.download(WORDDICT_URL, 'conll05st', WORDDICT_MD5)
download(VERBDICT_URL, 'conll05st', VERBDICT_MD5) paddle.v2.dataset.common.download(VERBDICT_URL, 'conll05st', VERBDICT_MD5)
download(TRGDICT_URL, 'conll05st', TRGDICT_MD5) paddle.v2.dataset.common.download(TRGDICT_URL, 'conll05st', TRGDICT_MD5)
download(EMB_URL, 'conll05st', EMB_MD5) paddle.v2.dataset.common.download(EMB_URL, 'conll05st', EMB_MD5)
download(DATA_URL, 'conll05st', DATA_MD5) paddle.v2.dataset.common.download(DATA_URL, 'conll05st', DATA_MD5)
def convert(path):
"""
Converts dataset to recordio format
"""
paddle.v2.dataset.common.convert(path, test(), 10, "conl105_train")
paddle.v2.dataset.common.convert(path, test(), 10, "conl105_test")
...@@ -28,7 +28,7 @@ import re ...@@ -28,7 +28,7 @@ import re
import string import string
import threading import threading
__all__ = ['build_dict', 'train', 'test'] __all__ = ['build_dict', 'train', 'test', 'convert']
URL = 'http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz' URL = 'http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz'
MD5 = '7c2ac02c03563afcf9b574c7e56c153a' MD5 = '7c2ac02c03563afcf9b574c7e56c153a'
...@@ -166,3 +166,12 @@ def word_dict(): ...@@ -166,3 +166,12 @@ def word_dict():
def fetch(): def fetch():
paddle.v2.dataset.common.download(URL, 'imdb', MD5) paddle.v2.dataset.common.download(URL, 'imdb', MD5)
def convert(path):
"""
Converts dataset to recordio format
"""
w = word_dict()
paddle.v2.dataset.common.convert(path, lambda: train(w), 10, "imdb_train")
paddle.v2.dataset.common.convert(path, lambda: test(w), 10, "imdb_test")
...@@ -22,7 +22,7 @@ import paddle.v2.dataset.common ...@@ -22,7 +22,7 @@ import paddle.v2.dataset.common
import collections import collections
import tarfile import tarfile
__all__ = ['train', 'test', 'build_dict'] __all__ = ['train', 'test', 'build_dict', 'convert']
URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz' URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
MD5 = '30177ea32e27c525793142b6bf2c8e2d' MD5 = '30177ea32e27c525793142b6bf2c8e2d'
...@@ -146,3 +146,15 @@ def test(word_idx, n, data_type=DataType.NGRAM): ...@@ -146,3 +146,15 @@ def test(word_idx, n, data_type=DataType.NGRAM):
def fetch(): def fetch():
paddle.v2.dataset.common.download(URL, "imikolov", MD5) paddle.v2.dataset.common.download(URL, "imikolov", MD5)
def convert(path):
"""
Converts dataset to recordio format
"""
N = 5
word_dict = build_dict()
paddle.v2.dataset.common.convert(path,
train(word_dict, N), 10, "imikolov_train")
paddle.v2.dataset.common.convert(path,
test(word_dict, N), 10, "imikolov_test")
...@@ -21,7 +21,7 @@ import paddle.v2.dataset.common ...@@ -21,7 +21,7 @@ import paddle.v2.dataset.common
import subprocess import subprocess
import numpy import numpy
import platform import platform
__all__ = ['train', 'test'] __all__ = ['train', 'test', 'convert']
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/' URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz' TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
...@@ -113,3 +113,11 @@ def fetch(): ...@@ -113,3 +113,11 @@ def fetch():
paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5) paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5) paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5)
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist', TRAIN_LABEL_MD5) paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
def convert(path):
"""
Converts dataset to recordio format
"""
paddle.v2.dataset.common.convert(path, train(), 10, "minist_train")
paddle.v2.dataset.common.convert(path, test(), 10, "minist_test")
...@@ -23,14 +23,15 @@ set and test set into paddle reader creators. ...@@ -23,14 +23,15 @@ set and test set into paddle reader creators.
""" """
import zipfile import zipfile
from common import download import paddle.v2.dataset.common
import re import re
import random import random
import functools import functools
__all__ = [ __all__ = [
'train', 'test', 'get_movie_title_dict', 'max_movie_id', 'max_user_id', 'train', 'test', 'get_movie_title_dict', 'max_movie_id', 'max_user_id',
'age_table', 'movie_categories', 'max_job_id', 'user_info', 'movie_info' 'age_table', 'movie_categories', 'max_job_id', 'user_info', 'movie_info',
'convert'
] ]
age_table = [1, 18, 25, 35, 45, 50, 56] age_table = [1, 18, 25, 35, 45, 50, 56]
...@@ -99,7 +100,7 @@ USER_INFO = None ...@@ -99,7 +100,7 @@ USER_INFO = None
def __initialize_meta_info__(): def __initialize_meta_info__():
fn = download(URL, "movielens", MD5) fn = paddle.v2.dataset.common.download(URL, "movielens", MD5)
global MOVIE_INFO global MOVIE_INFO
if MOVIE_INFO is None: if MOVIE_INFO is None:
pattern = re.compile(r'^(.*)\((\d+)\)$') pattern = re.compile(r'^(.*)\((\d+)\)$')
...@@ -246,7 +247,15 @@ def unittest(): ...@@ -246,7 +247,15 @@ def unittest():
def fetch(): def fetch():
download(URL, "movielens", MD5) paddle.v2.dataset.common.download(URL, "movielens", MD5)
def convert(path):
"""
Converts dataset to recordio format
"""
paddle.v2.dataset.common.convert(path, train(), 10, "movielens_train")
paddle.v2.dataset.common.convert(path, test(), 10, "movielens_test")
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -26,9 +26,9 @@ from itertools import chain ...@@ -26,9 +26,9 @@ from itertools import chain
import nltk import nltk
from nltk.corpus import movie_reviews from nltk.corpus import movie_reviews
import common import paddle.v2.dataset.common
__all__ = ['train', 'test', 'get_word_dict'] __all__ = ['train', 'test', 'get_word_dict', 'convert']
NUM_TRAINING_INSTANCES = 1600 NUM_TRAINING_INSTANCES = 1600
NUM_TOTAL_INSTANCES = 2000 NUM_TOTAL_INSTANCES = 2000
...@@ -39,12 +39,13 @@ def download_data_if_not_yet(): ...@@ -39,12 +39,13 @@ def download_data_if_not_yet():
""" """
try: try:
# make sure that nltk can find the data # make sure that nltk can find the data
if common.DATA_HOME not in nltk.data.path: if paddle.v2.dataset.common.DATA_HOME not in nltk.data.path:
nltk.data.path.append(common.DATA_HOME) nltk.data.path.append(paddle.v2.dataset.common.DATA_HOME)
movie_reviews.categories() movie_reviews.categories()
except LookupError: except LookupError:
print "Downloading movie_reviews data set, please wait....." print "Downloading movie_reviews data set, please wait....."
nltk.download('movie_reviews', download_dir=common.DATA_HOME) nltk.download(
'movie_reviews', download_dir=paddle.v2.dataset.common.DATA_HOME)
print "Download data set success....." print "Download data set success....."
print "Path is " + nltk.data.find('corpora/movie_reviews').path print "Path is " + nltk.data.find('corpora/movie_reviews').path
...@@ -128,4 +129,13 @@ def test(): ...@@ -128,4 +129,13 @@ def test():
def fetch(): def fetch():
nltk.download('movie_reviews', download_dir=common.DATA_HOME) nltk.download(
'movie_reviews', download_dir=paddle.v2.dataset.common.DATA_HOME)
def convert(path):
"""
Converts dataset to recordio format
"""
paddle.v2.dataset.common.convert(path, train, 10, "sentiment_train")
paddle.v2.dataset.common.convert(path, test, 10, "sentiment_test")
...@@ -14,14 +14,14 @@ ...@@ -14,14 +14,14 @@
""" """
UCI Housing dataset. UCI Housing dataset.
This module will download dataset from This module will paddle.v2.dataset.common.download dataset from
https://archive.ics.uci.edu/ml/machine-learning-databases/housing/ and https://archive.ics.uci.edu/ml/machine-learning-databases/housing/ and
parse training set and test set into paddle reader creators. parse training set and test set into paddle reader creators.
""" """
import numpy as np import numpy as np
import os import os
from common import download import paddle.v2.dataset.common
__all__ = ['train', 'test'] __all__ = ['train', 'test']
...@@ -29,7 +29,7 @@ URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing ...@@ -29,7 +29,7 @@ URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing
MD5 = 'd4accdce7a25600298819f8e28e8d593' MD5 = 'd4accdce7a25600298819f8e28e8d593'
feature_names = [ feature_names = [
'CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX', 'CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX',
'PTRATIO', 'B', 'LSTAT' 'PTRATIO', 'B', 'LSTAT', 'convert'
] ]
UCI_TRAIN_DATA = None UCI_TRAIN_DATA = None
...@@ -82,7 +82,7 @@ def train(): ...@@ -82,7 +82,7 @@ def train():
:rtype: callable :rtype: callable
""" """
global UCI_TRAIN_DATA global UCI_TRAIN_DATA
load_data(download(URL, 'uci_housing', MD5)) load_data(paddle.v2.dataset.common.download(URL, 'uci_housing', MD5))
def reader(): def reader():
for d in UCI_TRAIN_DATA: for d in UCI_TRAIN_DATA:
...@@ -102,7 +102,7 @@ def test(): ...@@ -102,7 +102,7 @@ def test():
:rtype: callable :rtype: callable
""" """
global UCI_TEST_DATA global UCI_TEST_DATA
load_data(download(URL, 'uci_housing', MD5)) load_data(paddle.v2.dataset.common.download(URL, 'uci_housing', MD5))
def reader(): def reader():
for d in UCI_TEST_DATA: for d in UCI_TEST_DATA:
...@@ -112,4 +112,12 @@ def test(): ...@@ -112,4 +112,12 @@ def test():
def fetch(): def fetch():
download(URL, 'uci_housing', MD5) paddle.v2.dataset.common.download(URL, 'uci_housing', MD5)
def convert(path):
"""
Converts dataset to recordio format
"""
paddle.v2.dataset.common.convert(path, train(), 10, "uci_housing_train")
paddle.v2.dataset.common.convert(path, test(), 10, "uci_houseing_test")
...@@ -22,10 +22,10 @@ parse training set and test set into paddle reader creators. ...@@ -22,10 +22,10 @@ parse training set and test set into paddle reader creators.
import tarfile import tarfile
import gzip import gzip
from paddle.v2.dataset.common import download import paddle.v2.dataset.common
from paddle.v2.parameters import Parameters from paddle.v2.parameters import Parameters
__all__ = ['train', 'test', 'build_dict'] __all__ = ['train', 'test', 'build_dict', 'convert']
URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz' URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz'
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5' MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
...@@ -115,7 +115,8 @@ def train(dict_size): ...@@ -115,7 +115,8 @@ def train(dict_size):
:rtype: callable :rtype: callable
""" """
return reader_creator( return reader_creator(
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'train/train', dict_size) paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN),
'train/train', dict_size)
def test(dict_size): def test(dict_size):
...@@ -130,16 +131,18 @@ def test(dict_size): ...@@ -130,16 +131,18 @@ def test(dict_size):
:rtype: callable :rtype: callable
""" """
return reader_creator( return reader_creator(
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size) paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN),
'test/test', dict_size)
def gen(dict_size): def gen(dict_size):
return reader_creator( return reader_creator(
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'gen/gen', dict_size) paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN),
'gen/gen', dict_size)
def model(): def model():
tar_file = download(URL_MODEL, 'wmt14', MD5_MODEL) tar_file = paddle.v2.dataset.common.download(URL_MODEL, 'wmt14', MD5_MODEL)
with gzip.open(tar_file, 'r') as f: with gzip.open(tar_file, 'r') as f:
parameters = Parameters.from_tar(f) parameters = Parameters.from_tar(f)
return parameters return parameters
...@@ -148,7 +151,7 @@ def model(): ...@@ -148,7 +151,7 @@ def model():
def get_dict(dict_size, reverse=True): def get_dict(dict_size, reverse=True):
# if reverse = False, return dict = {'a':'001', 'b':'002', ...} # if reverse = False, return dict = {'a':'001', 'b':'002', ...}
# else reverse = true, return dict = {'001':'a', '002':'b', ...} # else reverse = true, return dict = {'001':'a', '002':'b', ...}
tar_file = download(URL_TRAIN, 'wmt14', MD5_TRAIN) tar_file = paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size) src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
if reverse: if reverse:
src_dict = {v: k for k, v in src_dict.items()} src_dict = {v: k for k, v in src_dict.items()}
...@@ -157,5 +160,14 @@ def get_dict(dict_size, reverse=True): ...@@ -157,5 +160,14 @@ def get_dict(dict_size, reverse=True):
def fetch(): def fetch():
download(URL_TRAIN, 'wmt14', MD5_TRAIN) paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
download(URL_MODEL, 'wmt14', MD5_MODEL) paddle.v2.dataset.common.download(URL_MODEL, 'wmt14', MD5_MODEL)
def convert(path):
"""
Converts dataset to recordio format
"""
dict_size = 30000
paddle.v2.dataset.common.convert(path, train(dict_size), 10, "wmt14_train")
paddle.v2.dataset.common.convert(path, test(dict_size), 10, "wmt14_test")
...@@ -13,9 +13,7 @@ ...@@ -13,9 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import unittest import unittest
import numpy as np import numpy as np
import paddle.v2.reader.creator import paddle.v2.reader.creator
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册