提交 159c7e75 编写于 作者: Q Qiao Longfei

update document

上级 6ca686a2
# 基于DNN模型的点击率预估模型
# 基于skip-gram的word2vector模型
## 介绍
本模型实现了下述论文中提出的DNN模型:
```text
@inproceedings{guo2017deepfm,
title={DeepFM: A Factorization-Machine based Neural Network for CTR Prediction},
author={Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li and Xiuqiang He},
booktitle={the Twenty-Sixth International Joint Conference on Artificial Intelligence (IJCAI)},
pages={1725--1731},
year={2017}
}
```
## 运行环境
需要先安装PaddlePaddle Fluid
## 数据集
本文使用的是Kaggle公司举办的[展示广告竞赛](https://www.kaggle.com/c/criteo-display-ad-challenge/)中所使用的Criteo数据集。
每一行是一次广告展示的特征,第一列是一个标签,表示这次广告展示是否被点击。总共有39个特征,其中13个特征采用整型值,另外26个特征是类别类特征。测试集中是没有标签的。
数据集使用的是来自Matt Mahoney(http://mattmahoney.net/dc/textdata.html)的维基百科文章数据集enwiki8.
下载数据集:
```bash
......@@ -28,24 +16,28 @@ cd data && ./download.sh && cd ..
```
## 模型
本例子只实现了DeepFM论文中介绍的模型的DNN部分,DeepFM会在其他例子中给出
本例子实现了一个skip-gram模式的word2vector模型
## 数据准备
处理原始数据集,整型特征使用min-max归一化方法规范到[0, 1],类别类特征使用了one-hot编码。原始数据集分割成两部分:90%用于训练,其他10%用于训练过程中的验证。
对数据进行预处理以生成一个词典。
```bash
python preprocess.py --data_path data/enwik8 --dict_path data/enwik8_dict
```
## 训练
训练的命令行选项可以通过`python train.py -h`列出。
### 单机训练:
```bash
python train.py \
--train_data_path data/raw/train.txt \
--train_data_path data/enwik8 \
--dict_path data/enwik8_dict \
2>&1 | tee train.log
```
训练到第1轮的第40000个batch后,测试的AUC为0.801178,误差(cost)为0.445196。
### 分布式训练
本地启动一个2 trainer 2 pserver的分布式训练任务,分布式场景下训练数据会按照trainer的id进行切分,保证trainer之间的训练数据不会重叠,提高训练效率
......@@ -55,15 +47,7 @@ sh cluster_train.sh
```
## 预测
预测的命令行选项可以通过`python infer.py -h`列出。
对测试集进行预测:
```bash
python infer.py \
--model_path models/pass-0/ \
--data_path data/raw/valid.txt
```
注意:infer.py跑完最后输出的AUC才是整个预测文件的整体AUC。
## 在百度云上运行集群训练
1. 参考文档 [在百度云上启动Fluid分布式训练](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/user_guides/howto/training/train_on_baidu_cloud_cn.rst) 在百度云上部署一个CPU集群。
......
# DNN for Click-Through Rate prediction
# Skip-Gram Word2Vec Model
## Introduction
......@@ -8,14 +8,8 @@
You should install PaddlePaddle Fluid first.
## Dataset
This example uses Criteo dataset which was used for the [Display Advertising
Challenge](https://www.kaggle.com/c/criteo-display-ad-challenge/)
hosted by Kaggle.
Each row is the features for an ad display and the first column is a label
indicating whether this ad has been clicked or not. There are 39 features in
total. 13 features take integer values and the other 26 features are
categorical features. For the test dataset, the labels are omitted.
The training data for the Large Text Compression Benchmark is the first 109 bytes
of the English Wikipedia dump on Mar. 3, 2006 from Matt Mahoney(http://mattmahoney.net/dc/textdata.html).
Download dataset:
```bash
......@@ -23,19 +17,15 @@ cd data && ./download.sh && cd ..
```
## Model
This Demo only implement the DNN part of the model described in DeepFM paper.
DeepFM model will be provided in other model.
This model implement a skip-gram model of word2vector.
## Data Preprocessing method
To preprocess the raw dataset, the integer features are clipped then min-max
normalized to [0, 1] and the categorical features are one-hot encoded. The raw
training dataset are splited such that 90% are used for training and the other
10% are used for validation during training. In reader.py, training data is the first
90% of data in train.txt, and validation data is the left.
Preprocess the training data to generate a word dict.
```bash
python preprocess.py --data_path data/enwik9 --dict_path data/enwik9_dict
python preprocess.py --data_path data/enwik8 --dict_path data/enwik8_dict
```
## Train
......@@ -44,12 +34,11 @@ The command line options for training can be listed by `python train.py -h`.
### Local Train:
```bash
python train.py \
--train_data_path data/raw/train.txt \
--train_data_path data/enwik8 \
--dict_path data/enwik8_dict \
2>&1 | tee train.log
```
After training pass 1 batch 40000, the testing AUC is `0.801178` and the testing
cost is `0.445196`.
### Distributed Train
Run a 2 pserver 2 trainer distribute training on a single machine.
......@@ -61,15 +50,7 @@ sh cluster_train.sh
```
## Infer
The command line options for infering can be listed by `python infer.py -h`.
To make inference for the test dataset:
```bash
python infer.py \
--model_path models/ \
--data_path data/raw/train.txt
```
Note: The AUC value in the last log info is the total AUC for all test dataset. Here, train.txt is splited inside the reader.py so that validation data does not have overlap with training data.
## Train on Baidu Cloud
1. Please prepare some CPU machines on Baidu Cloud following the steps in [train_on_baidu_cloud](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/user_guides/howto/training/train_on_baidu_cloud_cn.rst)
......
......@@ -26,6 +26,10 @@ def parse_args():
return parser.parse_args()
def text_strip(text):
return re.sub("[^a-z ]", "", text)
def preprocess(data_path, dict_path, freq):
"""
proprocess the data, generate dictionary and save into dict_path.
......@@ -40,7 +44,7 @@ def preprocess(data_path, dict_path, freq):
with open(data_path) as f:
for line in f:
line = line.lower()
line = re.sub("[^a-z ]", "", line)
line = text_strip(line)
words = line.split()
for item in words:
if item in word_count:
......
# -*- coding: utf-8 -*
import numpy as np
"""
enwik9 dataset
http://mattmahoney.net/dc/enwik9.zip
"""
import preprocess
class Word2VecReader(object):
......@@ -42,6 +38,7 @@ class Word2VecReader(object):
def _reader():
with open(self.data_path_, 'r') as f:
for line in f:
line = preprocess.text_strip(line)
word_ids = [
self.word_to_id_[word] for word in line.split()
if word in self.word_to_id_
......
......@@ -23,12 +23,12 @@ def parse_args():
parser.add_argument(
'--train_data_path',
type=str,
default='./data/enwik9',
default='./data/enwik8',
help="The path of training dataset")
parser.add_argument(
'--dict_path',
type=str,
default='./data/enwik9_dict',
default='./data/enwik8_dict',
help="The path of data dict")
parser.add_argument(
'--test_data_path',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册