提交 65001741 编写于 作者: zhaoyijin666's avatar zhaoyijin666

vector

上级 2880e92f
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
├── train.py # 训练脚本 ├── train.py # 训练脚本
└── utils.py # 工具 └── utils.py # 工具
└── data_processer.py # 数据预处理脚本 └── data_processer.py # 数据预处理脚本
└── user_vector.py # 获取用户向量脚本
└── item_vector.py # 获取视频向量脚本
``` ```
## 背景介绍 ## 背景介绍\[[1](#参考文献)\]
Youtube是世界最大的视频网站之一, 其推荐系统帮助10亿以上的用户,从海量视频中,发现个性化的内容。该推荐系统主要面临以下三个挑战: Youtube是世界最大的视频网站之一, 其推荐系统帮助10亿以上的用户,从海量视频中,发现个性化的内容。该推荐系统主要面临以下三个挑战:
- 规模: 许多现有的推荐算法证明在小数据量下运行良好,但不能满足YouTube这样庞大的用户群和内容库的场景,因此需要高度专业化的分布式学习算法和高效的线上服务。 - 规模: 许多现有的推荐算法证明在小数据量下运行良好,但不能满足YouTube这样庞大的用户群和内容库的场景,因此需要高度专业化的分布式学习算法和高效的线上服务。
- 新鲜度: YouTube内容库更新频率极高,每秒上传小时级别视频。 系统应及时追踪新上传的视频和用户的实时的行为,并且模型在推荐新/旧视频上有好的平衡能力。 - 新鲜度: YouTube内容库更新频率极高,每秒上传小时级别视频。系统应及时追踪新上传的视频和用户的实时的行为,并且模型在推荐新/旧视频上有良好平衡能力。
- 噪音: 噪音来自于两方面,其一,用户历史行为稀疏,且有各种不可观测的外部因素,以及用户满意度不明确。其二,内容本身的数据是非结构化的。因此算法应更具有鲁棒性。 - 噪音: 噪音来自于两方面,其一,用户历史行为稀疏,且有各种不可观测的外部因素,以及用户满意度不明确。其二,内容本身的数据是非结构化的。因此算法应更具有鲁棒性。
下图展示了整个推荐系统框图: 下图展示了整个推荐系统框图:
...@@ -27,14 +29,14 @@ Youtube是世界最大的视频网站之一, 其推荐系统帮助10亿以上的 ...@@ -27,14 +29,14 @@ Youtube是世界最大的视频网站之一, 其推荐系统帮助10亿以上的
Figure 1. 推荐系统框图 Figure 1. 推荐系统框图
</p> </p>
整个推荐系统有两部分组成: 召回(candidate generation)和排序(ranking)。 整个推荐系统有两部分组成: 召回(candidate generation/recall)和排序(ranking)。
- 召回模型: 输入用户的历史行为, 从大规模的内容库中获得一个小集合(百级别)。召回出的视频与用户高度相关。一个用户是用其历史点击过的视频,搜索过的关键词,和人口统计相关的特征来表征。 - 召回模型: 输入用户的历史行为, 从大规模的内容库中获得一个小集合(百级别)。召回出的视频与用户高度相关。一个用户是用其历史点击过的视频,搜索过的关键词,和人口统计相关的特征来表征。
- 排序模型: 采用更精细的特征计算得到排序分,对召回得到的候选集合中的视频排序。 - 排序模型: 采用更精细的特征计算得到排序分,对召回得到的候选集合中的视频进行排序。
## 召回模型简介 ## 召回模型简介
该推荐问题可以被建模成一个"超大规模多分类"问题。即在时刻$$t$$,为用户$$U$$(已知上下文信息$$C$$)在视频库$$V$$中预测出观看视频i的类别, 该推荐问题可以被建模成一个"超大规模多分类"问题。即在时刻$$t$$,为用户$$U$$(已知上下文信息$$C$$)在视频库$$V$$中预测出观看视频i的类别,
$$P(\omega_t=i|U,C)=\frac{e^{v_iu}}{\sum_{j\in V}^{ }e^{v_ju}}$$ $$P(\omega_t=i|U,C)=\frac{e^{v_iu}}{\sum_{j\in V}^{ }e^{v_ju}}$$
其中$$u\in \mathbb{R}^N$$,是<用户,上下文信息>的高维向量表示。$$v_j\in \mathbb{R}^N$$是视频`j`的高维向量表示。DNN模型的目标是以用户信息和上下文信息为输入条件下,学习用户的高维向量表示,以此输入softmax分类器,来预测视频库中各个视频(类别)的观看概率。 其中$$\mathbf{u}\in \mathbb{R}^N$$,是<用户,上下文信息>的高维向量表示。$$\mathbf{v_j}\in \mathbb{R}^N$$是视频`j`的高维向量表示。DNN模型的目标是以用户信息和上下文信息为输入条件下,学习用户的高维向量表示,以此输入softmax分类器,来预测视频库中各个视频(类别)的观看概率。
下图展示了召回模型的网络结构: 下图展示了召回模型的网络结构:
<p align="center"> <p align="center">
...@@ -66,7 +68,7 @@ sh download.sh ...@@ -66,7 +68,7 @@ sh download.sh
usage: data_processor.py [-h] --train_set_path TRAIN_SET_PATH --output_dir usage: data_processor.py [-h] --train_set_path TRAIN_SET_PATH --output_dir
OUTPUT_DIR [--feat_appear_limit FEAT_APPEAR_LIMIT] OUTPUT_DIR [--feat_appear_limit FEAT_APPEAR_LIMIT]
PaddlePaddle Deep Candidate Generation Example PaddlePaddle Youtube Recall Model Example
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
...@@ -82,6 +84,7 @@ optional arguments: ...@@ -82,6 +84,7 @@ optional arguments:
- 借鉴\[[2](#参考文献)\]中对特征的处理,过滤低频特征(样本中出现次数低于`feat_appear_limit`)。 - 借鉴\[[2](#参考文献)\]中对特征的处理,过滤低频特征(样本中出现次数低于`feat_appear_limit`)。
- 对特征进行编码,生成字典`feature_dict.pkl` - 对特征进行编码,生成字典`feature_dict.pkl`
- 统计每个视频出现的概率,保存至`item_freq.pkl`,提供给nce层使用。 - 统计每个视频出现的概率,保存至`item_freq.pkl`,提供给nce层使用。
例如可执行下列命令, 完成数据预处理: 例如可执行下列命令, 完成数据预处理:
```shell ```shell
python data_processor.py --train_set_path=./data/train.txt \ python data_processor.py --train_set_path=./data/train.txt \
...@@ -123,7 +126,7 @@ def _build_input_layer(self): ...@@ -123,7 +126,7 @@ def _build_input_layer(self):
``` ```
### Embedding层 ### Embedding层
每个输入特征都被embedding到固定维度的向量中。 每个输入特征通过embedding到固定维度的向量中。
```python ```python
def _create_emb_attr(self, name): def _create_emb_attr(self, name):
""" """
...@@ -167,25 +170,31 @@ def _build_embedding_layer(self): ...@@ -167,25 +170,31 @@ def _build_embedding_layer(self):
### 隐层 ### 隐层
我们对原paper中做了改进,历史用户点击视频序列,经过embedding后,不再是加权求平均。而是连接lstm层,将用户点击的先后次序纳入模型,再在时间序列上做最大池化,得到定长的向量表示,从而使模型学习到与点击时序相关的隐藏信息。考虑到数据规模与训练性能,我们只用了两个Relu层,也有不错的效果。 我们对原paper中做了改进,历史用户点击视频序列,经过embedding后,不再是加权求平均。而是连接lstm层,将用户点击的先后次序纳入模型,再在时间序列上做最大池化,得到定长的向量表示,从而使模型学习到与点击时序相关的隐藏信息。考虑到数据规模与训练性能,我们只用了两个Relu层,也有不错的效果。
```python ```python
self._rnn_cell = paddle.networks.simple_lstm(input=self._history_clicked_items_emb, size=64) self._rnn_cell = paddle.networks.simple_lstm(
self._lstm_last = paddle.layer.pooling( input=self._history_clicked_items_emb, size=64)
input=self._rnn_cell, pooling_type=paddle.pooling.Max()) self._lstm_last = paddle.layer.pooling(
self._avg_emb_cats = paddle.layer.pooling(input=self._history_clicked_categories_emb, input=self._rnn_cell, pooling_type=paddle.pooling.Max())
pooling_type=paddle.pooling.Avg()) self._avg_emb_cats = paddle.layer.pooling(
self._avg_emb_tags = paddle.layer.pooling(input=self._history_clicked_tags_emb, input=self._history_clicked_categories_emb,
pooling_type=paddle.pooling.Avg()) pooling_type=paddle.pooling.Avg())
self._fc_0 = paddle.layer.fc( self._avg_emb_tags = paddle.layer.pooling(
name="Relu1", input=self._history_clicked_tags_emb,
input=[self._lstm_last, self._user_id_emb, pooling_type=paddle.pooling.Avg())
self._city_emb, self._phone_emb], self._fc_0 = paddle.layer.fc(
size=self._dnn_layer_dims[0], name="Relu1",
act=paddle.activation.Relu()) input=[
self._lstm_last, self._user_id_emb, self._province_emb,
self._city_emb, self._avg_emb_cats, self._avg_emb_tags,
self._phone_emb
],
size=self._dnn_layer_dims[0],
act=paddle.activation.Relu())
self._fc_1 = paddle.layer.fc( self._fc_1 = paddle.layer.fc(
name="Relu2", name="Relu2",
input=self._fc_0, input=self._fc_0,
size=self._dnn_layer_dims[1], size=self._dnn_layer_dims[1],
act=paddle.activation.Relu()) act=paddle.activation.Relu())
``` ```
### 输出层 ### 输出层
...@@ -250,7 +259,7 @@ python train.py --train_set_path='./data/train.txt' \ ...@@ -250,7 +259,7 @@ python train.py --train_set_path='./data/train.txt' \
--item_freq='./output/item_freq.pkl' --item_freq='./output/item_freq.pkl'
``` ```
## 预测 ## 离线预测
输入用户相关的特征,输出topN个最可能观看的视频,可执行以下命令: 输入用户相关的特征,输出topN个最可能观看的视频,可执行以下命令:
```shell ```shell
python infer.py --infer_set_path='./data/infer.txt' \ python infer.py --infer_set_path='./data/infer.txt' \
...@@ -259,7 +268,18 @@ python infer.py --infer_set_path='./data/infer.txt' \ ...@@ -259,7 +268,18 @@ python infer.py --infer_set_path='./data/infer.txt' \
--batch_size=50 --batch_size=50
``` ```
## 在线预测
在线预测的时候,我们采用近似最近邻(approximate nearest neighbor-ANN)算法直接用用户向量查询最相关的topN个视频内容。由于我们的ANN暂时只支持cosine,而模型是根据内积排序的,两者效果差异太大。
为此,我们的解决方案是,对用户和视频向量,作SIMPLE-LSH变换\[[4](#参考文献)\],使内积排序与cosin排序等价。具体如下:
对于视频向量$$\mathbf{v}\in \mathbb{R}^N$$,有$$\left \| \mathbf{v} \right \|\leqslant m$$,变换后的$$\tilde{\mathbf{v}}\in \mathbb{R}^{N+1}$$,
$$\tilde{\mathbf{v}} = [\frac{\mathbf{v}}{m}; \sqrt{1 -\left \| \mathbf{\frac{\mathbf{v}}{m}{}} \right \|^2}]$$
对于用户向量$$\mathbf{u}\in \mathbb{R}^N$$,变换后的$$\tilde{\mathbf{u}}\in \mathbb{R}^{N+1}$$,
$$\tilde{\mathbf{u}} = [\mathbf{u}_{norm}; 0]$$,其中$$\mathbf{u}_{norm}$$是模长归一化后的$$\mathbf{u}$$,
线上对于一个$$\mathbf{u}$$用内积召回$$\mathbf{v},作上述变换$$\mathbf{u}\rightarrow \tilde{\mathbf{u}}, \mathbf{v}\rightarrow \tilde{\mathbf{v}}$$后,不改变内积排序的顺序。又因为$$\left \| \tilde{\mathbf{u}} \right \|$$和$$\left \| \tilde{\mathbf{v}} \right \|$$都为1,因此$$cos(\tilde{\mathbf{u}} ,\tilde{\mathbf{v}}) = \tilde{\mathbf{u}}\cdot \tilde{\mathbf{v}}$$,就可以兼容ANN用cosin的方式召回了,结果等价。线上使用时,为保留精度,可以不除以$$$m$$,也就变成$\tilde{\mathbf{v}} = [\mathbf{v}; \sqrt{m^2 -\left \| \mathbf{\mathbf{v}} \right \|^2}]$$,排序依然等价。
## 参考文献 ## 参考文献
1. Covington, Paul, Jay Adams, and Emre Sargin. "Deep neural networks for youtube recommendations." Proceedings of the 10th ACM Conference on Recommender Systems. ACM, 2016. 1. Covington, Paul, Jay Adams, and Emre Sargin. "Deep neural networks for youtube recommendations." Proceedings of the 10th ACM Conference on Recommender Systems. ACM, 2016.
2. https://code.google.com/archive/p/word2vec/ 2. https://code.google.com/archive/p/word2vec/
3. http://paddlepaddle.org/docs/develop/models/nce_cost/README.html 3. http://paddlepaddle.org/docs/develop/models/nce_cost/README.html
4. Neyshabur, Behnam, and Nathan Srebro. "On symmetric and asymmetric LSHs for inner product search." arXiv preprint arXiv:1410.5518 (2014).
...@@ -244,7 +244,7 @@ python train.py --train_set_path='./data/train.txt' \ ...@@ -244,7 +244,7 @@ python train.py --train_set_path='./data/train.txt' \
--item_freq='./output/item_freq.pkl' --item_freq='./output/item_freq.pkl'
``` ```
## Use the model for prediction ## Offline prediction
Input user related features, and then get the most likely N videos for user. Input user related features, and then get the most likely N videos for user.
```shell ```shell
python infer.py --infer_set_path='./data/infer.txt' \ python infer.py --infer_set_path='./data/infer.txt' \
...@@ -253,7 +253,18 @@ python infer.py --infer_set_path='./data/infer.txt' \ ...@@ -253,7 +253,18 @@ python infer.py --infer_set_path='./data/infer.txt' \
--batch_size=50 --batch_size=50
``` ```
## Online prediction
For online prediction,we adopt Approximate Nearest Neighbor(ANN) to directly recall top N mostly likely watch video. However, our ANN system currently only supports cosin sorting, not by inner product sorting, which leads to big effect difference.
As a result, we sliently modify user and video vectors by a SIMPLE-LSH conversion\[[4](#References)\], so that inner sorting is equivalent to cosin sorting after conversion.
Details as follows:
For video vector, $$\mathbf{v}\in \mathbb{R}^N$$, we have $$\left \| \mathbf{v} \right \|\leqslant m$$. The modified video vector $$\tilde{\mathbf{v}}\in \mathbb{R}^{N+1}$$,
$$\tilde{\mathbf{v}} = [\frac{\mathbf{v}}{m}; \sqrt{1 -\left \| \mathbf{\frac{\mathbf{v}}{m}{}} \right \|^2}]$$
For user vector, $$\mathbf{u}\in \mathbb{R}^N$$, The modified user vector $$\tilde{\mathbf{u}}\in \mathbb{R}^{N+1}$$,
$$\tilde{\mathbf{u}} = [\mathbf{u}_{norm}; 0]$$,in which $$\mathbf{u}_{norm}$$ is normalized $$\mathbf{u}$$,
When online predicting, For a $$\mathbf{u}$$, we need recall $$\mathbf{v} by inner product sorting. After conversion, $$\mathbf{u}\rightarrow \tilde{\mathbf{u}}, \mathbf{v}\rightarrow \tilde{\mathbf{v}}$$, the order of inner prodct sorting is unchanged. Since $$\left \| \tilde{\mathbf{u}} \right \|$$ and $$\left \| \tilde{\mathbf{v}} \right \|$$ are both equal to 1, $$cos(\tilde{\mathbf{u}} ,\tilde{\mathbf{v}}) = \tilde{\mathbf{u}}\cdot \tilde{\mathbf{v}}$$, which makes cosin-supported-only ANN system works. And in order to retain precision, we find that $\tilde{\mathbf{v}} = [\mathbf{v}; \sqrt{m^2 -\left \| \mathbf{\mathbf{v}} \right \|^2}]$$ is also equivalent.
## References ## References
1. Covington, Paul, Jay Adams, and Emre Sargin. "Deep neural networks for youtube recommendations." Proceedings of the 10th ACM Conference on Recommender Systems. ACM, 2016. 1. Covington, Paul, Jay Adams, and Emre Sargin. "Deep neural networks for youtube recommendations." Proceedings of the 10th ACM Conference on Recommender Systems. ACM, 2016.
2. https://code.google.com/archive/p/word2vec/ 2. https://code.google.com/archive/p/word2vec/
3. http://paddlepaddle.org/docs/develop/models/nce_cost/README.html 3. http://paddlepaddle.org/docs/develop/models/nce_cost/README.html
4. Neyshabur, Behnam, and Nathan Srebro. "On symmetric and asymmetric LSHs for inner product search." arXiv preprint arXiv:1410.5518 (2014).
...@@ -127,7 +127,8 @@ class DNNmodel(object): ...@@ -127,7 +127,8 @@ class DNNmodel(object):
self._fc_0 = paddle.layer.fc( self._fc_0 = paddle.layer.fc(
name="Relu1", name="Relu1",
input=[ input=[
self._lstm_last, self._user_id_emb, self._city_emb, self._lstm_last, self._user_id_emb, self._province_emb,
self._city_emb, self._avg_emb_cats, self._avg_emb_tags,
self._phone_emb self._phone_emb
], ],
size=self._dnn_layer_dims[0], size=self._dnn_layer_dims[0],
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import gzip
import paddle.v2 as paddle
import argparse
import cPickle
from reader import Reader
from network_conf import DNNmodel
from utils import logger
def parse_args():
"""
parse arguments
:return:
"""
parser = argparse.ArgumentParser(
description="PaddlePaddle Youtube Recall Model Example")
parser.add_argument(
'--infer_set_path',
type=str,
required=True,
help="path of the infer set")
parser.add_argument(
'--model_path', type=str, required=True, help="path of the model")
parser.add_argument(
'--feature_dict',
type=str,
required=True,
help="path of feature_dict.pkl")
parser.add_argument(
'--batch_size',
type=int,
default=50,
help="size of mini-batch (default:50)")
return parser.parse_args()
def vector():
"""
print user vector and item vector
"""
args = parse_args()
# check argument
assert os.path.exists(
args.infer_set_path), 'The infer_set_path path does not exist.'
assert os.path.exists(
args.model_path), 'The model_path path does not exist.'
assert os.path.exists(
args.feature_dict), 'The feature_dict path does not exist.'
paddle.init(use_gpu=False, trainer_count=1)
with open(args.feature_dict) as f:
feature_dict = cPickle.load(f)
nid_dict = feature_dict['history_clicked_items']
nid_to_word = dict((v, k) for k, v in nid_dict.items())
# load the trained model.
with gzip.open(args.model_path) as f:
parameters = paddle.parameters.Parameters.from_tar(f)
# build model
prediction_layer, fc = DNNmodel(
dnn_layer_dims=[256, 31], feature_dict=feature_dict,
is_infer=True).model_cost
inferer = paddle.inference.Inference(
output_layer=[prediction_layer, fc], parameters=parameters)
reader = Reader(feature_dict)
test_batch = []
for idx, item in enumerate(reader.infer(args.infer_set_path)):
test_batch.append(item)
if len(test_batch) == args.batch_size:
infer_a_batch(inferer, test_batch, nid_to_word)
test_batch = []
if len(test_batch):
infer_a_batch(inferer, test_batch, nid_to_word)
def infer_a_batch(inferer, test_batch, nid_to_word):
"""
input a batch of data and infer
"""
feeding = {
'user_id': 0,
'province': 1,
'city': 2,
'history_clicked_items': 3,
'history_clicked_categories': 4,
'history_clicked_tags': 5,
'phone': 6
}
probs = inferer.infer(
input=test_batch,
feeding=feeding,
field=["value"],
flatten_result=False)
for i, res in enumerate(zip(test_batch, probs[0], probs[1])):
print "Sample %s:" % str(i)
user_vector = [1.000]
for i in res[2]:
user_vector.append(i)
user_vector.append(0.000)
norm = np.linalg.norm(user_vector)
user_vector_norm = [_ / norm for _ in user_vector]
print user_vector_norm
if __name__ == "__main__":
vector()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册