提交 07a86b52 编写于 作者: Y Yu Yang

Refine

上级 aecdc61b
......@@ -3,9 +3,39 @@ import paddle.v2 as paddle
def main():
movie_title_dict = paddle.dataset.movielens.get_movie_title_dict()
title_word_count = len(movie_title_dict)
uid = paddle.layer.data(
name='user_id',
type=paddle.data_type.integer_value(
paddle.dataset.movielens.max_user_id() + 1))
usr_emb = paddle.layer.embedding(input=uid, size=32)
paddle.layer.mixed
usr_gender_id = paddle.layer.data(
name='gender_id', type=paddle.data_type.integer_value(2))
usr_gender_emb = paddle.layer.embedding(input=usr_gender_id, size=16)
usr_age_id = paddle.layer.data(
name='age_id',
type=paddle.data_type.integer_value(
len(paddle.dataset.movielens.age_table)))
usr_age_emb = paddle.embedding(input=usr_age_id, size=16)
usr_combined_features = paddle.fc(
input=[usr_emb, usr_gender_emb, usr_age_emb],
size=200,
act=paddle.activation.Tanh())
mov_id = paddle.layer.data(
name='movie_id',
type=paddle.data_type.integer_value(
paddle.dataset.movielens.max_movie_id() + 1))
mov_emb = paddle.layer.embedding(input=mov_id, size=32)
mov_title_id = paddle.layer.data(
name='movie_title',
type=paddle.data_type.integer_value(len(movie_title_dict)))
mov_title_emb = paddle.embedding(input=mov_title_id, size=32)
with paddle.layer.mixed() as mixed:
pass
if __name__ == '__main__':
......
......@@ -4,7 +4,12 @@ import re
import random
import functools
__all__ = ['train_creator', 'test_creator', 'get_movie_title_dict']
__all__ = [
'train', 'test', 'get_movie_title_dict', 'max_movie_id', 'max_user_id',
'age_table'
]
age_table = [1, 18, 25, 35, 45, 50, 56]
class MovieInfo(object):
......@@ -24,7 +29,7 @@ class UserInfo(object):
def __init__(self, index, gender, age, job_id):
self.index = int(index)
self.is_male = gender == 'M'
self.age = [1, 18, 25, 35, 45, 50, 56].index(int(age))
self.age = age_table.index(int(age))
self.job_id = int(job_id)
def value(self):
......@@ -104,8 +109,8 @@ def __reader_creator__(**kwargs):
return lambda: __reader__(**kwargs)
train_creator = functools.partial(__reader_creator__, is_test=False)
test_creator = functools.partial(__reader_creator__, is_test=True)
train = functools.partial(__reader_creator__, is_test=False)
test = functools.partial(__reader_creator__, is_test=True)
def get_movie_title_dict():
......@@ -113,10 +118,27 @@ def get_movie_title_dict():
return MOVIE_TITLE_DICT
def __max_index_info__(a, b):
if a.index > b.index:
return a
else:
return b
def max_movie_id():
__initialize_meta_info__()
return reduce(__max_index_info__, MOVIE_INFO.viewvalues()).index
def max_user_id():
__initialize_meta_info__()
return reduce(__max_index_info__, USER_INFO.viewvalues()).index
def unittest():
for train_count, _ in enumerate(train_creator()()):
for train_count, _ in enumerate(train()()):
pass
for test_count, _ in enumerate(test_creator()()):
for test_count, _ in enumerate(test()()):
pass
print train_count, test_count
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册