提交 b26f5e4e 编写于 作者: zhaoyijin666's avatar zhaoyijin666

增加item_vector和user_vector的计算方法

上级 0175386a
......@@ -280,9 +280,24 @@ python infer.py --infer_set_path='./data/infer.txt' \
```
## 在线预测
在线预测的时候,采用近似最近邻(approximate nearest neighbor-ANN)算法直接用用户向量查询最相关的topN个视频内容。很多ann算法只支持cosine距离,而模型是根据内积排序的,两者效果差异较大
在线预测的时候,采用近似最近邻(approximate nearest neighbor-ANN)算法直接用用户向量查询最相关的topN个视频向量,将对应的视频内容推荐给用户。下面介绍如何获得用户向量和视频向量
为此,这边的解决方案是,对用户和视频向量,作SIMPLE-LSH变换\[[4](#参考文献)\],使内积排序与cosin排序等价。
### 用户向量
用最后一个RELU层的输出,前拼一个常数项1,作为用户向量。这边最后一个RELU层的大小是31维,拼接后的用户向量就是32维,即
![](https://www.zhihu.com/equation?tex=%5Cmathbf%7Bu%7D%3D%5B1%2Cu_1%2Cu_2%2C...%2Cu_%7B31%7D%5D)
### 视频向量
视频向量从模型训练得到的softmax层的参数中提取。假设共有M个不同的视频,那么softmax层输出的是这M个视频各自用户点击的概率,即
![](https://www.zhihu.com/equation?tex=%5Cmathbf%7Bo%7D%3D%5Bs_1%2Cs_2%2C...%2Cs_%7BM%7D%5D)
从最后一个RELU层输出的用户向量![](https://www.zhihu.com/equation?tex=%5Cmathbf%7Bu%7D),到softmax层输出的M个视频的概率![](https://www.zhihu.com/equation?tex=%5Cmathbf%7Bo%7D),中间则是通过乘以了softmax层的参数w,b构成的一个![](https://www.zhihu.com/equation?tex=32%5Ctimes%20M)矩阵,其中的每一列为一个32维的视频向量,按照字典顺序一一对应。
![](https://www.zhihu.com/equation?tex=%5Cmathbf%7Bu%7D%5Ccdot%20%5Cbegin%7Bbmatrix%7D%0A%20b_1%20%20%26%20b_2%20%26%20%20%5Ccdots%20%26%20b_M%20%5C%5C%20%0A%20w_%7B11%7D%20%26%20w_%7B21%7D%20%26%20%20%5Ccdots%20%20%26%20w_%7BM1%7D%20%5C%5C%20%0A%20w_%7B12%7D%20%26%20w_%7B22%7D%20%26%20%20%20%5Ccdots%20%26%20w_%7BM2%7D%20%20%5C%5C%20%0A%5Cvdots%20%26%20%5Cvdots%20%26%20%20%5Cvdots%20%26%20%5Cvdots%20%5C%5C%20%0Aw_%7B131%7D%20%26%20%20w_%7B231%7D%20%26%20%20%5Ccdots%20%20%26%20w_%7BM31%7D%20%20%0A%5Cend%7Bbmatrix%7D_%7B32%5Ctimes%20M%7D%20%3D%20%5Cmathbf%7Bu%7D%20%5Ccdot%20%20%5Cbegin%7Bbmatrix%7D%20%0A%5Cmathbf%7Bv_1%7D%2C%20%5Cmathbf%7Bv_2%7D%2C%20%5Ccdots%2C%20%5Cmathbf%7Bv_M%7D%20%0A%5Cend%7Bbmatrix%7D_%7B1%5Ctimes%20M%7D%3D%5Cmathbf%7Bo%7D)
### SIMPLE-LSH变换
很多ann算法只支持cosine距离,而模型是根据内积排序的,两者效果差异较大。为此,这边的解决方案是,对前面得到的用户和视频向量,作SIMPLE-LSH变换\[[4](#参考文献)\],使内积排序与cosin排序等价。
具体如下:
- 对于视频向量![](https://www.zhihu.com/equation?tex=%5Cmathbf%7Bv%7D%5Cin%20%5Cmathbb%7BR%7D%5EN),有![](https://www.zhihu.com/equation?tex=%5Cleft%20%5C%7C%20%5Cmathbf%7Bv%7D%20%5Cright%20%5C%7C%5Cleqslant%20m),变换后的![](https://www.zhihu.com/equation?tex=%5Ctilde%7B%5Cmathbf%7Bv%7D%7D%5Cin%20%5Cmathbb%7BR%7D%5E%7BN%2B1%7D),![](https://www.zhihu.com/equation?tex=%5Ctilde%7B%5Cmathbf%7Bv%7D%7D%20%3D%20%5B%5Cfrac%7B%5Cmathbf%7Bv%7D%7D%7Bm%7D%3B%20%5Csqrt%7B1%20-%5Cleft%20%5C%7C%20%5Cmathbf%7B%5Cfrac%7B%5Cmathbf%7Bv%7D%7D%7Bm%7D%7B%7D%7D%20%5Cright%20%5C%7C%5E2%7D%5D)
......@@ -293,7 +308,72 @@ python infer.py --infer_set_path='./data/infer.txt' \
线上使用时,为保留精度,可以不除以![](https://www.zhihu.com/equation?tex=m),也就变成![](https://www.zhihu.com/equation?tex=%5Ctilde%7B%5Cmathbf%7Bv%7D%7D%3D%5B%5Cmathbf%7Bv%7D%3B%5Csqrt%7Bm%5E2-%5Cleft%5C%7C%20%5Cmathbf%7B%5Cmathbf%7Bv%7D%7D%5Cright%5C%7C%5E2%7D%5D),排序依然等价。
可使用`user_vector.py``vector.py`分别获取用户和视频向量。例如执行下列命令:
### 实现
可使用`user_vector.py`获取用户向量, 输入用户特征经过网络预测,probs[1]中存储的是最后一个RELU层的输出,先前拼接一个1,再做SIMPLE-LSH变换(后接一个0,归一化):
```python
probs = inferer.infer(
input=test_batch,
feeding=feeding,
field=["value"],
flatten_result=False)
for i, res in enumerate(zip(probs[1])):
# do simple lsh conversion
user_vector = [1.000]
for i in res[0]:
user_vector.append(i)
user_vector.append(0.000)
norm = np.linalg.norm(user_vector)
user_vector_norm = [str(_ / norm) for _ in user_vector]
print ",".join(user_vector_norm)
```
可使用`item_vector.py`分别获视频向量。加载模型,提取参数nce_w和nce_b,拼接M个视频向量,第i个视频向量的第一维是对应的nce_b[0][i],后面是nce_w[i][1:31]。再做SIMPLE-LSH变换,找到所有向量最大的模,按照![](https://www.zhihu.com/equation?tex=%5Ctilde%7B%5Cmathbf%7Bv%7D%7D%3D%5B%5Cmathbf%7Bv%7D%3B%5Csqrt%7Bm%5E2-%5Cleft%5C%7C%20%5Cmathbf%7B%5Cmathbf%7Bv%7D%7D%5Cright%5C%7C%5E2%7D%5D)处理。
```python
# load the trained model.
with gzip.open(args.model_path) as f:
parameters = paddle.parameters.Parameters.from_tar(f)
nce_w = parameters.get("nce_w")
nce_b = parameters.get("nce_b")
item_vector = convt_simple_lsh(get_item_vec_from_softmax(nce_w, nce_b))
def get_item_vec_from_softmax(nce_w, nce_b):
"""
get item vectors from softmax parameter
"""
if nce_w is None or nce_b is None:
return None
vector = []
total_items_num = nce_w.shape[0]
if total_items_num != nce_b.shape[1]:
return None
dim_vector = nce_w.shape[1] + 1
for i in range(0, total_items_num):
vector.append([])
vector[i].append(nce_b[0][i])
for j in range(1, dim_vector):
vector[i].append(nce_w[i][j - 1])
return vector
def convt_simple_lsh(vector):
"""
do simple lsh conversion
"""
max_norm = 0
num_of_vec = len(vector)
for i in range(0, num_of_vec):
norm = np.linalg.norm(vector[i])
if norm > max_norm:
max_norm = norm
for i in range(0, num_of_vec):
vector[i].append(
math.sqrt(
math.pow(max_norm, 2) - math.pow(np.linalg.norm(vector[i]), 2)))
return vector
```
可执行下列命令运行脚本:
```shell
python user_vector.py --infer_set_path='./data/infer.txt' \
--model_path='./output/model/model_pass_00000.tar.gz' \
......
......@@ -261,7 +261,25 @@ python infer.py --infer_set_path='./data/infer.txt' \
```
## Online prediction
For online prediction, Approximate Nearest Neighbor(ANN) is adopted to directly recall top N most likely watch video. However, our ANN system currently only supports cosin sorting, not by inner product sorting, which leads to big effect difference.
For online prediction, Approximate Nearest Neighbor(ANN) is adopted to directly recall top N most likely watch video. Here shows how to get user vector and video vector.
### User Vector
User vector is the output of the last RELU layer with cascading a constant term 1 in the front. Here the dimension of the last RELU layer is 31, and thus the dimension of user vector is 32.
![](https://www.zhihu.com/equation?tex=%5Cmathbf%7Bu%7D%3D%5B1%2Cu_1%2Cu_2%2C...%2Cu_%7B31%7D%5D)
### Video Vector
Video vector is extracted from the parameters of softmax layer. If there are M different videos, the output of softmax layer will be the probability of click of these M videos.
![](https://www.zhihu.com/equation?tex=%5Cmathbf%7Bo%7D%3D%5Bs_1%2Cs_2%2C...%2Cs_%7BM%7D%5D)
To get ![](https://www.zhihu.com/equation?tex=%5Cmathbf%7Bo%7D) from user vector ![](https://www.zhihu.com/equation?tex=%5Cmathbf%7Bu%7D), a ![](https://www.zhihu.com/equation?tex=32%5Ctimes%20M) matrix which consists of the parameters w, b of softmax layer is multiplied. Each column of this matrix is a 32-dim video vector, according to the dictionary order one by one.
![](https://www.zhihu.com/equation?tex=%5Cmathbf%7Bu%7D%5Ccdot%20%5Cbegin%7Bbmatrix%7D%0A%20b_1%20%20%26%20b_2%20%26%20%20%5Ccdots%20%26%20b_M%20%5C%5C%20%0A%20w_%7B11%7D%20%26%20w_%7B21%7D%20%26%20%20%5Ccdots%20%20%26%20w_%7BM1%7D%20%5C%5C%20%0A%20w_%7B12%7D%20%26%20w_%7B22%7D%20%26%20%20%20%5Ccdots%20%26%20w_%7BM2%7D%20%20%5C%5C%20%0A%5Cvdots%20%26%20%5Cvdots%20%26%20%20%5Cvdots%20%26%20%5Cvdots%20%5C%5C%20%0Aw_%7B131%7D%20%26%20%20w_%7B231%7D%20%26%20%20%5Ccdots%20%20%26%20w_%7BM31%7D%20%20%0A%5Cend%7Bbmatrix%7D_%7B32%5Ctimes%20M%7D%20%3D%20%5Cmathbf%7Bu%7D%20%5Ccdot%20%20%5Cbegin%7Bbmatrix%7D%20%0A%5Cmathbf%7Bv_1%7D%2C%20%5Cmathbf%7Bv_2%7D%2C%20%5Ccdots%2C%20%5Cmathbf%7Bv_M%7D%20%0A%5Cend%7Bbmatrix%7D_%7B1%5Ctimes%20M%7D%3D%5Cmathbf%7Bo%7D)
### SIMPLE-LSH conversion
However, most of ANN systems currently only support cosin sorting, not by inner product sorting, which leads to big effect difference.
To solve it, user and video vectors are sliently modified by a SIMPLE-LSH conversion\[[4](#References)\], so that inner sorting is equivalent to cosin sorting after conversion.
......@@ -274,7 +292,73 @@ When online predicting, for a coming ![](https://www.zhihu.com/equation?tex=%5Cm
And in order to retain precision, use ![](https://www.zhihu.com/equation?tex=%5Ctilde%7B%5Cmathbf%7Bv%7D%7D%3D%5B%5Cmathbf%7Bv%7D%3B%5Csqrt%7Bm%5E2-%5Cleft%5C%7C%20%5Cmathbf%7B%5Cmathbf%7Bv%7D%7D%5Cright%5C%7C%5E2%7D%5D) is also equivalent.
Use `user_vector.py` and `vector.py` to calculate user and item vectors. For example, run the following commands:
### Implemention
Run `user_vector.py` to generate user vector. First input the features into network and then infer. The output of the last RELU layer is saved in variable probs[1]. By cascading a contant term 1 in the front and making SIMPLE-LSH conversion, user vector is generated.
```python
probs = inferer.infer(
input=test_batch,
feeding=feeding,
field=["value"],
flatten_result=False)
for i, res in enumerate(zip(probs[1])):
# do simple lsh conversion
user_vector = [1.000]
for i in res[0]:
user_vector.append(i)
user_vector.append(0.000)
norm = np.linalg.norm(user_vector)
user_vector_norm = [str(_ / norm) for _ in user_vector]
print ",".join(user_vector_norm)
```
Run `item_vector.py` to generate video vector. First load the model and extract the parameters nce_w and nce_b. And then generate ith video vector by putting nce_b[0][i] in the first dimension and nce_b[0][i] in the next. Finally make SIMPLE-LSH conversion, finding the maximum norm and processing according to ![](https://www.zhihu.com/equation?tex=%5Ctilde%7B%5Cmathbf%7Bv%7D%7D%3D%5B%5Cmathbf%7Bv%7D%3B%5Csqrt%7Bm%5E2-%5Cleft%5C%7C%20%5Cmathbf%7B%5Cmathbf%7Bv%7D%7D%5Cright%5C%7C%5E2%7D%5D).
```python
# load the trained model.
with gzip.open(args.model_path) as f:
parameters = paddle.parameters.Parameters.from_tar(f)
nce_w = parameters.get("nce_w")
nce_b = parameters.get("nce_b")
item_vector = convt_simple_lsh(get_item_vec_from_softmax(nce_w, nce_b))
def get_item_vec_from_softmax(nce_w, nce_b):
"""
get item vectors from softmax parameter
"""
if nce_w is None or nce_b is None:
return None
vector = []
total_items_num = nce_w.shape[0]
if total_items_num != nce_b.shape[1]:
return None
dim_vector = nce_w.shape[1] + 1
for i in range(0, total_items_num):
vector.append([])
vector[i].append(nce_b[0][i])
for j in range(1, dim_vector):
vector[i].append(nce_w[i][j - 1])
return vector
def convt_simple_lsh(vector):
"""
do simple lsh conversion
"""
max_norm = 0
num_of_vec = len(vector)
for i in range(0, num_of_vec):
norm = np.linalg.norm(vector[i])
if norm > max_norm:
max_norm = norm
for i in range(0, num_of_vec):
vector[i].append(
math.sqrt(
math.pow(max_norm, 2) - math.pow(np.linalg.norm(vector[i]), 2)))
return vector
```
Use `user_vector.py` and `item_vector.py` to calculate user and item vectors. For example, run the following commands:
```shell
python user_vector.py --infer_set_path='./data/infer.txt' \
--model_path='./output/model/model_pass_00000.tar.gz' \
......
......@@ -110,6 +110,7 @@ def infer_a_batch(inferer, test_batch, nid_to_word):
item_id_to_word = nid_to_word[item_id]
ret += "%s:%.6f," \
% (item_id_to_word, softmax_output[item_id])
print ret.rstrip(",")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册