From 07a86b52f4a792c3f2350fb1fadb3f34513c37b1 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 28 Feb 2017 16:42:51 +0800 Subject: [PATCH] Refine --- demo/recommendation/api_train_v2.py | 34 +++++++++++++++++++++++++-- python/paddle/v2/dataset/movielens.py | 34 ++++++++++++++++++++++----- 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/demo/recommendation/api_train_v2.py b/demo/recommendation/api_train_v2.py index 64dff9ae4..c726fa5bd 100644 --- a/demo/recommendation/api_train_v2.py +++ b/demo/recommendation/api_train_v2.py @@ -3,9 +3,39 @@ import paddle.v2 as paddle def main(): movie_title_dict = paddle.dataset.movielens.get_movie_title_dict() - title_word_count = len(movie_title_dict) + uid = paddle.layer.data( + name='user_id', + type=paddle.data_type.integer_value( + paddle.dataset.movielens.max_user_id() + 1)) + usr_emb = paddle.layer.embedding(input=uid, size=32) - paddle.layer.mixed + usr_gender_id = paddle.layer.data( + name='gender_id', type=paddle.data_type.integer_value(2)) + usr_gender_emb = paddle.layer.embedding(input=usr_gender_id, size=16) + + usr_age_id = paddle.layer.data( + name='age_id', + type=paddle.data_type.integer_value( + len(paddle.dataset.movielens.age_table))) + usr_age_emb = paddle.embedding(input=usr_age_id, size=16) + + usr_combined_features = paddle.fc( + input=[usr_emb, usr_gender_emb, usr_age_emb], + size=200, + act=paddle.activation.Tanh()) + + mov_id = paddle.layer.data( + name='movie_id', + type=paddle.data_type.integer_value( + paddle.dataset.movielens.max_movie_id() + 1)) + mov_emb = paddle.layer.embedding(input=mov_id, size=32) + + mov_title_id = paddle.layer.data( + name='movie_title', + type=paddle.data_type.integer_value(len(movie_title_dict))) + mov_title_emb = paddle.embedding(input=mov_title_id, size=32) + with paddle.layer.mixed() as mixed: + pass if __name__ == '__main__': diff --git a/python/paddle/v2/dataset/movielens.py b/python/paddle/v2/dataset/movielens.py index b66448faf..6efe42adb 100644 --- a/python/paddle/v2/dataset/movielens.py +++ b/python/paddle/v2/dataset/movielens.py @@ -4,7 +4,12 @@ import re import random import functools -__all__ = ['train_creator', 'test_creator', 'get_movie_title_dict'] +__all__ = [ + 'train', 'test', 'get_movie_title_dict', 'max_movie_id', 'max_user_id', + 'age_table' +] + +age_table = [1, 18, 25, 35, 45, 50, 56] class MovieInfo(object): @@ -24,7 +29,7 @@ class UserInfo(object): def __init__(self, index, gender, age, job_id): self.index = int(index) self.is_male = gender == 'M' - self.age = [1, 18, 25, 35, 45, 50, 56].index(int(age)) + self.age = age_table.index(int(age)) self.job_id = int(job_id) def value(self): @@ -104,8 +109,8 @@ def __reader_creator__(**kwargs): return lambda: __reader__(**kwargs) -train_creator = functools.partial(__reader_creator__, is_test=False) -test_creator = functools.partial(__reader_creator__, is_test=True) +train = functools.partial(__reader_creator__, is_test=False) +test = functools.partial(__reader_creator__, is_test=True) def get_movie_title_dict(): @@ -113,10 +118,27 @@ def get_movie_title_dict(): return MOVIE_TITLE_DICT +def __max_index_info__(a, b): + if a.index > b.index: + return a + else: + return b + + +def max_movie_id(): + __initialize_meta_info__() + return reduce(__max_index_info__, MOVIE_INFO.viewvalues()).index + + +def max_user_id(): + __initialize_meta_info__() + return reduce(__max_index_info__, USER_INFO.viewvalues()).index + + def unittest(): - for train_count, _ in enumerate(train_creator()()): + for train_count, _ in enumerate(train()()): pass - for test_count, _ in enumerate(test_creator()()): + for test_count, _ in enumerate(test()()): pass print train_count, test_count -- GitLab