提交 f3dd46e9 编写于 作者: L livc

initial complete recommender_system

上级 4b1b7cda
TODO: Write about https://github.com/PaddlePaddle/Paddle/tree/develop/demo/recommendation
# 个性化推荐
## 背景介绍
在网络技术不断发展和电子商务规模不断扩大的背景下,商品数量和种类快速增长,用户需要花费大量时间才能找到自己想买的商品,这就是信息超载问题。为了解决这个难题,个性化推荐系统(Recommender System)应运而生。
可以想见,最简单的推荐策略是当今热点,比如某宝爆款。再比如希望加入个性化信息,就是浏览过的某宝爆款,或者浏览过的某条消息。 问题是,商品和新闻这么多,大多都是没有浏览过的,而且浏览过的可能已经买过或不再需要了。那么怎样推荐才合适呢?
个性化推荐系统是信息过滤系统(Information Filtering System)的子集,它通过分析、挖掘用户行为,发现用户的个性化需求与兴趣特点,将用户可能感兴趣的信息或商品推荐给用户。与传统的搜索引擎不同,推荐系统不需要用户准确地描述出自己的需求,而是根据分析历史行为建模,主动提供满足用户兴趣和需求的信息。一般来说,个性化推荐比一般化的推荐更能吸引用户点击或购买。个性化推荐可以用在很多领域,如电影、音乐、电商和 Feed 流推荐等。
传统的推荐系统方法主要有:
- 协同过滤推荐[1](Collaborative Filtering Recommendation):协同过滤推荐技术是推荐系统中应用最广泛的技术之一。它一般使用 *K*近邻法(*k*-*NN*)收集并分析用户历史行为、活动、偏好,计算一个用户与其他用户的相似度或距离,利用目标用户的最近邻居用户对商品评价的加权评价值来预测目标用户对特定商品的喜好程度。协同过滤可以给用户推荐未购买过的新产品,缺点是对于没有任何行为的新用户存在冷启动的问题,其次是因为用户与商品之间的交互数据不够多造成的稀疏问题。
- 基于内容过滤[2](Content-based Filtering):利用商品的内容描述,抽象出有意义的特征,计算用户的兴趣和商品描述之间的相似度,来给用户做推荐。基于内容的推荐简单直接,不需要依据其他用户对商品的评价,同时可以比较商品间的相似度,但它要求得到商品的显示属性,对于没有任何行为的新用户同样存在冷启动的问题。
- 组合推荐[3](Hybrid Recommendation):运用不同的输入和技术共同进行推荐,以弥补各自推荐技术的缺点。
深度学习具有优秀的自动提取特征的能力,能够学习多层次的抽象特征表示,并对异质或跨域的内容信息进行学习,解决推荐系统常见的冷启动问题[4]。
## 效果展示
在电影推荐系统的场景中,我们可以根据所有电影的推荐得分排序,推荐给用户可能感兴趣的电影。
我们使用包含用户信息、电影信息与电影评分的数据集作为个性化推荐的应用场景。当我们训练好模型后,只需要输入对应的用户ID和电影ID,就可以得出一个匹配的分数,然后根据所有电影的推荐得分排序,推荐给用户可能感兴趣的电影。
```
Input movie_id: 9
Input user_id: 4
Prediction Score is 2.56
Input movie_id: 8
Input user_id: 2
Prediction Score is 3.13
```
## 模型概览
在构造推荐系统之前,我们先来了解一些业内模型。
### YouTube 的深度神经网络推荐系统
YouTube 是世界上最大的视频上传、分享、发现的网站,YouTube 推荐系统为超过 10 亿用户从不断增长的视频库内个性化定制内容。系统由两个神经网络组成:一个用于生成候选视频,另一个用于排名。系统结构如下图所示[5]:
![](image/YouTube_Overview.png)
候选生成网络从用户的 YouTube 活动历史中提取信息,然后从视频库中检索出几百个与用户相关的视频输出。系统把推荐问题视为极多种类别的多分类问题。
如下图所示,深度候选生成模型将嵌入的稀疏特征取平均后和一些稠密特征连接在一起,转换成适合隐藏层输入的固定宽度的向量。所有隐藏层是全连接的。在训练中,在取样的 softmax 输出上,使用梯度下降对交叉熵代价函数进行最小化。在服务中,使用近似最近邻(approximate nearest neighbor )查询来生成成百上千的候选视频推荐。
![](image/Deep_candidate_generation_model_architecture.png)
在排名模型中,特征根据贡献单个值还是多个值的集合被分为单价(univalent)特征和多价(multivalent)特征。例如,视频 ID 是单价特征,对应的多价特征就是用户最近看过的 N 个视频的 ID 集合。
如下图所示,嵌入的分类特征(包括单价特征和多价特征)带有共享的嵌入和归一化后的连续特征的乘幂。所有层都是全连接的。最后,使用 logistic regression 为每一个视频评出分数,排序后返回给用户。在实践中,需要给网络输入几百个特征。
![](image/Deep_ranking_network_architecture.png)
### 基于 RNN 的推荐系统模型
在这个场景中[6],我们先基于用户行为找到用户的最近邻用户群,然后将这些用户的商品评论文本转化为词向量作为训练数据。模型输出是用户喜欢某个商品的可能性大小。
#### 双向 RNN 模型
[双向 RNN]() 不仅可以访问上文信息,还可以访问下文信息。如下图所示,在单层模型中,前向 RNN 的输入为第一个词到最后一个词的顺序,后向 RNN 反之。对双向 RNN 输出的结果取平均值,经过线性转化后作为 softmax 的输入。在多层双向 RNN 中,每一层把前一层的记忆序列视为输入,然后计算这一层的记忆表示,最后取最后一层的输出做和单层一样的运算。
![](image/BiRNN_with_GRU_Cell.png)
#### 注意力机制模型
在上述双向 RNN 模型的基础上,我们引入了[注意力机制]()。每个前向和后向单元的状态被连接成一个输出向量,之后通过一个注意力权重向量的集合转化为标量。每个单元的标量再连接成一个新向量,这个向量输入到最后的预测层来生成最终结果。
![](image/Attention_Based_BiRNN_with_GRU_cell.png)
接下来,我们将使用神经网络构建自己的推荐系统。
## 数据准备
### 数据介绍与下载
此教程我们使用[MovieLens 数据集](http://grouplens.org/datasets/movielens/)。该数据集包含一些用户信息、电影信息以及电影评分,由 GroupLens Research 实验室搜集整理。
根据数据规模的不同,该数据集也有很多不同的版本。这里我们用 [MovieLens 百万数据集(ml-1m)](http://files.grouplens.org/datasets/movielens/ml-1m.zip)作为示例,其中包含 6,000 位用户对 4,000 部电影的 1,000,000 条评价。该数据集于 2003 年 2 月发布。当一个新的用户进入 MovieLens ,他需对 15 部电影评分,评分范围为 1-5 分,评分间隔为 0.5 分。当用户查看意图电影时,MovieLens 的推荐系统将根据用户以往的评分预测其对该电影的评分[7]。
运行 `data/getdata.sh` 下载数据,`data/ml-1m` 的目录结构:
```
+--ml-1m
+--- movies.dat # 电影特征
+--- ratings.dat # 评分
+--- users.dat # 用户特征
+--- README # 数据集描述
```
ml-1m 中的数据文件使用 "::" 作为分隔符。数据格式为(更多细节请参阅 ml-1m 中的 README ):
- 评分数据(ratings.dat):用户ID::电影ID::评分::时间戳
- 电影特征数据(movies.dat):电影ID::电影名称::电影类型
- 用户特征数据(users.dat): 用户ID::性别::年龄::职业::邮编
### 数据预处理
首先安装 Python 第三方库(推荐使用 Virtualenv):
```shell
pip install -r data/requirements.txt
```
整个预处理过程分为处理输入用户、电影特征(数据文件序列化)和将数据分为训练、测试集两部分。执行 `./preprocess.sh` 即可。
在处理输入特征中,得到每个字段(movies/users)的字段配置文件,将其转化为可以解析数据集的 meta 文件,然后用 meta 文件解析数据集为 Python 对象并序列化。
在分割训练、测试集部分中,根据`ratings.dat`将数据分为两部分,分别用来进行模型训练和测试。
### 提供数据给 PaddlePaddle
数据提供脚本 `dataprovider.py` 会读取 `meta.bin` 和评分文件,生成训练需要的样本。在这个脚本中,我们需要设置:
- obj.slots: 特征的类型和维度。
- use_seq: `dataprovider.py` 中的数据是否为序列模式。
```python
from paddle.trainer.PyDataProvider2 import *
import common_utils # 解析
def __list_to_map__(lst): # 输出格式
ret_val = dict()
for each in lst:
k, v = each
ret_val[k] = v
return ret_val
```
```python
def hook(settings, meta, **kwargs):
"""
初始 hook 设置了 obj.slots 并存储 meta 数据。它将在处理数据前被唤起。
:参数对象: global object. It will passed to process routine.
:类型对象: object
:参数 meta: meta file 对象,通过 trainer_config 传递(记录了电影和用户的特征)
:参数 kwargs: 其他未用过的参数
"""
del kwargs
# Header 定义了 paddle 使用的 slots.
# 第一部分是电影特征
# 第二部分是用户特征
# 最后的部分是评分分数
# header 是一个 [USE_SEQ_OR_NOT?, SlotType] 的 list
movie_headers = list(common_utils.meta_to_header(meta, 'movie'))
settings.movie_names = [h[0] for h in movie_headers]
headers = movie_headers
user_headers = list(common_utils.meta_to_header(meta, 'user'))
settings.user_names = [h[0] for h in user_headers]
headers.extend(user_headers)
headers.append(("rating", dense_vector(1))) # 分数
# slot 类型
settings.input_types = __list_to_map__(headers)
settings.meta = meta
```
接下来,在`process`函数中将数据逐一提供给 PaddlePaddle。
```python
@provider(init_hook=hook, cache=CacheType.CACHE_PASS_IN_MEM)
def process(settings, filename):
with open(filename, 'r') as f:
for line in f:
# 读取评分
user_id, movie_id, score = map(int, line.split('::')[:-1])
# 将分数范围放缩到 [-5, +5]
score = float(score) * 2 - 5.0
# 读取 电影/用户 特征
movie_meta = settings.meta['movie'][movie_id]
user_meta = settings.meta['user'][user_id]
outputs = [('movie_id', movie_id - 1)]
# 添加电影特征
for i, each_meta in enumerate(movie_meta):
outputs.append((settings.movie_names[i + 1], each_meta))
# 添加用户ID
outputs.append(('user_id', user_id - 1))
# 添加用户特征
for i, each_meta in enumerate(user_meta):
outputs.append((settings.user_names[i + 1], each_meta))
# 最后添加分数
outputs.append(('rating', [score]))
# 将数据提供给 paddle
yield __list_to_map__(outputs)
```
## 模型配置说明
### 数据定义
定义通过 `define_py_data_sources2` 从 dataprovider 中读入数据:
```python
define_py_data_sources2(
'data/train.list',
'data/test.list',
module='dataprovider',
obj='process',
args={'meta': meta})
```
### 算法配置
这里我们设置了 batch size、网络初始学习率,并设置RMSProp 优化方法为自适应学习率策略。
```python
settings(
batch_size=1600, learning_rate=1e-3, learning_method=RMSPropOptimizer())
```
### 模型结构
网络结构如下图所示:
![](image/rec_regression_network.png)
文件 `trainer_config.py` 中`construct_feature` 函数用来构建电影/用户特征,我们将每个特征种类映射到一个特征向量中:
```python
def construct_feature(name):
__meta__ = meta[name]['__meta__']['raw_meta']
fusion = []
for each_meta in __meta__: # 读入数据
type_name = each_meta['type']
slot_name = each_meta.get('name', '%s_id' % name)
if type_name == 'id': # id:简单的嵌入,然后添加一个全连接层
slot_dim = each_meta['max']
embedding = embedding_layer(
input=data_layer(
slot_name, size=slot_dim), size=256)
fusion.append(fc_layer(input=embedding, size=256))
elif type_name == 'embedding': # embedding:如果是序列,则先做嵌入,然后再做一次文本卷积操作, 然后得到平均采样的结果。否则,则先做嵌入,然后添加一个全连接层。
is_seq = each_meta['seq'] == 'sequence'
slot_dim = len(each_meta['dict'])
din = data_layer(slot_name, slot_dim)
embedding = embedding_layer(input=din, size=256)
if is_seq:
fusion.append(
text_conv_pool(
input=embedding, context_len=5, hidden_size=256))
else:
fusion.append(fc_layer(input=embedding, size=256))
elif type_name == 'one_hot_dense': # one_hot_dense:两个全连接层。
slot_dim = len(each_meta['dict'])
hidden = fc_layer(input=data_layer(slot_name, slot_dim), size=256)
fusion.append(fc_layer(input=hidden, size=256))
# 聚集所有特征向量,使用全连接层连接它们并返回。
return fc_layer(name="%s_fusion" % name, input=fusion, size=256)
```
然后我们出求这两个特征的余弦相似度并输出。
```python
movie_feature = construct_feature("movie")
user_feature = construct_feature("user")
similarity = cos_sim(a=movie_feature, b=user_feature)
if not is_predict:
outputs(
regression_cost(
input=similarity, label=data_layer(
'rating', size=1)))
define_py_data_sources2(
'data/train.list',
'data/test.list',
module='dataprovider',
obj='process',
args={'meta': meta})
else:
outputs(similarity)
```
## 训练模型
执行`sh train.sh` 开始训练模型,将日志写入文件 `log.txt` 并打印在屏幕上。其中指定了总共需要执行 50 个pass。
```shell
set -e
paddle train \
--config=trainer_config.py \ # 神经网络配置文件
--save_dir=./output \ # 模型保存路径
--use_gpu=false \ # 是否使用 GPU (默认不使用)
--trainer_count=4\ # 一台机器上面的线程数量
--test_all_data_in_one_period=true \ # 每个训练周期训练一次所有数据。否则每个训练周期测试 batch_size 个 batch 的数据。
--log_period=100 \ # 训练 log_period 个 batch 后打印日志
--dot_period=1 \ # 每训练 dot_period 个 batch 后打印一个"."
--num_passes=50 2>&1 | tee 'log.txt'
```
如果训练过程启动成功的话,输出应该类似如下:
```shell
I0601 08:07:22.832059 10549 TrainerInternal.cpp:157] Batch=100 samples=160000 AvgCost=4.13494 CurrentCost=4.13494 Eval: CurrentEval:
I0601 08:07:50.672627 10549 TrainerInternal.cpp:157] Batch=200 samples=320000 AvgCost=3.80957 CurrentCost=3.48421 Eval: CurrentEval:
I0601 08:08:18.877369 10549 TrainerInternal.cpp:157] Batch=300 samples=480000 AvgCost=3.68145 CurrentCost=3.42519 Eval: CurrentEval:
I0601 08:08:46.863963 10549 TrainerInternal.cpp:157] Batch=400 samples=640000 AvgCost=3.6007 CurrentCost=3.35847 Eval: CurrentEval:
I0601 08:09:15.413025 10549 TrainerInternal.cpp:157] Batch=500 samples=800000 AvgCost=3.54811 CurrentCost=3.33773 Eval: CurrentEval:
I0601 08:09:36.058670 10549 TrainerInternal.cpp:181] Pass=0 Batch=565 samples=902826 AvgCost=3.52368 Eval:
I0601 08:09:46.215489 10549 Tester.cpp:101] Test samples=97383 cost=3.32155 Eval:
I0601 08:09:46.215966 10549 GradientMachine.cpp:132] Saving parameters to ./output/model/pass-00000
I0601 08:09:46.233397 10549 ParamUtil.cpp:99] save dir ./output/model/pass-00000
I0601 08:09:46.233438 10549 Util.cpp:209] copy trainer_config.py to ./output/model/pass-00000
I0601 08:09:46.233541 10549 ParamUtil.cpp:147] fileName trainer_config.py
```
## 应用模型
在训练了几轮以后,你可以对模型进行评估,通过选择最小训练误差的一轮参数得到最好轮次的模型。运行下面命令即可:
```shell
./evaluate.py log.txt
```
你将看到:
```shell
Best pass is 00009, error is 3.06949, which means predict get error as 0.875998002281
evaluating from pass output/pass-00009
```
然后,你可以预测任何用户对于任何一部电影的评价,运行下面命令即可:
```shell
python prediction.py 'output/pass-00009/'
```
预测程序将读取用户的输入,然后输出预测分数(注意这里的分数不是 dataprovider 归一化后的得分,而是预测的最终得分结果)。
用户预测的命令行界面如下:
```
Input movie_id: 9
Input user_id: 4
Prediction Score is 2.56
Input movie_id: 8
Input user_id: 2
Prediction Score is 3.13
```
## 总结
本章介绍了传统的推荐系统方法、 YouTube 的深度神经网络推荐系统和基于 RNN 的推荐系统模型,并以电影推荐为例,使用神经网络训练了一个个性化推荐模型。推荐系统几乎涵盖了电商系统、社交网络,广告推荐,搜索引擎等领域的方方面面,而在图像处理、自然语言处理等领域已经发挥重要作用的深度学习技术,也将在推荐系统领域大放异彩。
## 参考文献
1. Breese, John S., David Heckerman, and Carl Kadie. ["Empirical analysis of predictive algorithms for collaborative filtering."](https://arxiv.org/pdf/1301.7363v1.pdf) Proceedings of the Fourteenth conference on Uncertainty in artificial intelligence. Morgan Kaufmann Publishers Inc., 1998. APA
2. [Peter Brusilovsky](https://en.wikipedia.org/wiki/Peter_Brusilovsky) (2007). *The Adaptive Web*. p. 325. [ISBN](https://en.wikipedia.org/wiki/International_Standard_Book_Number) [978-3-540-72078-2](https://en.wikipedia.org/wiki/Special:BookSources/978-3-540-72078-2).
3. Robin Burke , [Hybrid Web Recommender Systems](http://www.dcs.warwick.ac.uk/~acristea/courses/CS411/2010/Book%20-%20The%20Adaptive%20Web/HybridWebRecommenderSystems.pdf), pp. 377-408, The Adaptive Web, Peter Brusilovsky, Alfred Kobsa, Wolfgang Nejdl (Ed.), Lecture Notes in Computer Science, Springer-Verlag, Berlin, Germany, Lecture Notes in Computer Science, Vol. 4321, May 2007, 978-3-540-72078-2.
4. Yuan, Jianbo, et al. ["Solving Cold-Start Problem in Large-scale Recommendation Engines: A Deep Learning Approach."](https://arxiv.org/pdf/1611.05480v1.pdf) *arXiv preprint arXiv:1611.05480* (2016).
5. Covington P, Adams J, Sargin E. [Deep neural networks for youtube recommendations](http://delivery.acm.org/10.1145/2960000/2959190/p191-covington.pdf?ip=113.225.222.231&id=2959190&acc=OA&key=4D4702B0C3E38B35%2E4D4702B0C3E38B35%2E4D4702B0C3E38B35%2E5945DC2EABF3343C&CFID=713293170&CFTOKEN=33777789&__acm__=1483689091_3196ba42120e35d98a6adbf5feed64a0)[C]//Proceedings of the 10th ACM Conference on Recommender Systems. ACM, 2016: 191-198.
MLA
6. https://cs224d.stanford.edu/reports/LiuSingh.pdf
7. https://zh.wikipedia.org/wiki/MovieLens
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer.PyDataProvider2 import *
def meta_to_header(meta, name):
metas = meta[name]['__meta__']['raw_meta']
for each_meta in metas:
slot_name = each_meta.get('name', '%s_id' % name)
if each_meta['type'] == 'id':
yield slot_name, integer_value(each_meta['max'])
elif each_meta['type'] == 'embedding':
is_seq = each_meta['seq'] == 'sequence'
yield slot_name, integer_value(
len(each_meta['dict']),
seq_type=SequenceType.SEQUENCE
if is_seq else SequenceType.NO_SEQUENCE)
elif each_meta['type'] == 'one_hot_dense':
yield slot_name, dense_vector(len(each_meta['dict']))
{
"user": {
"file": {
"name": "users.dat",
"delimiter": "::"
},
"fields": ["id", "gender", "age", "occupation"]
},
"movie": {
"file": {
"name": "movies.dat",
"delimiter": "::"
},
"fields": ["id", "title", "genres"]
}
}
#!/bin/env python2
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
config_generator.py
Usage:
./config_generator.py <config_file> [--output_format=<output_format>]
./config_generator.py -h | --help
Options:
-h --help Show this screen.
--output_format=<output_format> Output Config format(json or yaml) [default: json].
"""
import json
import docopt
import copy
DEFAULT_FILE = {"type": "split", "delimiter": ","}
DEFAULT_FIELD = {
"id": {
"type": "id"
},
"gender": {
"name": "gender",
"type": "embedding",
"dict": {
"type": "char_based"
}
},
"age": {
"name": "age",
"type": "embedding",
"dict": {
"type": "whole_content",
"sort": True
}
},
"occupation": {
"name": "occupation",
"type": "embedding",
"dict": {
"type": "whole_content",
"sort": "true"
}
},
"title": {
"regex": {
"pattern": r"^(.*)\((\d+)\)$",
"group_id": 1,
"strip": True
},
"name": "title",
"type": {
"name": "embedding",
"seq_type": "sequence",
},
"dict": {
"type": "char_based"
}
},
"genres": {
"type": "one_hot_dense",
"dict": {
"type": "split",
"delimiter": "|"
},
"name": "genres"
}
}
def merge_dict(master_dict, slave_dict):
return dict(((k, master_dict.get(k) or slave_dict.get(k))
for k in set(slave_dict) | set(master_dict)))
def main(filename, fmt):
with open(filename, 'r') as f:
conf = json.load(f)
obj = dict()
for k in conf:
val = conf[k]
file_dict = val['file']
file_dict = merge_dict(file_dict, DEFAULT_FILE)
fields = []
for pos, field_key in enumerate(val['fields']):
assert isinstance(field_key, basestring)
field = copy.deepcopy(DEFAULT_FIELD[field_key])
field['pos'] = pos
fields.append(field)
obj[k] = {"file": file_dict, "fields": fields}
meta = {"meta": obj}
# print meta
if fmt == 'json':
def formatter(x):
import json
return json.dumps(x, indent=2)
elif fmt == 'yaml':
def formatter(x):
import yaml
return yaml.safe_dump(x, default_flow_style=False)
else:
raise NotImplementedError("Dump format %s is not implemented" % fmt)
print formatter(meta)
if __name__ == '__main__':
args = docopt.docopt(__doc__, version="0.1.0")
main(args["<config_file>"], args["--output_format"])
#!/bin/bash
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -ex
cd "$(dirname "$0")"
# download the dataset
wget http://files.grouplens.org/datasets/movielens/ml-1m.zip
# unzip the dataset
unzip ml-1m.zip
# remove the unused zip file
rm ml-1m.zip
#!/bin/env python2
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess Movielens dataset, to get movie/user object.
Usage:
./preprocess.py <dataset_dir> <binary_filename> [--config=<config_file>]
./preprocess.py -h | --help
Options:
-h --help Show this screen.
--version Show version.
--config=<config_file> Get MetaData config file [default: config.json].
"""
import docopt
import os
import sys
import re
import collections
try:
import cPickle as pickle
except ImportError:
import pickle
class UniqueIDGenerator(object):
def __init__(self):
self.pool = collections.defaultdict(self.__next_id__)
self.next_id = 0
def __next_id__(self):
tmp = self.next_id
self.next_id += 1
return tmp
def __call__(self, k):
return self.pool[k]
def to_list(self):
ret_val = [None] * len(self.pool)
for k in self.pool.keys():
ret_val[self.pool[k]] = k
return ret_val
class SortedIDGenerator(object):
def __init__(self):
self.__key_set__ = set()
self.dict = None
def scan(self, key):
self.__key_set__.add(key)
def finish_scan(self, compare=None, key=None, reverse=False):
self.__key_set__ = sorted(
list(self.__key_set__), cmp=compare, key=key, reverse=reverse)
self.dict = dict()
for idx, each_key in enumerate(self.__key_set__):
self.dict[each_key] = idx
def __call__(self, key):
return self.dict[key]
def to_list(self):
return self.__key_set__
class SplitFileReader(object):
def __init__(self, work_dir, config):
assert isinstance(config, dict)
self.filename = config['name']
self.delimiter = config.get('delimiter', ',')
self.work_dir = work_dir
def read(self):
with open(os.path.join(self.work_dir, self.filename), 'r') as f:
for line in f:
line = line.strip()
if isinstance(self.delimiter, unicode):
self.delimiter = str(self.delimiter)
yield line.split(self.delimiter)
@staticmethod
def create(work_dir, config):
assert isinstance(config, dict)
if config['type'] == 'split':
return SplitFileReader(work_dir, config)
class IFileReader(object):
READERS = [SplitFileReader]
def read(self):
raise NotImplementedError()
@staticmethod
def create(work_dir, config):
for reader_cls in IFileReader.READERS:
val = reader_cls.create(work_dir, config)
if val is not None:
return val
class IDFieldParser(object):
TYPE = 'id'
def __init__(self, config):
self.__max_id__ = -sys.maxint - 1
self.__min_id__ = sys.maxint
self.__id_count__ = 0
def scan(self, line):
idx = int(line)
self.__max_id__ = max(self.__max_id__, idx)
self.__min_id__ = min(self.__min_id__, idx)
self.__id_count__ += 1
def parse(self, line):
return int(line)
def meta_field(self):
return {
"is_key": True,
'max': self.__max_id__,
'min': self.__min_id__,
'count': self.__id_count__,
'type': 'id'
}
class SplitEmbeddingDict(object):
def __init__(self, delimiter):
self.__id__ = UniqueIDGenerator()
self.delimiter = delimiter
def scan(self, multi):
for val in multi.split(self.delimiter):
self.__id__(val)
def parse(self, multi):
return map(self.__id__, multi.split(self.delimiter))
def meta_field(self):
return self.__id__.to_list()
class EmbeddingFieldParser(object):
TYPE = 'embedding'
NO_SEQUENCE = "no_sequence"
SEQUENCE = "sequence"
class CharBasedEmbeddingDict(object):
def __init__(self, is_seq=True):
self.__id__ = UniqueIDGenerator()
self.is_seq = is_seq
def scan(self, s):
for ch in s:
self.__id__(ch)
def parse(self, s):
return map(self.__id__, s) if self.is_seq else self.__id__(s[0])
def meta_field(self):
return self.__id__.to_list()
class WholeContentDict(object):
def __init__(self, need_sort=True):
assert need_sort
self.__id__ = SortedIDGenerator()
self.__has_finished__ = False
def scan(self, txt):
self.__id__.scan(txt)
def meta_field(self):
if not self.__has_finished__:
self.__id__.finish_scan()
self.__has_finished__ = True
return self.__id__.to_list()
def parse(self, txt):
return self.__id__(txt)
def __init__(self, config):
try:
self.seq_type = config['type']['seq_type']
except TypeError:
self.seq_type = EmbeddingFieldParser.NO_SEQUENCE
if config['dict']['type'] == 'char_based':
self.dict = EmbeddingFieldParser.CharBasedEmbeddingDict(
self.seq_type == EmbeddingFieldParser.SEQUENCE)
elif config['dict']['type'] == 'split':
self.dict = SplitEmbeddingDict(config['dict'].get('delimiter', ','))
elif config['dict']['type'] == 'whole_content':
self.dict = EmbeddingFieldParser.WholeContentDict(config['dict'][
'sort'])
else:
print config
assert False
self.name = config['name']
def scan(self, s):
self.dict.scan(s)
def meta_field(self):
return {
'name': self.name,
'dict': self.dict.meta_field(),
'type': 'embedding',
'seq': self.seq_type
}
def parse(self, s):
return self.dict.parse(s)
class OneHotDenseFieldParser(object):
TYPE = 'one_hot_dense'
def __init__(self, config):
if config['dict']['type'] == 'split':
self.dict = SplitEmbeddingDict(config['dict']['delimiter'])
self.name = config['name']
def scan(self, s):
self.dict.scan(s)
def meta_field(self):
# print self.dict.meta_field()
return {
'dict': self.dict.meta_field(),
'name': self.name,
'type': 'one_hot_dense'
}
def parse(self, s):
ids = self.dict.parse(s)
retv = [0.0] * len(self.dict.meta_field())
for idx in ids:
retv[idx] = 1.0
# print retv
return retv
class FieldParserFactory(object):
PARSERS = [IDFieldParser, EmbeddingFieldParser, OneHotDenseFieldParser]
@staticmethod
def create(config):
if isinstance(config['type'], basestring):
config_type = config['type']
elif isinstance(config['type'], dict):
config_type = config['type']['name']
assert config_type is not None
for each_parser_cls in FieldParserFactory.PARSERS:
if config_type == each_parser_cls.TYPE:
return each_parser_cls(config)
print config
class CompositeFieldParser(object):
def __init__(self, parser, extractor):
self.extractor = extractor
self.parser = parser
def scan(self, *args, **kwargs):
self.parser.scan(self.extractor.extract(*args, **kwargs))
def parse(self, *args, **kwargs):
return self.parser.parse(self.extractor.extract(*args, **kwargs))
def meta_field(self):
return self.parser.meta_field()
class PositionContentExtractor(object):
def __init__(self, pos):
self.pos = pos
def extract(self, line):
assert isinstance(line, list)
return line[self.pos]
class RegexPositionContentExtractor(PositionContentExtractor):
def __init__(self, pos, pattern, group_id, strip=True):
PositionContentExtractor.__init__(self, pos)
pattern = pattern.strip()
self.pattern = re.compile(pattern)
self.group_id = group_id
self.strip = strip
def extract(self, line):
line = PositionContentExtractor.extract(self, line)
match = self.pattern.match(line)
# print line, self.pattern.pattern, match
assert match is not None
txt = match.group(self.group_id)
if self.strip:
txt.strip()
return txt
class ContentExtractorFactory(object):
def extract(self, line):
pass
@staticmethod
def create(config):
if 'pos' in config:
if 'regex' not in config:
return PositionContentExtractor(config['pos'])
else:
extra_args = config['regex']
return RegexPositionContentExtractor(
pos=config['pos'], **extra_args)
class MetaFile(object):
def __init__(self, work_dir):
self.work_dir = work_dir
self.obj = dict()
def parse(self, config):
config = config['meta']
ret_obj = dict()
for key in config.keys():
val = config[key]
assert 'file' in val
reader = IFileReader.create(self.work_dir, val['file'])
assert reader is not None
assert 'fields' in val and isinstance(val['fields'], list)
fields_config = val['fields']
field_parsers = map(MetaFile.__field_config_mapper__, fields_config)
for each_parser in field_parsers:
assert each_parser is not None
for each_block in reader.read():
for each_parser in field_parsers:
each_parser.scan(each_block)
metas = map(lambda x: x.meta_field(), field_parsers)
# print metas
key_index = filter(
lambda x: x is not None,
map(lambda (idx, meta): idx if 'is_key' in meta and meta['is_key'] else None,
enumerate(metas)))[0]
key_map = []
for i in range(min(key_index, len(metas))):
key_map.append(i)
for i in range(key_index + 1, len(metas)):
key_map.append(i)
obj = {'__meta__': {'raw_meta': metas, 'feature_map': key_map}}
for each_block in reader.read():
idx = field_parsers[key_index].parse(each_block)
val = []
for i, each_parser in enumerate(field_parsers):
if i != key_index:
val.append(each_parser.parse(each_block))
obj[idx] = val
ret_obj[key] = obj
self.obj = ret_obj
return ret_obj
@staticmethod
def __field_config_mapper__(conf):
assert isinstance(conf, dict)
extrator = ContentExtractorFactory.create(conf)
field_parser = FieldParserFactory.create(conf)
assert extrator is not None
assert field_parser is not None
return CompositeFieldParser(field_parser, extrator)
def dump(self, fp):
pickle.dump(self.obj, fp, pickle.HIGHEST_PROTOCOL)
def preprocess(binary_filename, dataset_dir, config, **kwargs):
assert isinstance(config, str)
with open(config, 'r') as config_file:
file_loader = None
if config.lower().endswith('.yaml'):
import yaml
file_loader = yaml
elif config.lower().endswith('.json'):
import json
file_loader = json
config = file_loader.load(config_file)
meta = MetaFile(dataset_dir)
meta.parse(config)
with open(binary_filename, 'wb') as outf:
meta.dump(outf)
if __name__ == '__main__':
args = docopt.docopt(__doc__, version='0.1.0')
kwargs = dict()
for key in args.keys():
if key != '--help':
param_name = key
assert isinstance(param_name, str)
param_name = param_name.replace('<', '')
param_name = param_name.replace('>', '')
param_name = param_name.replace('--', '')
kwargs[param_name] = args[key]
preprocess(**kwargs)
#!/bin/env python2
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Separate movielens 1m dataset to train/test file.
Usage:
./separate.py <input_file> [--test_ratio=<test_ratio>] [--delimiter=<delimiter>]
./separate.py -h | --help
Options:
-h --help Show this screen.
--version Show version.
--test_ratio=<test_ratio> Test ratio for separate [default: 0.1].
--delimiter=<delimiter> File delimiter [default: ,].
"""
import docopt
import collections
import random
def process(test_ratio, input_file, delimiter, **kwargs):
test_ratio = float(test_ratio)
rating_dict = collections.defaultdict(list)
with open(input_file, 'r') as f:
for line in f:
user_id = int(line.split(delimiter)[0])
rating_dict[user_id].append(line.strip())
with open(input_file + ".train", 'w') as train_file:
with open(input_file + ".test", 'w') as test_file:
for k in rating_dict.keys():
lines = rating_dict[k]
assert isinstance(lines, list)
random.shuffle(lines)
test_len = int(len(lines) * test_ratio)
for line in lines[:test_len]:
print >> test_file, line
for line in lines[test_len:]:
print >> train_file, line
if __name__ == '__main__':
args = docopt.docopt(__doc__, version='0.1.0')
kwargs = dict()
for key in args.keys():
if key != '--help':
param_name = key
assert isinstance(param_name, str)
param_name = param_name.replace('<', '')
param_name = param_name.replace('>', '')
param_name = param_name.replace('--', '')
kwargs[param_name] = args[key]
process(**kwargs)
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer.PyDataProvider2 import *
import common_utils # parse
def __list_to_map__(lst):
ret_val = dict()
for each in lst:
k, v = each
ret_val[k] = v
return ret_val
def hook(settings, meta, **kwargs):
"""
Init hook is invoked before process data. It will set obj.slots and store
data meta.
:param obj: global object. It will passed to process routine.
:type obj: object
:param meta: the meta file object, which passed from trainer_config. Meta
file record movie/user features.
:param kwargs: unused other arguments.
"""
del kwargs # unused kwargs
# Header define slots that used for paddle.
# first part is movie features.
# second part is user features.
# final part is rating score.
# header is a list of [USE_SEQ_OR_NOT?, SlotType]
movie_headers = list(common_utils.meta_to_header(meta, 'movie'))
settings.movie_names = [h[0] for h in movie_headers]
headers = movie_headers
user_headers = list(common_utils.meta_to_header(meta, 'user'))
settings.user_names = [h[0] for h in user_headers]
headers.extend(user_headers)
headers.append(("rating", dense_vector(1))) # Score
# slot types.
settings.input_types = __list_to_map__(headers)
settings.meta = meta
@provider(init_hook=hook, cache=CacheType.CACHE_PASS_IN_MEM)
def process(settings, filename):
with open(filename, 'r') as f:
for line in f:
# Get a rating from file.
user_id, movie_id, score = map(int, line.split('::')[:-1])
# Scale score to [-5, +5]
score = float(score) * 2 - 5.0
# Get movie/user features by movie_id, user_id
movie_meta = settings.meta['movie'][movie_id]
user_meta = settings.meta['user'][user_id]
outputs = [('movie_id', movie_id - 1)]
# Then add movie features
for i, each_meta in enumerate(movie_meta):
outputs.append((settings.movie_names[i + 1], each_meta))
# Then add user id.
outputs.append(('user_id', user_id - 1))
# Then add user features.
for i, each_meta in enumerate(user_meta):
outputs.append((settings.user_names[i + 1], each_meta))
# Finally, add score
outputs.append(('rating', [score]))
# Return data to paddle
yield __list_to_map__(outputs)
#!/usr/bin/python
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import re
import math
def get_best_pass(log_filename):
with open(log_filename, 'r') as f:
text = f.read()
pattern = re.compile('Test.*? cost=([0-9]+\.[0-9]+).*?pass-([0-9]+)',
re.S)
results = re.findall(pattern, text)
sorted_results = sorted(results, key=lambda result: float(result[0]))
return sorted_results[0]
log_filename = sys.argv[1]
log = get_best_pass(log_filename)
predict_error = math.sqrt(float(log[0])) / 2
print 'Best pass is %s, error is %s, which means predict get error as %f' % (
log[1], log[0], predict_error)
evaluate_pass = "output/pass-%s" % log[1]
print "evaluating from pass %s" % evaluate_pass
#!/bin/env python2
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from py_paddle import swig_paddle, DataProviderConverter
from common_utils import *
from paddle.trainer.config_parser import parse_config
try:
import cPickle as pickle
except ImportError:
import pickle
import sys
if __name__ == '__main__':
model_path = sys.argv[1]
swig_paddle.initPaddle('--use_gpu=0')
conf = parse_config("trainer_config.py", "is_predict=1")
network = swig_paddle.GradientMachine.createFromConfigProto(
conf.model_config)
assert isinstance(network, swig_paddle.GradientMachine)
network.loadParameters(model_path)
with open('./data/meta.bin', 'rb') as f:
meta = pickle.load(f)
headers = [h[1] for h in meta_to_header(meta, 'movie')]
headers.extend([h[1] for h in meta_to_header(meta, 'user')])
cvt = DataProviderConverter(headers)
while True:
movie_id = int(raw_input("Input movie_id: "))
user_id = int(raw_input("Input user_id: "))
movie_meta = meta['movie'][movie_id] # Query Data From Meta.
user_meta = meta['user'][user_id]
data = [movie_id - 1]
data.extend(movie_meta)
data.append(user_id - 1)
data.extend(user_meta)
print "Prediction Score is %.2f" % (
(network.forwardTest(cvt.convert([data]))[0]['value'][0][0] + 5)
/ 2)
#!/bin/bash
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
UNAME_STR=`uname`
if [[ ${UNAME_STR} == 'Linux' ]]; then
SHUF_PROG='shuf'
else
SHUF_PROG='gshuf'
fi
cd "$(dirname "$0")"
delimiter='::'
dir=ml-1m
cd data
echo 'generate meta config file'
python config_generator.py config.json > meta_config.json
echo 'generate meta file'
python meta_generator.py $dir meta.bin --config=meta_config.json
echo 'split train/test file'
python split.py $dir/ratings.dat --delimiter=${delimiter} --test_ratio=0.1
echo 'shuffle train file'
${SHUF_PROG} $dir/ratings.dat.train > ratings.dat.train
cp $dir/ratings.dat.test .
echo "./data/ratings.dat.train" > train.list
echo "./data/ratings.dat.test" > test.list
#!/bin/bash
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
paddle train \
--config=trainer_config.py \
--save_dir=./output \
--use_gpu=false \
--trainer_count=4\
--test_all_data_in_one_period=true \
--log_period=100 \
--dot_period=1 \
--num_passes=50 2>&1 | tee 'log.txt'
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers import *
try:
import cPickle as pickle
except ImportError:
import pickle
is_predict = get_config_arg('is_predict', bool, False)
META_FILE = 'data/meta.bin'
with open(META_FILE, 'rb') as f:
# load meta file
meta = pickle.load(f)
settings(
batch_size=1600, learning_rate=1e-3, learning_method=RMSPropOptimizer())
def construct_feature(name):
"""
Construct movie/user features.
This method read from meta data. Then convert feature to neural network due
to feature type. The map relation as follow.
* id: embedding => fc
* embedding:
is_sequence: embedding => context_projection => fc => pool
not sequence: embedding => fc
* one_hot_dense: fc => fc
Then gather all features vector, and use a fc layer to combined them as
return.
:param name: 'movie' or 'user'
:type name: basestring
:return: combined feature output
:rtype: LayerOutput
"""
__meta__ = meta[name]['__meta__']['raw_meta']
fusion = []
for each_meta in __meta__:
type_name = each_meta['type']
slot_name = each_meta.get('name', '%s_id' % name)
if type_name == 'id':
slot_dim = each_meta['max']
embedding = embedding_layer(
input=data_layer(
slot_name, size=slot_dim), size=256)
fusion.append(fc_layer(input=embedding, size=256))
elif type_name == 'embedding':
is_seq = each_meta['seq'] == 'sequence'
slot_dim = len(each_meta['dict'])
din = data_layer(slot_name, slot_dim)
embedding = embedding_layer(input=din, size=256)
if is_seq:
fusion.append(
text_conv_pool(
input=embedding, context_len=5, hidden_size=256))
else:
fusion.append(fc_layer(input=embedding, size=256))
elif type_name == 'one_hot_dense':
slot_dim = len(each_meta['dict'])
hidden = fc_layer(input=data_layer(slot_name, slot_dim), size=256)
fusion.append(fc_layer(input=hidden, size=256))
return fc_layer(name="%s_fusion" % name, input=fusion, size=256)
movie_feature = construct_feature("movie")
user_feature = construct_feature("user")
similarity = cos_sim(a=movie_feature, b=user_feature)
if not is_predict:
outputs(
regression_cost(
input=similarity, label=data_layer(
'rating', size=1)))
define_py_data_sources2(
'data/train.list',
'data/test.list',
module='dataprovider',
obj='process',
args={'meta': meta})
else:
outputs(similarity)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册