未验证 提交 a03fc02a 编写于 作者: O overlordmax 提交者: GitHub

Youtube 05131100 (#4614)

* fix bugs

* fix bugs

* add wide_deep

* fix code style

* fix code style

* fix some bugs

* fix filename

* add ncf

* add download data

* add download data

* add youtube dnn

* edit README.md
上级 b0d375a8
# youtube dnn
以下是本例的简要目录结构及说明:
```
├── README.md # 文档
├── youtubednn.py # youtubednn.py网络文件
├── args.py # 参数脚本
├── train.py # 训练文件
├── infer.py # 预测文件
├── train_gpu.sh # gpu训练shell脚本
├── train_cpu.sh # cpu训练shell脚本
├── infer_gpu.sh # gpu预测shell脚本
├── infer_cpu.sh # cpu预测shell脚本
├── get_topk.py # 获取user最有可能点击的k个video
├── rec_topk.sh # 推荐shell脚本
```
## 简介
[《Deep Neural Networks for YouTube Recommendations》](https://link.zhihu.com/?target=https%3A//static.googleusercontent.com/media/research.google.com/zh-CN//pubs/archive/45530.pdf) 这篇论文是google的YouTube团队在推荐系统上DNN方面的尝试,是经典的向量化召回模型,主要通过模型来学习用户和物品的兴趣向量,并通过内积来计算用户和物品之间的相似性,从而得到最终的候选集。YouTube采取了两层深度网络完成整个推荐过程:
1.第一层是**Candidate Generation Model**完成候选视频的快速筛选,这一步候选视频集合由百万降低到了百的量级。
2.第二层是用**Ranking Model**完成几百个候选视频的精排。
本项目在paddlepaddle上完成YouTube dnn的召回部分Candidate Generation Model,分别获得用户和物品的向量表示,从而后续可以通过其他方法(如用户和物品的余弦相似度)给用户推荐物品。
由于原论文没有开源数据集,本项目随机构造数据验证网络的正确性。
## 环境
PaddlePaddle 1.7.0
python3.7
## 单机训练
GPU环境
在train_gpu.sh脚本文件中设置好数据路径、参数。
```sh
CUDA_VISIBLE_DEVICES=0 python train.py --use_gpu 1\ #使用gpu
--batch_size 32\
--epochs 20\
--watch_vec_size 64\ #特征维度
--search_vec_size 64\
--other_feat_size 64\
--output_size 100\
--model_dir 'model_dir'\ #模型保存路径
--test_epoch 19\
--base_lr 0.01\
--video_vec_path './video_vec.csv' #得到物品向量文件路径
```
执行脚本
```sh
sh train_gpu.sh
```
CPU环境
在train_cpu.sh脚本文件中设置好数据路径、参数。
```sh
python train.py --use_gpu 0\ #使用cpu
--batch_size 32\
--epochs 20\
--watch_vec_size 64\ #特征维度
--search_vec_size 64\
--other_feat_size 64\
--output_size 100\
--model_dir 'model_dir'\ #模型保存路径
--test_epoch 19\
--base_lr 0.01\
--video_vec_path './video_vec.csv' #得到物品向量文件路径
```
执行脚本
```
sh train_cpu.sh
```
## 单机预测
GPU环境
在infer_gpu.sh脚本文件中设置好数据路径、参数。
```sh
CUDA_VISIBLE_DEVICES=0 python infer.py --use_gpu 1 \ #使用gpu
--test_epoch 19 \ #采用哪一轮模型来预测
--model_dir './model_dir' \ #模型路径
--user_vec_path './user_vec.csv' #用户向量路径
```
执行脚本
```sh
sh infer_gpu.sh
```
CPU环境
在infer_cpu.sh脚本文件中设置好数据路径、参数。
```sh
python infer.py --use_gpu 0 \ #使用cpu
--test_epoch 19 \ #采用哪一轮模型来预测
--model_dir './model_dir' \ #模型路径
--user_vec_path './user_vec.csv' #用户向量路径
```
执行脚本
```sh
sh infer_cpu.sh
```
## 模型效果
构造数据集进行训练:
```
W0512 23:12:36.044643 2124 device_context.cc:237] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 10.1, Runtime API Version: 9.0
W0512 23:12:36.050058 2124 device_context.cc:245] device: 0, cuDNN Version: 7.3.
2020-05-12 23:12:37,681-INFO: epoch_id: 0, batch_time: 0.00719s, loss: 4.68754, acc: 0.00000
2020-05-12 23:12:37,686-INFO: epoch_id: 0, batch_time: 0.00503s, loss: 4.54141, acc: 0.03125
2020-05-12 23:12:37,691-INFO: epoch_id: 0, batch_time: 0.00419s, loss: 4.92227, acc: 0.00000
```
通过计算每个用户和每个物品的余弦相似度,给每个用户推荐topk视频:
```
user:0, top K videos:[93, 73, 6, 20, 84]
user:1, top K videos:[58, 0, 46, 86, 71]
user:2, top K videos:[52, 51, 47, 82, 19]
......
user:96, top K videos:[0, 52, 86, 45, 11]
user:97, top K videos:[0, 52, 45, 58, 28]
user:98, top K videos:[58, 24, 49, 36, 46]
user:99, top K videos:[0, 47, 44, 72, 51]
```
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import distutils.util
import sys
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--epochs", type=int, default=20, help="epochs")
parser.add_argument("--batch_size", type=int, default=32, help="batch_size")
parser.add_argument("--test_epoch", type=int, default=19, help="test_epoch")
parser.add_argument('--use_gpu', type=int, default=0, help='whether using gpu')
parser.add_argument('--model_dir', type=str, default='./model_dir', help='model_dir')
parser.add_argument('--watch_vec_size', type=int, default=64, help='watch_vec_size')
parser.add_argument('--search_vec_size', type=int, default=64, help='search_vec_size')
parser.add_argument('--other_feat_size', type=int, default=64, help='other_feat_size')
parser.add_argument('--output_size', type=int, default=100, help='output_size')
parser.add_argument('--base_lr', type=float, default=0.01, help='base_lr')
parser.add_argument('--video_vec_path', type=str, default='./video_vec.csv', help='video_vec_path')
parser.add_argument('--user_vec_path', type=str, default='./user_vec.csv', help='user_vec_path')
parser.add_argument('--topk', type=int, default=5, help='topk')
args = parser.parse_args()
return args
import numpy as np
import pandas as pd
import args
import copy
def cos_sim(vector_a, vector_b):
vector_a = np.mat(vector_a)
vector_b = np.mat(vector_b)
num = float(vector_a * vector_b.T)
denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b)
cos = num / denom
sim = 0.5 + 0.5 * cos
return sim
def get_topK(args, K):
video_vec = pd.read_csv(args.video_vec_path, header=None)
user_vec = pd.read_csv(args.user_vec_path, header=None)
user_video_sim_list = []
for i in range(user_vec.shape[0]):
for j in range(video_vec.shape[1]):
user_video_sim = cos_sim(np.array(user_vec.loc[i]), np.array(video_vec[j]))
user_video_sim_list.append(user_video_sim)
tmp_list=copy.deepcopy(user_video_sim_list)
tmp_list.sort()
max_sim_index=[user_video_sim_list.index(one) for one in tmp_list[::-1][:K]]
print("user:{0}, top K videos:{1}".format(i, max_sim_index))
user_video_sim_list = []
if __name__ == "__main__":
args = args.parse_args()
get_topK(args, 5)
\ No newline at end of file
import paddle.fluid as fluid
import numpy as np
import pandas as pd
import time
import sys
import os
import args
import logging
from youtubednn import YoutubeDNN
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def infer(args):
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
cur_model_path = os.path.join(args.model_dir, 'epoch_' + str(args.test_epoch), "checkpoint")
with fluid.scope_guard(fluid.Scope()):
infer_program, feed_target_names, fetch_vars = fluid.io.load_inference_model(cur_model_path, exe)
# Build a random data set.
sample_size = 100
watch_vecs = []
search_vecs = []
other_feats = []
for i in range(sample_size):
watch_vec = np.random.rand(1, args.watch_vec_size)
search_vec = np.random.rand(1, args.search_vec_size)
other_feat = np.random.rand(1, args.other_feat_size)
watch_vecs.append(watch_vec)
search_vecs.append(search_vec)
other_feats.append(other_feat)
for i in range(sample_size):
l3 = exe.run(infer_program,
feed={
"watch_vec": watch_vecs[i].astype('float32'),
"search_vec": search_vecs[i].astype('float32'),
"other_feat": other_feats[i].astype('float32'),
},
return_numpy=True,
fetch_list=fetch_vars)
user_vec = pd.DataFrame(l3[0])
user_vec.to_csv(args.user_vec_path, mode="a", index=False, header=0)
if __name__ == "__main__":
args = args.parse_args()
if(os.path.exists(args.user_vec_path)):
os.system("rm " + args.user_vec_path)
infer(args)
\ No newline at end of file
python infer.py --use_gpu 0 --test_epoch 19 --model_dir './model_dir' --user_vec_path './user_vec.csv'
\ No newline at end of file
CUDA_VISIBLE_DEVICES=0 python infer.py --use_gpu 1 --test_epoch 19 --model_dir './model_dir' --user_vec_path './user_vec.csv'
\ No newline at end of file
python get_topk.py --video_vec_path './video_vec.csv' --user_vec_path './user_vec.csv' --topk 5
\ No newline at end of file
import numpy as np
import pandas as pd
import os
import random
import paddle.fluid as fluid
from youtubednn import YoutubeDNN
import paddle
import args
import logging
import time
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def train(args):
youtube_model = YoutubeDNN()
inputs = youtube_model.input_data(args.watch_vec_size, args.search_vec_size, args.other_feat_size)
loss, acc, l3 = youtube_model.net(inputs, args.output_size, layers=[128, 64, 32])
sgd = fluid.optimizer.SGD(learning_rate=args.base_lr)
sgd.minimize(loss)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
# Build a random data set.
sample_size = 100
watch_vecs = []
search_vecs = []
other_feats = []
labels = []
for i in range(sample_size):
watch_vec = np.random.rand(args.batch_size, args.watch_vec_size)
search_vec = np.random.rand(args.batch_size, args.search_vec_size)
other_feat = np.random.rand(args.batch_size, args.other_feat_size)
watch_vecs.append(watch_vec)
search_vecs.append(search_vec)
other_feats.append(other_feat)
label = np.random.randint(args.output_size, size=(args.batch_size, 1))
labels.append(label)
for epoch in range(args.epochs):
for i in range(sample_size):
begin = time.time()
loss_data, acc_val = exe.run(fluid.default_main_program(),
feed={
"watch_vec": watch_vecs[i].astype('float32'),
"search_vec": search_vecs[i].astype('float32'),
"other_feat": other_feats[i].astype('float32'),
"label": np.array(labels[i]).reshape(args.batch_size, 1)
},
return_numpy=True,
fetch_list=[loss.name, acc.name])
end = time.time()
logger.info("epoch_id: {}, batch_time: {:.5f}s, loss: {:.5f}, acc: {:.5f}".format(
epoch, end-begin, float(np.array(loss_data)), np.array(acc_val)[0]))
#save model
model_dir = os.path.join(args.model_dir, 'epoch_' + str(epoch + 1), "checkpoint")
feed_var_names = ["watch_vec", "search_vec", "other_feat"]
fetch_vars = [l3]
fluid.io.save_inference_model(model_dir, feed_var_names, fetch_vars, exe)
#save all video vector
video_array = np.array(fluid.global_scope().find_var('l4_weight').get_tensor())
video_vec = pd.DataFrame(video_array)
video_vec.to_csv(args.video_vec_path, mode="a", index=False, header=0)
if __name__ == "__main__":
args = args.parse_args()
if(os.path.exists(args.video_vec_path)):
os.system("rm " + args.video_vec_path)
train(args)
python train.py --use_gpu 0 --batch_size 32 --epochs 20 --watch_vec_size 64 --search_vec_size 64 --other_feat_size 64 --output_size 100 --model_dir 'model_dir' --test_epoch 19 --base_lr 0.01 --video_vec_path './video_vec.csv'
\ No newline at end of file
CUDA_VISIBLE_DEVICES=0 python train.py --use_gpu 1 --batch_size 32 --epochs 20 --watch_vec_size 64 --search_vec_size 64 --other_feat_size 64 --output_size 100 --model_dir 'model_dir' --test_epoch 19 --base_lr 0.01 --video_vec_path './video_vec.csv'
\ No newline at end of file
import paddle
import io
import math
import numpy as np
import paddle.fluid as fluid
class YoutubeDNN(object):
def input_data(self, watch_vec_size, search_vec_size, other_feat_size):
watch_vec = fluid.data(name="watch_vec", shape=[None, watch_vec_size], dtype="float32")
search_vec = fluid.data(name="search_vec", shape=[None, search_vec_size], dtype="float32")
other_feat = fluid.data(name="other_feat", shape=[None, other_feat_size], dtype="float32")
label = fluid.data(name="label", shape=[None, 1], dtype="int64")
inputs = [watch_vec] + [search_vec] + [other_feat] + [label]
return inputs
def fc(self, tag, data, out_dim, active='relu'):
init_stddev = 1.0
scales = 1.0 / np.sqrt(data.shape[1])
if tag == 'l4':
p_attr = fluid.param_attr.ParamAttr(name='%s_weight' % tag,
initializer=fluid.initializer.NormalInitializer(loc=0.0, scale=init_stddev * scales))
else:
p_attr = None
b_attr = fluid.ParamAttr(name='%s_bias' % tag, initializer=fluid.initializer.Constant(0.1))
out = fluid.layers.fc(input=data,
size=out_dim,
act=active,
param_attr=p_attr,
bias_attr =b_attr,
name=tag)
return out
def net(self, inputs, output_size, layers=[128, 64, 32]):
concat_feats = fluid.layers.concat(input=inputs[:-1], axis=-1)
l1 = self.fc('l1', concat_feats, layers[0], 'relu')
l2 = self.fc('l2', l1, layers[1], 'relu')
l3 = self.fc('l3', l2, layers[2], 'relu')
l4 = self.fc('l4', l3, output_size, 'softmax')
num_seqs = fluid.layers.create_tensor(dtype='int64')
acc = fluid.layers.accuracy(input=l4, label=inputs[-1], total=num_seqs)
cost = fluid.layers.cross_entropy(input=l4, label=inputs[-1])
avg_cost = fluid.layers.mean(cost)
return avg_cost, acc, l3
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册