提交 32b168c7 编写于 作者: Y Yu Yang

Refine code

上级 0eba01c0
import paddle.v2 as paddle
def main():
movie_title_dict = paddle.dataset.movielens.get_movie_title_dict()
title_word_count = len(movie_title_dict)
paddle.layer.mixed
if __name__ == '__main__':
main()
import mnist import mnist
import cifar
import movielens
__all__ = ['mnist'] __all__ = ['mnist', 'cifar', 'movielens']
...@@ -4,7 +4,7 @@ import re ...@@ -4,7 +4,7 @@ import re
import random import random
import functools import functools
__all__ = ['train_creator', 'test_creator'] __all__ = ['train_creator', 'test_creator', 'get_movie_title_dict']
class MovieInfo(object): class MovieInfo(object):
...@@ -40,7 +40,8 @@ USER_INFO = None ...@@ -40,7 +40,8 @@ USER_INFO = None
def __initialize_meta_info__(): def __initialize_meta_info__():
fn = download( fn = download(
url='http://files.grouplens.org/datasets/movielens/ml-1m.zip', url='http://files.grouplens.org/datasets/movielens/ml-1m.zip',
md5='c4d9eecfca2ab87c1945afe126590906') module_name='movielens',
md5sum='c4d9eecfca2ab87c1945afe126590906')
global MOVIE_INFO global MOVIE_INFO
if MOVIE_INFO is None: if MOVIE_INFO is None:
pattern = re.compile(r'^(.*)\((\d+)\)$') pattern = re.compile(r'^(.*)\((\d+)\)$')
...@@ -107,6 +108,11 @@ train_creator = functools.partial(__reader_creator__, is_test=False) ...@@ -107,6 +108,11 @@ train_creator = functools.partial(__reader_creator__, is_test=False)
test_creator = functools.partial(__reader_creator__, is_test=True) test_creator = functools.partial(__reader_creator__, is_test=True)
def get_movie_title_dict():
__initialize_meta_info__()
return MOVIE_TITLE_DICT
def unittest(): def unittest():
for train_count, _ in enumerate(train_creator()()): for train_count, _ in enumerate(train_creator()()):
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册