提交 927d956d 编写于 作者: zhaoyijin666's avatar zhaoyijin666

first commit

上级 84b2855e
# Youtube DNN推荐模型
以下是本例目录包含的文件以及对应说明:
```
├── README.md # 文档
├── README.cn.md # 中文文档
├── data # 示例数据
│   ├── download.sh # 数据下载脚本
├── infer.py # 预测脚本
├── network_conf.py # 模型网络配置
├── reader.py # data reader
├── train.py # 训练脚本
└── utils.py # 工具
└── data_processer.py # 数据预处理脚本
```
## 背景介绍
Youtube是世界最大的视频网站之一, 其推荐系统帮助10亿以上的用户,从海量视频中,发现个性化的内容。该推荐系统主要面临以下三个挑战:
- 规模: 许多现有的推荐算法证明在小数据量下运行良好,但不能满足YouTube这样庞大的用户群和内容库的场景,因此需要高度专业化的分布式学习算法和高效的线上服务。
- 新鲜度: YouTube内容库更新频率极高,每秒上传小时级别视频。 系统应及时追踪新上传的视频和用户的实时的行为,并且模型在推荐新/旧视频上有好的平衡能力。
- 噪音: 噪音来自于两方面,其一,用户历史行为稀疏,且有各种不可观测的外部因素,以及用户满意度不明确。其二,内容本身的数据是非结构化的。因此算法应更具有鲁棒性。
下图展示了整个推荐系统框图:
<p align="center">
<img src="images/recommendation_system.png" width="500" height="300" hspace='10'/> <br/>
Figure 1. 推荐系统框图
</p>
整个推荐系统有两部分组成: 召回(candidate generation)和排序(ranking)。
- 召回模型: 输入用户的历史行为, 从大规模的内容库中获得一个小集合(百级别)。召回出的视频与用户高度相关。一个用户是用其历史点击过的视频,搜索过的关键词,和人口统计相关的特征来表征。
- 排序模型: 采用更精细的特征计算得到排序分,对召回得到的候选集合中的视频排序。
## 召回模型简介
该推荐问题可以被建模成一个"超大规模多分类"问题。即在时刻$$t$$,为用户$$U$$(已知上下文信息$$C$$)在视频库$$V$$中预测出观看视频i的类别,
$$P(\omega_t=i|U,C)=\frac{e^{v_iu}}{\sum_{j\in V}^{ }e^{v_ju}}$$
其中$$u\in \mathbb{R}^N$$,是<用户,上下文信息>的高维向量表示。$$v_j\in \mathbb{R}^N$$是视频`j`的高维向量表示。DNN模型的目标是以用户信息和上下文信息为输入条件下,学习用户的高维向量表示,以此输入softmax分类器,来预测视频库中各个视频(类别)的观看概率。
下图展示了召回模型的网络结构:
<p align="center">
<img src="images/model_network.png" width="500" height="400" hspace='10'/> <br/>
Figure 2. 召回模型网络结构
</p>
- 输入层:用户的浏览序列、搜索序列、人口统计学特征、和其他上下文信息等
- embedding层:将用户浏览视频序列接embedding层,再做时间序列上的平均。对于搜索序列同样处理。
- 隐层:包含三个隐层,用RELU激活函数,最后一层隐层的输出即为高维向量表示$$u$$。
- 输出层: softmax层,输出视频库中各个视频(类别)的观看概率。在线上预测时,提取模型训练得到的softmax层内部的参数,作为视频$$v_j$$的高维向量表示。可利用类似局部敏感哈希(Locality Sensitive Hashing)用$$u$$查询最相关的N个视频。
## 数据预处理
本例模拟了用户的视频点击日志,作为样本数据。格式如下:
```
用户Id \t 所在省份 \t 所在城市 \t 历史点击的视频序列信息 \t 手机型号
历史点击的视频序列信息的格式为 视频信息1;视频信息2;...;视频信息K
视频信息的格式为 视频id:视频类目:视频标签1_视频标签2_视频标签3_...视频标签M
例如:
USER_ID_15 上海市 上海市 VIDEO_42:CATEGORY_9:TAG115;VIDEO_43:CATEGORY_9:TAG116_TAG115;VIDEO_44:CATEGORY_2:TAG117_TAG71 GO T5
```
运行以下命令可下载该样本数据。
```
sh download.sh
```
然后,脚本`data_preprocess.py`将对训练数据做预处理。具体使用方法参考如下说明:
```
usage: data_processor.py [-h] --train_set_path TRAIN_SET_PATH --output_dir
OUTPUT_DIR [--feat_appear_limit FEAT_APPEAR_LIMIT]
PaddlePaddle Deep Candidate Generation Example
optional arguments:
-h, --help show this help message and exit
--train_set_path TRAIN_SET_PATH
path of the train set
--output_dir OUTPUT_DIR
directory to output
--feat_appear_limit FEAT_APPEAR_LIMIT
the minimum number of feature values appears (default:
20)
```
该脚本的作用如下:
- 借鉴\[[2](#参考文献)\]中对特征的处理,过滤低频特征(样本中出现次数低于`feat_appear_limit`)。
- 对特征进行编码,生成字典`feature_dict.pkl`
- 统计每个视频出现的概率,保存至`item_freq.pkl`,提供给nce层使用。
例如可执行下列命令, 完成数据预处理:
```shell
python data_processor.py --train_set_path=./data/train.txt \
--output_dir=./output \
--feat_appear_limit=20
```
## 模型实现
下面是网络中各个部分的具体实现,相关代码均包含在 `./network_conf.py` 中。
### 输入层
```python
def _build_input_layer(self):
"""
build input layer
"""
self._history_clicked_items = paddle.layer.data(
name="history_clicked_items", type=paddle.data_type.integer_value_sequence(
len(self._feature_dict['history_clicked_items'])))
self._history_clicked_categories = paddle.layer.data(
name="history_clicked_categories", type=paddle.data_type.integer_value_sequence(
len(self._feature_dict['history_clicked_categories'])))
self._history_clicked_tags = paddle.layer.data(
name="history_clicked_tags", type=paddle.data_type.integer_value_sequence(
len(self._feature_dict['history_clicked_tags'])))
self._user_id = paddle.layer.data(
name="user_id", type=paddle.data_type.integer_value(
len(self._feature_dict['user_id'])))
self._province = paddle.layer.data(
name="province", type=paddle.data_type.integer_value(
len(self._feature_dict['province'])))
self._city = paddle.layer.data(
name="city", type=paddle.data_type.integer_value(len(self._feature_dict['city'])))
self._phone = paddle.layer.data(
name="phone", type=paddle.data_type.integer_value(len(self._feature_dict['phone'])))
self._target_item = paddle.layer.data(
name="target_item", type=paddle.data_type.integer_value(
len(self._feature_dict['history_clicked_items'])))
```
### Embedding层
每个输入特征都被embedding到固定维度的向量中。
```python
def _create_emb_attr(self, name):
"""
create embedding parameter
"""
return paddle.attr.Param(
name=name, initial_std=0.001, learning_rate=1, l2_rate=0, sparse_update=True)
def _build_embedding_layer(self):
"""
build embedding layer
"""
self._user_id_emb = paddle.layer.embedding(input=self._user_id,
size=64,
param_attr=self._create_emb_attr(
'_proj_user_id'))
self._province_emb = paddle.layer.embedding(input=self._province,
size=8,
param_attr=self._create_emb_attr(
'_proj_province'))
self._city_emb = paddle.layer.embedding(input=self._city,
size=16,
param_attr=self._create_emb_attr('_proj_city'))
self._phone_emb = paddle.layer.embedding(input=self._phone,
size=16,
param_attr=self._create_emb_attr('_proj_phone'))
self._history_clicked_items_emb = paddle.layer.embedding(
input=self._history_clicked_items,
size=64,
param_attr=self._create_emb_attr('_proj_history_clicked_items'))
self._history_clicked_categories_emb = paddle.layer.embedding(
input=self._history_clicked_categories,
size=8,
param_attr=self._create_emb_attr('_proj_history_clicked_categories'))
self._history_clicked_tags_emb = paddle.layer.embedding(
input=self._history_clicked_tags,
size=64,
param_attr=self._create_emb_attr('_proj_history_clicked_tags'))
```
### 隐层
我们对原paper中做了改进,历史用户点击视频序列,经过embedding后,不再是加权求平均。而是连接lstm层,将用户点击的先后次序纳入模型,再在时间序列上做最大池化,得到定长的向量表示,从而使模型学习到与点击时序相关的隐藏信息。考虑到数据规模与训练性能,我们只用了两个Relu层,也有不错的效果。
```python
self._rnn_cell = paddle.networks.simple_lstm(input=self._history_clicked_items_emb, size=64)
self._lstm_last = paddle.layer.pooling(
input=self._rnn_cell, pooling_type=paddle.pooling.Max())
self._avg_emb_cats = paddle.layer.pooling(input=self._history_clicked_categories_emb,
pooling_type=paddle.pooling.Avg())
self._avg_emb_tags = paddle.layer.pooling(input=self._history_clicked_tags_emb,
pooling_type=paddle.pooling.Avg())
self._fc_0 = paddle.layer.fc(
name="Relu1",
input=[self._lstm_last, self._user_id_emb,
self._city_emb, self._phone_emb],
size=self._dnn_layer_dims[0],
act=paddle.activation.Relu())
self._fc_1 = paddle.layer.fc(
name="Relu2",
input=self._fc_0,
size=self._dnn_layer_dims[1],
act=paddle.activation.Relu())
```
### 输出层
为了提高模型训练速度,使用噪声对比估计(Noise-contrastive estimation, NCE)\[[3](#参考文献)\]。将[数据预处理](#数据预处理)中产出的item_freq.pkl,也就是负样例的分布,作为nce层的参数。
```python
return paddle.layer.nce(
input=self._fc_1,
label=self._target_item,
num_classes=len(self._feature_dict['history_clicked_items']),
param_attr=paddle.attr.Param(name="nce_w"),
bias_attr=paddle.attr.Param(name="nce_b"),
act=paddle.activation.Sigmoid(),
num_neg_samples=5,
neg_distribution=self._item_freq)
```
## 训练
首先,准备`reader.py`,负责将输入原始数据中的特征,转为编码后的特征id。对一条训练数据,根据`window_size`产出多条训练样本给trainer,例如:
```
window_size=2
原始数据:
用户Id \t 所在省份 \t 所在城市 \t 视频信息1;视频信息2;...;视频信息K \t 手机型号
多条训练样本:
用户Id,所在省份,所在城市,[<unk>,历史点击视频1],[<unk>,历史点击视频类目1],[<unk>,历史点击视频标签1],手机型号,历史点击视频2
用户Id,所在省份,所在城市,[历史点击视频1,历史点击视频2],[历史点击视频类目1,历史点击视频类目2],[历史点击视频标签1,历史点击视频标签2],手机型号,历史点击视频3
用户Id,所在省份,所在城市,[历史点击视频2,历史点击视频3],[历史点击视频类目2,历史点击视频类目3],[历史点击视频标签2,历史点击视频标签3],手机型号,历史点击视频4
......
```
相关代码如下:
```python
for i in range(1, len(history_clicked_items_all)):
start = max(0, i - self._window_size)
history_clicked_items = history_clicked_items_all[start:i]
history_clicked_categories = history_clicked_categories_all[start:i]
history_clicked_tags_str = history_clicked_tags_all[start:i]
history_clicked_tags = []
for tags_a in history_clicked_tags_str:
for tag in tags_a.split("_"):
history_clicked_tags.append(int(tag))
target_item = history_clicked_items_all[i]
yield user_id, province, city, \
history_clicked_items, history_clicked_categories, \
history_clicked_tags, phone, target_item
```
```python
reader = Reader(feature_dict, args.window_size)
trainer.train(
paddle.batch(
paddle.reader.shuffle(
lambda: reader.train(args.train_set_path),
buf_size=7000), args.batch_size),
num_passes=args.num_passes,
feeding=feeding,
event_handler=event_handler)
```
接下去就可以开始训练了,可执行以下命令:
```shell
python train.py --train_set_path='./data/train.txt' \
--test_set_path='./data/test.txt' \
--model_output_dir='./output/model/' \
--feature_dict='./output/feature_dict.pkl' \
--item_freq='./output/item_freq.pkl'
```
## 预测
输入用户相关的特征,输出topN个最可能观看的视频,可执行以下命令:
```shell
python infer.py --infer_set_path='./data/infer.txt' \
--model_path='./output/model/model_pass_00000.tar.gz' \
--feature_dict='./output/feature_dict.pkl' \
--batch_size=50
```
## 参考文献
1. Covington, Paul, Jay Adams, and Emre Sargin. "Deep neural networks for youtube recommendations." Proceedings of the 10th ACM Conference on Recommender Systems. ACM, 2016.
2. https://code.google.com/archive/p/word2vec/
3. http://paddlepaddle.org/docs/develop/models/nce_cost/README.html
# Deep Neural Networks for YouTube Recommendations
## Introduction\[[1](#References)\]
YouTube is the world's largest platform for creating, sharing and discovering video content. Youtube recommendations are responsible for helping more than a billion users discover personalized content from an ever-growing corpus of videos.
- Scale: Many existing recommendation algorithm proven to work well on small problems fail to operate on massive scale. Highly specialized distributed learning algorithms and efficient serving systems are essential.
- Freshness: YouTube has a very dynamic corpus with many hours of video are uploaded per second. The recommendation system should model newly uploaded content as well as the latest actions taken by user.
- Noise: Historical user behavior on YouTube is inherently difficult to predict due to sparsity and a variety of unobservable external factors. Furthermore, the noisy implicit feedback signals instead of the ground truth of user satisfaction is observed, and metadata associated with content is poorly structured, which forces the algorithms to be robust.
The overall structure of the recommendation system is illustrated in Figure 1.
<p align="center">
<img src="images/recommendation_system.png" width="500" height="300" hspace='10'/> <br/>
Figure 1. Recommendation system architecture
</p>
The system is comprised of two neural networks: one for candidate generation and one for ranking.
- The candidate generation network: It takes events from the user's YouTube activity history as input and retrieves a small subset(hundreds) of videos, highly relevant to the user, from a large corpus. The similarity between users is expressed in terms of coarse features such as IDs of video watches, search query tokens and demographics.
- The ranking network: It accomplishes this task by assigning a score to each video according to a desired objective function using a rich set of features describing the video and user.
## Candidate Generation
Here, candidate generation is modeled as extreme multiclass classification where the prediction problem becomes accurately classifying a specific video watch $$\omega_t$$ at time $$t$$ among millions of video $$i$$ (classes) from a corpus $$V$$ based on user $$U$$ and context $$C$$,
$$P(\omega_t=i|U,C)=\frac{e^{v_iu}}{\sum_{j\in V}^{ }e^{v_ju}}$$
where $$u\in \mathbb{R}^N$$ represents a high-dimensional "embedding" of the user, context pair and the $$v_j\in \mathbb{R}^N$$ represent embeddings of each candidate video. The task of the deep neural network is to learn user embeddings $$u$$ as a function of the user's history and context that are useful for discriminating among videos with a softmax classifier.
Figure 2 shows the general network architecture of candidate generation model:
<p align="center">
<img src="images/model_network.png" width="500" height="400" hspace='10'/> <br/>
Figure 2. Candidate generation model architecture
</p>
- Input layer: A user's watch history is represented by a variable-length sequence of sparse video IDs, and search history is similarly represented by a variable-length sequence of search tokens.
- Embedding layer: The input features each is mapped to a fixed-sized dense vector representation via the embeddings, and then simply averaging the embeddings. The embeddings are learned jointly with all other model parameters through normal gradient descent back-propagation updates.
- Hidden layer: Features are concatenated into a wide first layer, followed by several layers of fully connected Rectified Linear Units (ReLU). The output of the last ReLU layer is the previous mentioned high-dimensional "embedding" of the user $$u$$, so called user vector.
- Output layer: A softmax classifier is connected to do discriminating millions of classes (videos). To speed up training process, a technique is applied that samples negative classes from background distribution with importance weighting. The previous mentioned high-dimensional "embedding" of the candidate video $$v_j$$ is obtained by weight and bias of the softmax layer. At serving time, the most likely N classes (videos) is computed for presenting to the user. To Score millions of items under a strict serving laterncy, the scoring problem reduces to a nearest neighbor search in the dot product space, and Locality Sensitive Hashing is relied on.
## Data Pre-processing
In this example, we moked click log of users as sample data, and its format is as follows:
```
user-id \t province \t city \t history-clicked-video-info-sequence \t phone
history-clicked-video-info-sequence is formated as
video-info1;video-info2;...;video-infoK
video-info is formated as
video-id:category:tag1_tag2_tag3_...tagM
For example:
USER_ID_15 Shanghai Shanghai VIDEO_42:CATEGORY_9:TAG115;VIDEO_43:CATEGORY_9:TAG116_TAG115;VIDEO_44:CATEGORY_2:TAG117_TAG71 GO T5
```
Run this code to download the sample data.
```
sh download.sh
```
Then, run `data_preprocess.py` for data pre-processiong. Refer to the following instructions:
```
usage: data_processor.py [-h] --train_set_path TRAIN_SET_PATH --output_dir
OUTPUT_DIR [--feat_appear_limit FEAT_APPEAR_LIMIT]
PaddlePaddle Deep Candidate Generation Example
optional arguments:
-h, --help show this help message and exit
--train_set_path TRAIN_SET_PATH
path of the train set
--output_dir OUTPUT_DIR
directory to output
--feat_appear_limit FEAT_APPEAR_LIMIT
the minimum number of feature values appears (default:
20)
```
The fucntion of this script is as follows:
- Filter low-frequency features\[[2](#References)\], which appears less than `feat_appear_limit` times.
- Encode features, and generate dictionary `feature_dict.pkl`.
- Count the probability of each video appears and write into `item_freq.pkl`, and provide it to NCE layer.
```
For example, run the following command to accomplish data pre-processing:
```shell
python data_processor.py --train_set_path=./data/train.txt \
--output_dir=./output \
--feat_appear_limit=20
```
## Model Implementaion
The details of model implementation is illustrated as follows. The code is in `./network_conf.py`.
### Input layer
```python
def _build_input_layer(self):
"""
build input layer
"""
self._history_clicked_items = paddle.layer.data(
name="history_clicked_items", type=paddle.data_type.integer_value_sequence(
len(self._feature_dict['history_clicked_items'])))
self._history_clicked_categories = paddle.layer.data(
name="history_clicked_categories", type=paddle.data_type.integer_value_sequence(
len(self._feature_dict['history_clicked_categories'])))
self._history_clicked_tags = paddle.layer.data(
name="history_clicked_tags", type=paddle.data_type.integer_value_sequence(
len(self._feature_dict['history_clicked_tags'])))
self._user_id = paddle.layer.data(
name="user_id", type=paddle.data_type.integer_value(
len(self._feature_dict['user_id'])))
self._province = paddle.layer.data(
name="province", type=paddle.data_type.integer_value(
len(self._feature_dict['province'])))
self._city = paddle.layer.data(
name="city", type=paddle.data_type.integer_value(len(self._feature_dict['city'])))
self._phone = paddle.layer.data(
name="phone", type=paddle.data_type.integer_value(len(self._feature_dict['phone'])))
self._target_item = paddle.layer.data(
name="target_item", type=paddle.data_type.integer_value(
len(self._feature_dict['history_clicked_items'])))
```
### Embedding layer
The each of input features is mapped to a fixed-sized dense vector representation
```python
def _create_emb_attr(self, name):
"""
create embedding parameter
"""
return paddle.attr.Param(
name=name, initial_std=0.001, learning_rate=1, l2_rate=0, sparse_update=True)
def _build_embedding_layer(self):
"""
build embedding layer
"""
self._user_id_emb = paddle.layer.embedding(input=self._user_id,
size=64,
param_attr=self._create_emb_attr(
'_proj_user_id'))
self._province_emb = paddle.layer.embedding(input=self._province,
size=8,
param_attr=self._create_emb_attr(
'_proj_province'))
self._city_emb = paddle.layer.embedding(input=self._city,
size=16,
param_attr=self._create_emb_attr('_proj_city'))
self._phone_emb = paddle.layer.embedding(input=self._phone,
size=16,
param_attr=self._create_emb_attr('_proj_phone'))
self._history_clicked_items_emb = paddle.layer.embedding(
input=self._history_clicked_items,
size=64,
param_attr=self._create_emb_attr('_proj_history_clicked_items'))
self._history_clicked_categories_emb = paddle.layer.embedding(
input=self._history_clicked_categories,
size=8,
param_attr=self._create_emb_attr('_proj_history_clicked_categories'))
self._history_clicked_tags_emb = paddle.layer.embedding(
input=self._history_clicked_tags,
size=64,
param_attr=self._create_emb_attr('_proj_history_clicked_tags'))
```
### Hiddern layer
We improves the original networks in \[[1](#References)\] by modifying that the embeddings of video watches are not simply averaged but are connected to a LSTM layer with max temporal pooling instead, so that the deep sequential information related to user interests can be learned well. Considering data scale and efficiency of training, we only apply two ReLU layers, which also leads to good performance.
```python
self._rnn_cell = paddle.networks.simple_lstm(input=self._history_clicked_items_emb, size=64)
self._lstm_last = paddle.layer.pooling(
input=self._rnn_cell, pooling_type=paddle.pooling.Max())
self._avg_emb_cats = paddle.layer.pooling(input=self._history_clicked_categories_emb,
pooling_type=paddle.pooling.Avg())
self._avg_emb_tags = paddle.layer.pooling(input=self._history_clicked_tags_emb,
pooling_type=paddle.pooling.Avg())
self._fc_0 = paddle.layer.fc(
name="Relu1",
input=[self._lstm_last, self._user_id_emb,
self._city_emb, self._phone_emb],
size=self._dnn_layer_dims[0],
act=paddle.activation.Relu())
self._fc_1 = paddle.layer.fc(
name="Relu2",
input=self._fc_0,
size=self._dnn_layer_dims[1],
act=paddle.activation.Relu())
```
### Output layer
To speed up training process, Noise-contrastive estimation, NCE\[[3](#references)\] is applied to sample negative classes from background distribution with importance weighting. The previous mentioned `item_freq.pkl`[data pre-processing](#data pre-processing) is used as neg_distribution.
```python
return paddle.layer.nce(
input=self._fc_1,
label=self._target_item,
num_classes=len(self._feature_dict['history_clicked_items']),
param_attr=paddle.attr.Param(name="nce_w"),
bias_attr=paddle.attr.Param(name="nce_b"),
act=paddle.activation.Sigmoid(),
num_neg_samples=5,
neg_distribution=self._item_freq)
```
## Train
First of all, prepare `reader.py`, the function of which is to convert raw features into encoding id. One piece of train data generates several data instances according to `window_size`, and then is fed into trainer.
```
window_size=2
train data:
user-id \t province \t city \t video-info1;video-info2;...;video-infoK \t phone
several data instances:
user-id,province,city,[<unk>,video-id1],[<unk>,category1],[<unk>,tags1],phone,video-id2
user-id,province,city,[video-id1,video-id2],[category1,category2],[tags1,tags2],phone,video-id3
user-id,province,city,[video-id2,video-id3],[category2,category3],[tags2,tags3],phone,video-id4
......
```
The relevant code is as follows:
```python
for i in range(1, len(history_clicked_items_all)):
start = max(0, i - self._window_size)
history_clicked_items = history_clicked_items_all[start:i]
history_clicked_categories = history_clicked_categories_all[start:i]
history_clicked_tags_str = history_clicked_tags_all[start:i]
history_clicked_tags = []
for tags_a in history_clicked_tags_str:
for tag in tags_a.split("_"):
history_clicked_tags.append(int(tag))
target_item = history_clicked_items_all[i]
yield user_id, province, city, \
history_clicked_items, history_clicked_categories, \
history_clicked_tags, phone, target_item
```
```python
reader = Reader(feature_dict, args.window_size)
trainer.train(
paddle.batch(
paddle.reader.shuffle(
lambda: reader.train(args.train_set_path),
buf_size=7000), args.batch_size),
num_passes=args.num_passes,
feeding=feeding,
event_handler=event_handler)
```
Then start training.
```shell
python train.py --train_set_path='./data/train.txt' \
--test_set_path='./data/test.txt' \
--model_output_dir='./output/model/' \
--feature_dict='./output/feature_dict.pkl' \
--item_freq='./output/item_freq.pkl'
```
## Use the model for prediction
Input user related features, and then get the most likely N videos for user.
```shell
python infer.py --infer_set_path='./data/infer.txt' \
--model_path='./output/model/model_pass_00000.tar.gz' \
--feature_dict='./output/feature_dict.pkl' \
--batch_size=50
```
## References
1. Covington, Paul, Jay Adams, and Emre Sargin. "Deep neural networks for youtube recommendations." Proceedings of the 10th ACM Conference on Recommender Systems. ACM, 2016.
2. https://code.google.com/archive/p/word2vec/
3. http://paddlepaddle.org/docs/develop/models/nce_cost/README.html
#!/bin/bash
wget https://yun/dac.tar.gz
tar -zxvf dac.tar.gz
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import argparse
import os
import cPickle
from utils import logger
"""
This script will output 2 files:
1. feature_dict.pkl
2. item_freq.pkl
"""
class FeatureGenerator(object):
"""
Encode feature values with low-frequency filtering.
"""
def __init__(self, feat_appear_limit=20):
"""
@feat_appear_limit: int
"""
self._dic = None # feature value --> id
self._count = None # numbers of appearances of feature values
self._feat_appear_limit = feat_appear_limit
def add_feat_val(self, feat_val):
"""
Add feature values and count numbers of its appearance.
"""
if self._count is None:
self._count = {'<unk>': 0}
if feat_val == "NULL":
feat_val = '<unk>'
if feat_val not in self._count:
self._count[feat_val] = 1
else:
self._count[feat_val] += 1
self._count['<unk>'] += 1
def _filter_feat(self):
"""
Filter low-frequency feature values.
"""
self._items = filter(lambda x: x[1] > self._feat_appear_limit,
self._count.items())
self._items.sort(key=lambda x: x[1], reverse=True)
def _build_dict(self):
"""
Build feature values --> ids dict.
"""
self._dic = {}
self._filter_feat()
for i in xrange(len(self._items)):
self._dic[self._items[i][0]] = i
self.dim = len(self._dic)
def get_feat_id(self, feat_val):
"""
Get id of feature value after encoding.
"""
# build dict
if self._dic is None:
self._build_dict()
# find id
if feat_val in self._dic:
return self._dic[feat_val]
else:
return self._dic['<unk>']
def get_dim(self):
"""
Get dim.
"""
# build dict
if self._dic is None:
self._build_dict()
return len(self._dic)
def get_dict(self):
"""
Get dict.
"""
# build dict
if self._dic is None:
self._build_dict()
return self._dic
def get_total_count(self):
"""
Compute total num of count.
"""
total_count = 0
for i in xrange(len(self._items)):
feat_val = self._items[i][0]
c = self._items[i][1]
total_count += c
return total_count
def count_iterator(self):
"""
Iterate feature values and its num of appearance.
"""
for i in xrange(len(self._items)):
yield self._items[i][0], self._items[i][1]
def __repr__(self):
"""
"""
return '<FeatureGenerator %d>' % self._dim
def scan_build_dict(data_path, features_dict):
"""
Scan the raw data and add all feature values.
"""
logger.info('scan data set')
with open(data_path, 'r') as f:
for (line_id, line) in enumerate(f):
fields = line.strip('\n').split('\t')
user_id = fields[0]
province = fields[1]
features_dict['province'].add_feat_val(province)
city = fields[2]
features_dict['city'].add_feat_val(city)
item_infos = fields[3]
phone = fields[4]
features_dict['phone'].add_feat_val(phone)
for item_info in item_infos.split(";"):
item_info_array = item_info.split(":")
item = item_info_array[0]
features_dict['history_clicked_items'].add_feat_val(item)
features_dict['user_id'].add_feat_val(user_id)
category = item_info_array[1]
features_dict['history_clicked_categories'].add_feat_val(
category)
tags = item_info_array[2]
for tag in tags.split("_"):
features_dict['history_clicked_tags'].add_feat_val(tag)
def parse_args():
"""
parse arguments
"""
parser = argparse.ArgumentParser(
description="PaddlePaddle Youtube Recall Model Example")
parser.add_argument(
'--train_set_path',
type=str,
required=True,
help="path of the train set")
parser.add_argument(
'--output_dir', type=str, required=True, help="directory to output")
parser.add_argument(
'--feat_appear_limit',
type=int,
default=20,
help="the minimum number of feature values appears (default: 20)")
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
# check argument
assert os.path.exists(
args.train_set_path), 'The train set path does not exist.'
# features used
features = [
'user_id', 'province', 'city', 'phone', 'history_clicked_items',
'history_clicked_tags', 'history_clicked_categories'
]
# init feature generators
features_dict = {}
for feature in features:
features_dict[feature] = FeatureGenerator(
feat_appear_limit=args.feat_appear_limit)
# scan data for building dict
scan_build_dict(args.train_set_path, features_dict)
# generate feature_dict.pkl
feature_encoding_dict = {}
for feature in features:
d = features_dict[feature].get_dict()
feature_encoding_dict[feature] = d
logger.info('Feature:%s, dimension is %d' % (feature, len(d)))
output_dict_path = os.path.join(args.output_dir, 'feature_dict.pkl')
with open(output_dict_path, "w") as f:
cPickle.dump(feature_encoding_dict, f, -1)
# generate item_freq.pkl
item_freq_list = []
g = features_dict['history_clicked_items']
total_count = g.get_total_count()
for feat_val, feat_count in g.count_iterator():
item_freq_list.append(float(feat_count) / total_count)
logger.info('item_freq, dimension is %d' % (len(item_freq_list)))
output_item_freq_path = os.path.join(args.output_dir, 'item_freq.pkl')
with open(output_item_freq_path, "w") as f:
cPickle.dump(item_freq_list, f, -1)
logger.info('Complete!')
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import gzip
import paddle.v2 as paddle
import argparse
import cPickle
from reader import Reader
from network_conf import DNNmodel
from utils import logger
def parse_args():
"""
parse arguments
:return:
"""
parser = argparse.ArgumentParser(
description="PaddlePaddle Youtube Recall Model Example")
parser.add_argument(
'--infer_set_path',
type=str,
required=True,
help="path of the infer set")
parser.add_argument(
'--model_path', type=str, required=True, help="path of the model")
parser.add_argument(
'--feature_dict',
type=str,
required=True,
help="path of feature_dict.pkl")
parser.add_argument(
'--batch_size',
type=int,
default=50,
help="size of mini-batch (default:50)")
return parser.parse_args()
def infer():
"""
infer
"""
args = parse_args()
# check argument
assert os.path.exists(
args.infer_set_path), 'The train_set_path path does not exist.'
assert os.path.exists(
args.model_path), 'The model_path path does not exist.'
assert os.path.exists(
args.feature_dict), 'The feature_dict path does not exist.'
paddle.init(use_gpu=False, trainer_count=1)
with open(args.feature_dict) as f:
feature_dict = cPickle.load(f)
nid_dict = feature_dict['history_clicked_items']
nid_to_word = dict((v, k) for k, v in nid_dict.items())
# load the trained model.
with gzip.open(args.model_path) as f:
parameters = paddle.parameters.Parameters.from_tar(f)
# build model
prediction_layer, fc = DNNmodel(
dnn_layer_dims=[256, 31], feature_dict=feature_dict,
is_infer=True).model_cost
inferer = paddle.inference.Inference(
output_layer=[prediction_layer, fc], parameters=parameters)
reader = Reader(feature_dict)
test_batch = []
for idx, item in enumerate(reader.infer(args.infer_set_path)):
test_batch.append(item)
if len(test_batch) == args.batch_size:
infer_a_batch(inferer, test_batch, nid_to_word)
test_batch = []
if len(test_batch):
infer_a_batch(inferer, test_batch, nid_to_word)
def infer_a_batch(inferer, test_batch, nid_to_word):
"""
input a batch of data and infer
"""
feeding = {
'user_id': 0,
'province': 1,
'city': 2,
'history_clicked_items': 3,
'history_clicked_categories': 4,
'history_clicked_tags': 5,
'phone': 6
}
probs = inferer.infer(
input=test_batch,
feeding=feeding,
field=["value"],
flatten_result=False)
for i, res in enumerate(zip(test_batch, probs[0], probs[1])):
print "Sample %s:" % str(i)
softmax_output = res[1]
sort_nid = res[1].argsort()
# 输出top 30的推荐视频id,title,分数
for j in range(1, 30):
item_id = sort_nid[-1 * j]
item_id_to_word = nid_to_word[item_id]
print "%s\t%.6f" \
% (item_id_to_word, softmax_output[item_id])
if __name__ == "__main__":
infer()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import paddle.v2 as paddle
import cPickle
class DNNmodel(object):
"""
Deep Neural Networks for YouTube candidate generation
"""
def __init__(self,
dnn_layer_dims=None,
feature_dict=None,
item_freq=None,
is_infer=False):
"""
initialize model
"""
self._dnn_layer_dims = dnn_layer_dims
self._feature_dict = feature_dict
self._item_freq = item_freq
self._is_infer = is_infer
# build model
self._build_input_layer()
self._build_embedding_layer()
self.model_cost = self._build_dnn_model()
def _build_input_layer(self):
"""
build input layer
"""
self._history_clicked_items = paddle.layer.data(
name="history_clicked_items",
type=paddle.data_type.integer_value_sequence(
len(self._feature_dict['history_clicked_items'])))
self._history_clicked_categories = paddle.layer.data(
name="history_clicked_categories",
type=paddle.data_type.integer_value_sequence(
len(self._feature_dict['history_clicked_categories'])))
self._history_clicked_tags = paddle.layer.data(
name="history_clicked_tags",
type=paddle.data_type.integer_value_sequence(
len(self._feature_dict['history_clicked_tags'])))
self._user_id = paddle.layer.data(
name="user_id",
type=paddle.data_type.integer_value(
len(self._feature_dict['user_id'])))
self._province = paddle.layer.data(
name="province",
type=paddle.data_type.integer_value(
len(self._feature_dict['province'])))
self._city = paddle.layer.data(
name="city",
type=paddle.data_type.integer_value(
len(self._feature_dict['city'])))
self._phone = paddle.layer.data(
name="phone",
type=paddle.data_type.integer_value(
len(self._feature_dict['phone'])))
self._target_item = paddle.layer.data(
name="target_item",
type=paddle.data_type.integer_value(
len(self._feature_dict['history_clicked_items'])))
def _create_emb_attr(self, name):
"""
create embedding parameter
"""
return paddle.attr.Param(
name=name,
initial_std=0.001,
learning_rate=1,
l2_rate=0,
sparse_update=True)
def _build_embedding_layer(self):
"""
build embedding layer
"""
self._user_id_emb = paddle.layer.embedding(
input=self._user_id,
size=64,
param_attr=self._create_emb_attr('_proj_user_id'))
self._province_emb = paddle.layer.embedding(
input=self._province,
size=8,
param_attr=self._create_emb_attr('_proj_province'))
self._city_emb = paddle.layer.embedding(
input=self._city,
size=16,
param_attr=self._create_emb_attr('_proj_city'))
self._phone_emb = paddle.layer.embedding(
input=self._phone,
size=16,
param_attr=self._create_emb_attr('_proj_phone'))
self._history_clicked_items_emb = paddle.layer.embedding(
input=self._history_clicked_items,
size=64,
param_attr=self._create_emb_attr('_proj_history_clicked_items'))
self._history_clicked_categories_emb = paddle.layer.embedding(
input=self._history_clicked_categories,
size=8,
param_attr=self._create_emb_attr(
'_proj_history_clicked_categories'))
self._history_clicked_tags_emb = paddle.layer.embedding(
input=self._history_clicked_tags,
size=64,
param_attr=self._create_emb_attr('_proj_history_clicked_tags'))
def _build_dnn_model(self):
"""
build dnn model
"""
self._rnn_cell = paddle.networks.simple_lstm(
input=self._history_clicked_items_emb, size=64)
self._lstm_last = paddle.layer.pooling(
input=self._rnn_cell, pooling_type=paddle.pooling.Max())
self._avg_emb_cats = paddle.layer.pooling(
input=self._history_clicked_categories_emb,
pooling_type=paddle.pooling.Avg())
self._avg_emb_tags = paddle.layer.pooling(
input=self._history_clicked_tags_emb,
pooling_type=paddle.pooling.Avg())
self._fc_0 = paddle.layer.fc(
name="Relu1",
input=[
self._lstm_last, self._user_id_emb, self._city_emb,
self._phone_emb
],
size=self._dnn_layer_dims[0],
act=paddle.activation.Relu())
self._fc_1 = paddle.layer.fc(
name="Relu2",
input=self._fc_0,
size=self._dnn_layer_dims[1],
act=paddle.activation.Relu())
if not self._is_infer:
return paddle.layer.nce(
input=self._fc_1,
label=self._target_item,
num_classes=len(self._feature_dict['history_clicked_items']),
param_attr=paddle.attr.Param(name="nce_w"),
bias_attr=paddle.attr.Param(name="nce_b"),
act=paddle.activation.Sigmoid(),
num_neg_samples=5,
neg_distribution=self._item_freq)
else:
self.prediction_layer = paddle.layer.mixed(
size=len(self._feature_dict['history_clicked_items']),
input=paddle.layer.trans_full_matrix_projection(
self._fc_1, param_attr=paddle.attr.Param(name="nce_w")),
act=paddle.activation.Softmax(),
bias_attr=paddle.attr.Param(name="nce_b"))
return self.prediction_layer, self._fc_1
if __name__ == "__main__":
# this is to test and debug the network topology defination.
# please set the hyper-parameters as needed.
item_freq_path = "./output/item_freq.pkl"
with open(item_freq_path) as f:
item_freq = cPickle.load(f)
feature_dict_path = "./output/feature_dict.pkl"
with open(feature_dict_path) as f:
feature_dict = cPickle.load(f)
a = DNNmodel(
dnn_layer_dims=[256, 31],
feature_dict=feature_dict,
item_freq=item_freq,
is_infer=False)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
from utils import logger
from utils import TaskMode
class Reader(object):
"""
Reader
"""
def __init__(self, feature_dict=None, window_size=20):
"""
init
@window_size: window_size
"""
self._feature_dict = feature_dict
self._window_size = window_size
def train(self, path):
"""
load train set
@path: train set path
"""
logger.info("start train reader from %s" % path)
mode = TaskMode.create_train()
return self._reader(path, mode)
def test(self, path):
"""
load test set
@path: test set path
"""
logger.info("start test reader from %s" % path)
mode = TaskMode.create_test()
return self._reader(path, mode)
def infer(self, path):
"""
load infer set
@path: infer set path
"""
logger.info("start infer reader from %s" % path)
mode = TaskMode.create_infer()
return self._reader(path, mode)
def _reader(self, path, mode):
"""
parse data set
"""
USER_ID_UNK = self._feature_dict['user_id'].get('<unk>')
PROVINCE_UNK = self._feature_dict['province'].get('<unk>')
CITY_UNK = self._feature_dict['city'].get('<unk>')
ITEM_UNK = self._feature_dict['history_clicked_items'].get('<unk>')
CATEGORY_UNK = self._feature_dict['history_clicked_categories'].get(
'<unk>')
TAG_UNK = self._feature_dict['history_clicked_tags'].get('<unk>')
PHONE_UNK = self._feature_dict['phone'].get('<unk>')
with open(path) as f:
for line in f:
fields = line.strip('\n').split('\t')
user_id = self._feature_dict['user_id'].get(fields[0],
USER_ID_UNK)
province = self._feature_dict['province'].get(fields[1],
PROVINCE_UNK)
city = self._feature_dict['city'].get(fields[2], CITY_UNK)
item_infos = fields[3]
phone = self._feature_dict['phone'].get(fields[4], PHONE_UNK)
history_clicked_items_all = []
history_clicked_tags_all = []
history_clicked_categories_all = []
for item_info in item_infos.split(';'):
item_info_array = item_info.split(':')
item = item_info_array[0]
item_encoded_id = self._feature_dict['history_clicked_items'].get(\
item, ITEM_UNK)
if item_encoded_id != ITEM_UNK:
history_clicked_items_all.append(item_encoded_id)
category = item_info_array[1]
history_clicked_categories_all.append(
self._feature_dict['history_clicked_categories'].get(\
category, CATEGORY_UNK))
tags = item_info_array[2]
tag_split = map(str, [self._feature_dict['history_clicked_tags'].get(\
tag, TAG_UNK) \
for tag in tags.strip().split("_")])
history_clicked_tags_all.append("_".join(tag_split))
if not mode.is_infer():
history_clicked_items_all.insert(0, 0)
history_clicked_tags_all.insert(0, "0")
history_clicked_categories_all.insert(0, 0)
for i in range(1, len(history_clicked_items_all)):
start = max(0, i - self._window_size)
history_clicked_items = history_clicked_items_all[start:
i]
history_clicked_categories = history_clicked_categories_all[
start:i]
history_clicked_tags_str = history_clicked_tags_all[
start:i]
history_clicked_tags = []
for tags_a in history_clicked_tags_str:
for tag in tags_a.split("_"):
history_clicked_tags.append(int(tag))
target_item = history_clicked_items_all[i]
yield user_id, province, city, \
history_clicked_items, history_clicked_categories, \
history_clicked_tags, phone, target_item
else:
history_clicked_items = history_clicked_items_all
history_clicked_categories = history_clicked_categories_all
history_clicked_tags_str = history_clicked_tags_all
history_clicked_tags = []
for tags_a in history_clicked_tags_str:
for tag in tags_a.split("_"):
history_clicked_tags.append(int(tag))
yield user_id, province, city, \
history_clicked_items, history_clicked_categories, \
history_clicked_tags, phone
if __name__ == "__main__":
# this is to test and debug reader function
train_data = sys.argv[1]
feature_dict = sys.argv[2]
window_size = int(sys.argv[3])
import cPickle
with open(feature_dict) as f:
feature_dict = cPickle.load(f)
r = Reader(feature_dict, window_size)
for dat in r.train(train_data):
print dat
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import gzip
import paddle.v2 as paddle
import argparse
import cPickle
from reader import Reader
from network_conf import DNNmodel
from utils import logger
def parse_args():
"""
parse arguments
"""
parser = argparse.ArgumentParser(
description="PaddlePaddle Youtube Recall Model Example")
parser.add_argument(
'--train_set_path',
type=str,
required=True,
help="path of the train set")
parser.add_argument(
'--test_set_path', type=str, required=True, help="path of the test set")
parser.add_argument(
'--model_output_dir',
type=str,
required=True,
help="directory to output")
parser.add_argument(
'--feature_dict',
type=str,
required=True,
help="path of feature_dict.pkl")
parser.add_argument(
'--item_freq', type=str, required=True, help="path of item_freq.pkl ")
parser.add_argument(
'--window_size', type=int, default=20, help="window size(default: 20)")
parser.add_argument(
'--num_passes', type=int, default=1, help="number of passes to train")
parser.add_argument(
'--batch_size',
type=int,
default=50,
help="size of mini-batch (default:50)")
return parser.parse_args()
def train():
"""
train
"""
args = parse_args()
# check argument
assert os.path.exists(
args.train_set_path), 'The train_set_path path does not exist.'
assert os.path.exists(
args.test_set_path), 'The test_set_path path does not exist.'
assert os.path.exists(
args.feature_dict), 'The feature_dict path does not exist.'
assert os.path.exists(args.item_freq), 'The item_freq path does not exist.'
assert os.path.exists(
args.model_output_dir), 'The model_output_dir path does not exist.'
paddle.init(use_gpu=False, trainer_count=1)
with open(args.feature_dict) as f:
feature_dict = cPickle.load(f)
with open(args.item_freq) as f:
item_freq = cPickle.load(f)
feeding = {
'user_id': 0,
'province': 1,
'city': 2,
'history_clicked_items': 3,
'history_clicked_categories': 4,
'history_clicked_tags': 5,
'phone': 6,
'target_item': 7
}
optimizer = paddle.optimizer.AdaGrad(
learning_rate=1e-1,
regularization=paddle.optimizer.L2Regularization(rate=1e-3))
cost = DNNmodel(
dnn_layer_dims=[256, 31],
feature_dict=feature_dict,
item_freq=item_freq,
is_infer=False).model_cost
parameters = paddle.parameters.create(cost)
trainer = paddle.trainer.SGD(cost, parameters, optimizer)
def event_handler(event):
"""
event handler
"""
if isinstance(event, paddle.event.EndIteration):
if event.batch_id and not event.batch_id % 10:
logger.info("Pass %d, Batch %d, Cost %f" %
(event.pass_id, event.batch_id, event.cost))
elif isinstance(event, paddle.event.EndPass):
save_path = os.path.join(args.model_output_dir,
"model_pass_%05d.tar.gz" % event.pass_id)
logger.info("Save model into %s ..." % save_path)
with gzip.open(save_path, "w") as f:
trainer.save_parameter_to_tar(f)
reader = Reader(feature_dict, args.window_size)
trainer.train(
paddle.batch(
paddle.reader.shuffle(
lambda: reader.train(args.train_set_path), buf_size=7000),
args.batch_size),
num_passes=args.num_passes,
feeding=feeding,
event_handler=event_handler)
if __name__ == "__main__":
train()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
logging.basicConfig()
logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
class TaskMode(object):
"""
TaskMode
"""
TRAIN_MODE = 0
TEST_MODE = 1
INFER_MODE = 2
def __init__(self, mode):
"""
:param mode:
"""
self.mode = mode
def is_train(self):
"""
:return:
"""
return self.mode == self.TRAIN_MODE
def is_test(self):
"""
:return:
"""
return self.mode == self.TEST_MODE
def is_infer(self):
"""
:return:
"""
return self.mode == self.INFER_MODE
@staticmethod
def create_train():
"""
:return:
"""
return TaskMode(TaskMode.TRAIN_MODE)
@staticmethod
def create_test():
"""
:return:
"""
return TaskMode(TaskMode.TEST_MODE)
@staticmethod
def create_infer():
"""
:return:
"""
return TaskMode(TaskMode.INFER_MODE)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册