提交 14eb5b8e 编写于 作者: Y Yancey1989

rename fetch_all to fetch; add fetch_all function

上级 7b72c792
...@@ -20,7 +20,7 @@ TODO(yuyang18): Complete the comments. ...@@ -20,7 +20,7 @@ TODO(yuyang18): Complete the comments.
import cPickle import cPickle
import itertools import itertools
import numpy import numpy
import paddle.v2.dataset.common from common import download
import tarfile import tarfile
__all__ = ['train100', 'test100', 'train10', 'test10'] __all__ = ['train100', 'test100', 'train10', 'test10']
...@@ -55,28 +55,23 @@ def reader_creator(filename, sub_name): ...@@ -55,28 +55,23 @@ def reader_creator(filename, sub_name):
def train100(): def train100():
return reader_creator( return reader_creator(
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5), download(CIFAR100_URL, 'cifar', CIFAR100_MD5), 'train')
'train')
def test100(): def test100():
return reader_creator( return reader_creator(download(CIFAR100_URL, 'cifar', CIFAR100_MD5), 'test')
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
'test')
def train10(): def train10():
return reader_creator( return reader_creator(
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), download(CIFAR10_URL, 'cifar', CIFAR10_MD5), 'data_batch')
'data_batch')
def test10(): def test10():
return reader_creator( return reader_creator(
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), download(CIFAR10_URL, 'cifar', CIFAR10_MD5), 'test_batch')
'test_batch')
def fetch_data(): def fetch():
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5) download(CIFAR10_URL, 'cifar', CIFAR10_MD5)
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5) download(CIFAR100_URL, 'cifar', CIFAR100_MD5)
...@@ -17,6 +17,8 @@ import hashlib ...@@ -17,6 +17,8 @@ import hashlib
import os import os
import shutil import shutil
import sys import sys
import importlib
import paddle.v2.dataset
__all__ = ['DATA_HOME', 'download', 'md5file'] __all__ = ['DATA_HOME', 'download', 'md5file']
...@@ -69,3 +71,13 @@ def dict_add(a_dict, ele): ...@@ -69,3 +71,13 @@ def dict_add(a_dict, ele):
a_dict[ele] += 1 a_dict[ele] += 1
else: else:
a_dict[ele] = 1 a_dict[ele] = 1
def fetch_all():
for module_name in filter(lambda x: not x.startswith("__"),
dir(paddle.v2.dataset)):
if "fetch" in dir(
importlib.import_module("paddle.v2.dataset.%s" % module_name)):
getattr(
importlib.import_module("paddle.v2.dataset.%s" % module_name),
"fetch")()
...@@ -198,9 +198,9 @@ def test(): ...@@ -198,9 +198,9 @@ def test():
return reader_creator(reader, word_dict, verb_dict, label_dict) return reader_creator(reader, word_dict, verb_dict, label_dict)
def fetch_data(): def fetch():
paddle.v2.dataset.common.download(WORDDICT_URL, 'conll05st', WORDDICT_MD5) download(WORDDICT_URL, 'conll05st', WORDDICT_MD5)
paddle.v2.dataset.common.download(VERBDICT_URL, 'conll05st', VERBDICT_MD5) download(VERBDICT_URL, 'conll05st', VERBDICT_MD5)
paddle.v2.dataset.common.download(TRGDICT_URL, 'conll05st', TRGDICT_MD5) download(TRGDICT_URL, 'conll05st', TRGDICT_MD5)
paddle.v2.dataset.common.download(EMB_URL, 'conll05st', EMB_MD5) download(EMB_URL, 'conll05st', EMB_MD5)
paddle.v2.dataset.common.download(DATA_URL, 'conll05st', DATA_MD5) download(DATA_URL, 'conll05st', DATA_MD5)
...@@ -125,5 +125,5 @@ def word_dict(): ...@@ -125,5 +125,5 @@ def word_dict():
re.compile("aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$"), 150) re.compile("aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$"), 150)
def fetch_data(): def fetch():
paddle.v2.dataset.common.download(URL, 'imdb', MD5) paddle.v2.dataset.common.download(URL, 'imdb', MD5)
...@@ -91,5 +91,5 @@ def test(word_idx, n): ...@@ -91,5 +91,5 @@ def test(word_idx, n):
return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n) return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n)
def fetch_data(): def fetch():
paddle.v2.dataset.common.download(URL, "imikolov", MD5) paddle.v2.dataset.common.download(URL, "imikolov", MD5)
...@@ -108,6 +108,8 @@ def test(): ...@@ -108,6 +108,8 @@ def test():
TEST_LABEL_MD5), 100) TEST_LABEL_MD5), 100)
def fetch_data(): def fetch():
paddle.v2.dataset.common.download(TRAIN_IMAGE_URL, 'mnist', TRAIN_IMAGE_MD5) paddle.v2.dataset.common.download(TRAIN_IMAGE_URL, 'mnist', TRAIN_IMAGE_MD5)
paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5) paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5)
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
...@@ -205,8 +205,8 @@ def unittest(): ...@@ -205,8 +205,8 @@ def unittest():
print train_count, test_count print train_count, test_count
def fetch_data(): def fetch():
paddle.v2.dataset.common.download(URL, "movielens", MD5) download(URL, "movielens", MD5)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -26,7 +26,7 @@ from itertools import chain ...@@ -26,7 +26,7 @@ from itertools import chain
import nltk import nltk
from nltk.corpus import movie_reviews from nltk.corpus import movie_reviews
import paddle.v2.dataset.common import common
__all__ = ['train', 'test', 'get_word_dict'] __all__ = ['train', 'test', 'get_word_dict']
NUM_TRAINING_INSTANCES = 1600 NUM_TRAINING_INSTANCES = 1600
...@@ -127,5 +127,5 @@ def test(): ...@@ -127,5 +127,5 @@ def test():
return reader_creator(data_set[NUM_TRAINING_INSTANCES:]) return reader_creator(data_set[NUM_TRAINING_INSTANCES:])
def fetch_data(): def fetch():
nltk.download('movie_reviews', download_dir=common.DATA_HOME) nltk.download('movie_reviews', download_dir=common.DATA_HOME)
...@@ -91,5 +91,5 @@ def test(): ...@@ -91,5 +91,5 @@ def test():
return reader return reader
def fetch_data(): def fetch():
paddle.v2.dataset.common.download(URL, 'uci_housing', MD5) download(URL, 'uci_housing', MD5)
...@@ -16,7 +16,7 @@ wmt14 dataset ...@@ -16,7 +16,7 @@ wmt14 dataset
""" """
import tarfile import tarfile
import paddle.v2.dataset.common from paddle.v2.dataset.common import download
__all__ = ['train', 'test', 'build_dict'] __all__ = ['train', 'test', 'build_dict']
...@@ -95,15 +95,13 @@ def reader_creator(tar_file, file_name, dict_size): ...@@ -95,15 +95,13 @@ def reader_creator(tar_file, file_name, dict_size):
def train(dict_size): def train(dict_size):
return reader_creator( return reader_creator(
paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN), download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'train/train', dict_size)
'train/train', dict_size)
def test(dict_size): def test(dict_size):
return reader_creator( return reader_creator(
paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN), download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size)
'test/test', dict_size)
def fetch_data(): def fetch():
paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN) download(URL_TRAIN, 'wmt14', MD5_TRAIN)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册