diff --git a/python/paddle/v2/dataset/cifar.py b/python/paddle/v2/dataset/cifar.py index 9a999de7e02aad1c0d09d74e1e650541fd430920..2ac71c6effe9d5f1140d1f574db9c9848b56433a 100644 --- a/python/paddle/v2/dataset/cifar.py +++ b/python/paddle/v2/dataset/cifar.py @@ -5,16 +5,14 @@ URL: https://www.cs.toronto.edu/~kriz/cifar.html the default train_creator, test_creator used for CIFAR-10 dataset. """ -from config import DATA_HOME -import os -import hashlib -import urllib2 -import shutil -import tarfile import cPickle import itertools +import tarfile + import numpy +from config import download + __all__ = [ 'cifar_100_train_creator', 'cifar_100_test_creator', 'train_creator', 'test_creator' @@ -47,31 +45,6 @@ def __read_batch__(filename, sub_name): return reader -def download(url, md5): - filename = os.path.split(url)[-1] - assert DATA_HOME is not None - filepath = os.path.join(DATA_HOME, md5) - if not os.path.exists(filepath): - os.makedirs(filepath) - __full_file__ = os.path.join(filepath, filename) - - def __file_ok__(): - if not os.path.exists(__full_file__): - return False - md5_hash = hashlib.md5() - with open(__full_file__, 'rb') as f: - for chunk in iter(lambda: f.read(4096), b""): - md5_hash.update(chunk) - - return md5_hash.hexdigest() == md5 - - while not __file_ok__(): - response = urllib2.urlopen(url) - with open(__full_file__, mode='wb') as of: - shutil.copyfileobj(fsrc=response, fdst=of) - return __full_file__ - - def cifar_100_train_creator(): fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5) return __read_batch__(fn, 'train') diff --git a/python/paddle/v2/dataset/config.py b/python/paddle/v2/dataset/config.py index 69e96d65ef1ef868aff5d46ddf3af250ca11e641..02a009f09c71ccf6a5292a188565adeeb3f875f6 100644 --- a/python/paddle/v2/dataset/config.py +++ b/python/paddle/v2/dataset/config.py @@ -1,8 +1,36 @@ +import hashlib import os +import shutil +import urllib2 -__all__ = ['DATA_HOME'] +__all__ = ['DATA_HOME', 'download'] DATA_HOME = os.path.expanduser('~/.cache/paddle_data_set') if not os.path.exists(DATA_HOME): os.makedirs(DATA_HOME) + + +def download(url, md5): + filename = os.path.split(url)[-1] + assert DATA_HOME is not None + filepath = os.path.join(DATA_HOME, md5) + if not os.path.exists(filepath): + os.makedirs(filepath) + __full_file__ = os.path.join(filepath, filename) + + def __file_ok__(): + if not os.path.exists(__full_file__): + return False + md5_hash = hashlib.md5() + with open(__full_file__, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + md5_hash.update(chunk) + + return md5_hash.hexdigest() == md5 + + while not __file_ok__(): + response = urllib2.urlopen(url) + with open(__full_file__, mode='wb') as of: + shutil.copyfileobj(fsrc=response, fdst=of) + return __full_file__ diff --git a/python/paddle/v2/dataset/movielens.py b/python/paddle/v2/dataset/movielens.py new file mode 100644 index 0000000000000000000000000000000000000000..314329e91cadf8a74466ed9f385cd596c0ba6f9f --- /dev/null +++ b/python/paddle/v2/dataset/movielens.py @@ -0,0 +1,120 @@ +import zipfile +from config import download +import re +import random +import functools + +__all__ = ['train_creator', 'test_creator'] + + +class MovieInfo(object): + def __init__(self, index, categories, title): + self.index = int(index) + self.categories = categories + self.title = title + + def value(self): + return [ + self.index, [CATEGORIES_DICT[c] for c in self.categories], + [MOVIE_TITLE_DICT[w.lower()] for w in self.title.split()] + ] + + +class UserInfo(object): + def __init__(self, index, gender, age, job_id): + self.index = int(index) + self.is_male = gender == 'M' + self.age = [1, 18, 25, 35, 45, 50, 56].index(int(age)) + self.job_id = int(job_id) + + def value(self): + return [self.index, 0 if self.is_male else 1, self.age, self.job_id] + + +MOVIE_INFO = None +MOVIE_TITLE_DICT = None +CATEGORIES_DICT = None +USER_INFO = None + + +def __initialize_meta_info__(): + fn = download( + url='http://files.grouplens.org/datasets/movielens/ml-1m.zip', + md5='c4d9eecfca2ab87c1945afe126590906') + global MOVIE_INFO + if MOVIE_INFO is None: + pattern = re.compile(r'^(.*)\((\d+)\)$') + with zipfile.ZipFile(file=fn) as package: + for info in package.infolist(): + assert isinstance(info, zipfile.ZipInfo) + MOVIE_INFO = dict() + title_word_set = set() + categories_set = set() + with package.open('ml-1m/movies.dat') as movie_file: + for i, line in enumerate(movie_file): + movie_id, title, categories = line.strip().split('::') + categories = categories.split('|') + for c in categories: + categories_set.add(c) + title = pattern.match(title).group(1) + MOVIE_INFO[int(movie_id)] = MovieInfo( + index=movie_id, categories=categories, title=title) + for w in title.split(): + title_word_set.add(w.lower()) + + global MOVIE_TITLE_DICT + MOVIE_TITLE_DICT = dict() + for i, w in enumerate(title_word_set): + MOVIE_TITLE_DICT[w] = i + + global CATEGORIES_DICT + CATEGORIES_DICT = dict() + for i, c in enumerate(categories_set): + CATEGORIES_DICT[c] = i + + global USER_INFO + USER_INFO = dict() + with package.open('ml-1m/users.dat') as user_file: + for line in user_file: + uid, gender, age, job, _ = line.strip().split("::") + USER_INFO[int(uid)] = UserInfo( + index=uid, gender=gender, age=age, job_id=job) + return fn + + +def __reader__(rand_seed=0, test_ratio=0.1, is_test=False): + fn = __initialize_meta_info__() + rand = random.Random(x=rand_seed) + with zipfile.ZipFile(file=fn) as package: + with package.open('ml-1m/ratings.dat') as rating: + for line in rating: + if (rand.random() < test_ratio) == is_test: + uid, mov_id, rating, _ = line.strip().split("::") + uid = int(uid) + mov_id = int(mov_id) + rating = float(rating) * 2 - 5.0 + + mov = MOVIE_INFO[mov_id] + usr = USER_INFO[uid] + yield usr.value() + mov.value() + [[rating]] + + +def __reader_creator__(**kwargs): + return lambda: __reader__(**kwargs) + + +train_creator = functools.partial(__reader_creator__, is_test=False) +test_creator = functools.partial(__reader_creator__, is_test=True) + + +def unittest(): + for train_count, _ in enumerate(train_creator()()): + pass + for test_count, _ in enumerate(test_creator()()): + pass + + print train_count, test_count + + +if __name__ == '__main__': + unittest()