From 32b168c78ff807ca178bfb016b3d178178b66202 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 28 Feb 2017 12:35:22 +0800 Subject: [PATCH] Refine code --- demo/recommendation/api_train_v2.py | 12 ++++++++++++ python/paddle/v2/dataset/__init__.py | 4 +++- python/paddle/v2/dataset/movielens.py | 10 ++++++++-- 3 files changed, 23 insertions(+), 3 deletions(-) create mode 100644 demo/recommendation/api_train_v2.py diff --git a/demo/recommendation/api_train_v2.py b/demo/recommendation/api_train_v2.py new file mode 100644 index 000000000..64dff9ae4 --- /dev/null +++ b/demo/recommendation/api_train_v2.py @@ -0,0 +1,12 @@ +import paddle.v2 as paddle + + +def main(): + movie_title_dict = paddle.dataset.movielens.get_movie_title_dict() + title_word_count = len(movie_title_dict) + + paddle.layer.mixed + + +if __name__ == '__main__': + main() diff --git a/python/paddle/v2/dataset/__init__.py b/python/paddle/v2/dataset/__init__.py index 9647e9850..a947edd2c 100644 --- a/python/paddle/v2/dataset/__init__.py +++ b/python/paddle/v2/dataset/__init__.py @@ -1,3 +1,5 @@ import mnist +import cifar +import movielens -__all__ = ['mnist'] +__all__ = ['mnist', 'cifar', 'movielens'] diff --git a/python/paddle/v2/dataset/movielens.py b/python/paddle/v2/dataset/movielens.py index dcffcff2f..b66448faf 100644 --- a/python/paddle/v2/dataset/movielens.py +++ b/python/paddle/v2/dataset/movielens.py @@ -4,7 +4,7 @@ import re import random import functools -__all__ = ['train_creator', 'test_creator'] +__all__ = ['train_creator', 'test_creator', 'get_movie_title_dict'] class MovieInfo(object): @@ -40,7 +40,8 @@ USER_INFO = None def __initialize_meta_info__(): fn = download( url='http://files.grouplens.org/datasets/movielens/ml-1m.zip', - md5='c4d9eecfca2ab87c1945afe126590906') + module_name='movielens', + md5sum='c4d9eecfca2ab87c1945afe126590906') global MOVIE_INFO if MOVIE_INFO is None: pattern = re.compile(r'^(.*)\((\d+)\)$') @@ -107,6 +108,11 @@ train_creator = functools.partial(__reader_creator__, is_test=False) test_creator = functools.partial(__reader_creator__, is_test=True) +def get_movie_title_dict(): + __initialize_meta_info__() + return MOVIE_TITLE_DICT + + def unittest(): for train_count, _ in enumerate(train_creator()()): pass -- GitLab