提交 66b6ae16 编写于 作者: Y Yu Yang

Complete api_train_v2

上级 dd47da0d
import paddle.v2 as paddle
import cPickle
import copy
def main():
paddle.init(use_gpu=False, trainer_count=3)
paddle.init(use_gpu=False)
movie_title_dict = paddle.dataset.movielens.get_movie_title_dict()
uid = paddle.layer.data(
name='user_id',
......@@ -86,19 +88,40 @@ def main():
if event.batch_id % 100 == 0:
print "Pass %d Batch %d Cost %.2f" % (
event.pass_id, event.batch_id, event.cost)
elif isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=paddle.reader.batched(
paddle.dataset.movielens.test(), batch_size=256))
print result.cost
trainer.train(
reader=paddle.reader.batched(
paddle.reader.shuffle(
paddle.dataset.movielens.train(), buf_size=8192),
paddle.reader.firstn(
paddle.reader.shuffle(
paddle.dataset.movielens.train(), buf_size=8192),
n=1000),
batch_size=256),
event_handler=event_handler,
reader_dict=reader_dict,
num_passes=10)
num_passes=1)
user_id = 234
movie_id = 345
user = paddle.dataset.movielens.user_info()[user_id]
movie = paddle.dataset.movielens.movie_info()[movie_id]
feature = user.value() + movie.value()
def reader():
yield feature
infer_dict = copy.copy(reader_dict)
del infer_dict['score']
print infer_dict
prediction = paddle.infer(
output=inference,
parameters=parameters,
reader=paddle.reader.batched(
reader, batch_size=32),
reader_dict=infer_dict)
print prediction
if __name__ == '__main__':
......
......@@ -6,7 +6,7 @@ import functools
__all__ = [
'train', 'test', 'get_movie_title_dict', 'max_movie_id', 'max_user_id',
'age_table', 'movie_categories', 'max_job_id'
'age_table', 'movie_categories', 'max_job_id', 'user_info', 'movie_info'
]
age_table = [1, 18, 25, 35, 45, 50, 56]
......@@ -24,6 +24,13 @@ class MovieInfo(object):
[MOVIE_TITLE_DICT[w.lower()] for w in self.title.split()]
]
def __str__(self):
return "<MovieInfo id(%d), title(%s), categories(%s)>" % (
self.index, self.title, self.categories)
def __repr__(self):
return self.__str__()
class UserInfo(object):
def __init__(self, index, gender, age, job_id):
......@@ -35,6 +42,14 @@ class UserInfo(object):
def value(self):
return [self.index, 0 if self.is_male else 1, self.age, self.job_id]
def __str__(self):
return "<UserInfo id(%d), gender(%s), age(%d), job(%d)>" % (
self.index, "M"
if self.is_male else "F", age_table[self.age], self.job_id)
def __repr__(self):
return str(self)
MOVIE_INFO = None
MOVIE_TITLE_DICT = None
......@@ -152,6 +167,16 @@ def movie_categories():
return CATEGORIES_DICT
def user_info():
__initialize_meta_info__()
return USER_INFO
def movie_info():
__initialize_meta_info__()
return MOVIE_INFO
def unittest():
for train_count, _ in enumerate(train()()):
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册