提交 12801b6a 编写于 作者: Q Qiao Longfei

init ctr model

上级 59adc0d6
运行本目录下的程序示例需要使用PaddlePaddle v0.10.0 版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html)中的说明更新PaddlePaddle安装版本。
---
# 基于深度因子分解机的点击率预估模型
## 介绍
本模型实现了下述论文中提出的DeepFM模型:
```text
@inproceedings{guo2017deepfm,
title={DeepFM: A Factorization-Machine based Neural Network for CTR Prediction},
author={Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li and Xiuqiang He},
booktitle={the Twenty-Sixth International Joint Conference on Artificial Intelligence (IJCAI)},
pages={1725--1731},
year={2017}
}
```
DeepFM模型把因子分解机和深度神经网络的低阶和高阶特征的相互作用结合起来,有关因子分解机的详细信息,请参考论文[因子分解机](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf)
## 数据集
本文使用的是Kaggle公司举办的[展示广告竞赛](https://www.kaggle.com/c/criteo-display-ad-challenge/)中所使用的Criteo数据集。
每一行是一次广告展示的特征,第一列是一个标签,表示这次广告展示是否被点击。总共有39个特征,其中13个特征采用整型值,另外26个特征是类别类特征。测试集中是没有标签的。
下载数据集:
```bash
cd data && ./download.sh && cd ..
```
## 模型
DeepFM模型是由因子分解机(FM)和深度神经网络(DNN)组成的。所有的输入特征都会同时输入FM和DNN,最后把FM和DNN的输出结合在一起形成最终的输出。DNN中稀疏特征生成的嵌入层与FM层中的隐含向量(因子)共享参数。
PaddlePaddle中的因子分解机层负责计算二阶组合特征的相互关系。以下的代码示例结合了因子分解机层和全连接层,形成了完整的的因子分解机:
```python
def fm_layer(input, factor_size):
first_order = paddle.layer.fc(input=input, size=1, act=paddle.activation.Linear())
second_order = paddle.layer.factorization_machine(input=input, factor_size=factor_size)
fm = paddle.layer.addto(input=[first_order, second_order],
act=paddle.activation.Linear(),
bias_attr=False)
return fm
```
## 数据准备
处理原始数据集,整型特征使用min-max归一化方法规范到[0, 1],类别类特征使用了one-hot编码。原始数据集分割成两部分:90%用于训练,其他10%用于训练过程中的验证。
```bash
python preprocess.py --datadir ./data/raw --outdir ./data
```
## 训练
训练的命令行选项可以通过`python train.py -h`列出。
训练模型:
```bash
python train.py \
--train_data_path data/train.txt \
--test_data_path data/valid.txt \
2>&1 | tee train.log
```
训练到第9轮的第40000个batch后,测试的AUC为0.807178,误差(cost)为0.445196。
## 预测
预测的命令行选项可以通过`python infer.py -h`列出。
对测试集进行预测:
```bash
python infer.py \
--model_gz_path models/model-pass-9-batch-10000.tar.gz \
--data_path data/test.txt \
--prediction_output_path ./predict.txt
```
The minimum PaddlePaddle version needed for the code sample in this directory is v0.11.0. If you are on a version of PaddlePaddle earlier than v0.11.0, [please update your installation](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html).
---
# Deep Factorization Machine for Click-Through Rate prediction
## Introduction
This model implements the DeepFM proposed in the following paper:
```text
@inproceedings{guo2017deepfm,
title={DeepFM: A Factorization-Machine based Neural Network for CTR Prediction},
author={Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li and Xiuqiang He},
booktitle={the Twenty-Sixth International Joint Conference on Artificial Intelligence (IJCAI)},
pages={1725--1731},
year={2017}
}
```
The DeepFm combines factorization machine and deep neural networks to model
both low order and high order feature interactions. For details of the
factorization machines, please refer to the paper [factorization
machines](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf)
## Dataset
This example uses Criteo dataset which was used for the [Display Advertising
Challenge](https://www.kaggle.com/c/criteo-display-ad-challenge/)
hosted by Kaggle.
Each row is the features for an ad display and the first column is a label
indicating whether this ad has been clicked or not. There are 39 features in
total. 13 features take integer values and the other 26 features are
categorical features. For the test dataset, the labels are omitted.
Download dataset:
```bash
cd data && ./download.sh && cd ..
```
## Model
The DeepFM model is composed of the factorization machine layer (FM) and deep
neural networks (DNN). All the input features are feeded to both FM and DNN.
The output from FM and DNN are combined to form the final output. The embedding
layer for sparse features in the DNN shares the parameters with the latent
vectors (factors) of the FM layer.
The factorization machine layer in PaddlePaddle computes the second order
interactions. The following code example combines the factorization machine
layer and fully connected layer to form the full version of factorization
machine:
```python
def fm_layer(input, factor_size):
first_order = paddle.layer.fc(input=input, size=1, act=paddle.activation.Linear())
second_order = paddle.layer.factorization_machine(input=input, factor_size=factor_size)
fm = paddle.layer.addto(input=[first_order, second_order],
act=paddle.activation.Linear(),
bias_attr=False)
return fm
```
## Data preparation
To preprocess the raw dataset, the integer features are clipped then min-max
normalized to [0, 1] and the categorical features are one-hot encoded. The raw
training dataset are splited such that 90% are used for training and the other
10% are used for validation during training.
```bash
python preprocess.py --datadir ./data/raw --outdir ./data
```
## Train
The command line options for training can be listed by `python train.py -h`.
To train the model:
```bash
python train.py \
--train_data_path data/train.txt \
--test_data_path data/valid.txt \
2>&1 | tee train.log
```
After training pass 9 batch 40000, the testing AUC is `0.807178` and the testing
cost is `0.445196`.
## Infer
The command line options for infering can be listed by `python infer.py -h`.
To make inference for the test dataset:
```bash
python infer.py \
--model_gz_path models/model-pass-9-batch-10000.tar.gz \
--data_path data/test.txt \
--prediction_output_path ./predict.txt
```
#!/bin/bash
wget --no-check-certificate https://s3-eu-west-1.amazonaws.com/criteo-labs/dac.tar.gz
tar zxf dac.tar.gz
rm -f dac.tar.gz
mkdir raw
mv ./*.txt raw/
import os
import gzip
import argparse
import itertools
import paddle.v2 as paddle
from network_conf import DeepFM
import reader
def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle DeepFM example")
parser.add_argument(
'--model_gz_path',
type=str,
required=True,
help="The path of model parameters gz file")
parser.add_argument(
'--data_path',
type=str,
required=True,
help="The path of the dataset to infer")
parser.add_argument(
'--prediction_output_path',
type=str,
required=True,
help="The path to output the prediction")
parser.add_argument(
'--factor_size',
type=int,
default=10,
help="The factor size for the factorization machine (default:10)")
return parser.parse_args()
def infer():
args = parse_args()
paddle.init(use_gpu=False, trainer_count=1)
model = DeepFM(args.factor_size, infer=True)
parameters = paddle.parameters.Parameters.from_tar(
gzip.open(args.model_gz_path, 'r'))
inferer = paddle.inference.Inference(
output_layer=model, parameters=parameters)
dataset = reader.Dataset()
infer_reader = paddle.batch(dataset.infer(args.data_path), batch_size=1000)
with open(args.prediction_output_path, 'w') as out:
for id, batch in enumerate(infer_reader()):
res = inferer.infer(input=batch)
predictions = [x for x in itertools.chain.from_iterable(res)]
out.write('\n'.join(map(str, predictions)) + '\n')
if __name__ == '__main__':
infer()
import paddle.fluid as fluid
dense_feature_dim = 13
sparse_feature_dim = 117568
def DeepFM(factor_size, infer=False):
dense_input = fluid.layers.data(
name="dense_input", shape=[dense_feature_dim], dtype='float32')
sparse_input_ids = [
fluid.layers.data(
name="C" + str(i), shape=[1], lod_level=1, dtype='int64')
for i in range(1, 27)
]
def embedding_layer(input):
return fluid.layers.embedding(
input=input,
size=[sparse_feature_dim, factor_size],
param_attr=fluid.ParamAttr(name="SparseFeatFactors"))
sparse_embed_seq = map(embedding_layer, sparse_input_ids)
concated = fluid.layers.concat(sparse_embed_seq + [dense_input], axis=1)
fc1 = fluid.layers.fc(input=concated, size=400, act='relu')
fc2 = fluid.layers.fc(input=fc1, size=400, act='relu')
fc3 = fluid.layers.fc(input=fc2, size=400, act='relu')
predict = fluid.layers.fc(input=fc3, size=2, act='sigmoid')
data_list = [dense_input] + sparse_input_ids
if not infer:
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.reduce_sum(cost)
accuracy = fluid.layers.accuracy(input=predict, label=label)
auc_var, cur_auc_var, auc_states = fluid.layers.auc(input=predict, label=label, num_thresholds=2**12)
data_list.append(label)
return avg_cost, data_list
else:
return predict, data_list
"""
Preprocess Criteo dataset. This dataset was used for the Display Advertising
Challenge (https://www.kaggle.com/c/criteo-display-ad-challenge).
"""
import os
import sys
import click
import random
import collections
# There are 13 integer features and 26 categorical features
continous_features = range(1, 14)
categorial_features = range(14, 40)
# Clip integer features. The clip point for each integer feature
# is derived from the 95% quantile of the total values in each feature
continous_clip = [20, 600, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
class CategoryDictGenerator:
"""
Generate dictionary for each of the categorical features
"""
def __init__(self, num_feature):
self.dicts = []
self.num_feature = num_feature
for i in range(0, num_feature):
self.dicts.append(collections.defaultdict(int))
def build(self, datafile, categorial_features, cutoff=0):
with open(datafile, 'r') as f:
for line in f:
features = line.rstrip('\n').split('\t')
for i in range(0, self.num_feature):
if features[categorial_features[i]] != '':
self.dicts[i][features[categorial_features[i]]] += 1
for i in range(0, self.num_feature):
self.dicts[i] = filter(lambda x: x[1] >= cutoff,
self.dicts[i].items())
self.dicts[i] = sorted(self.dicts[i], key=lambda x: (-x[1], x[0]))
vocabs, _ = list(zip(*self.dicts[i]))
self.dicts[i] = dict(zip(vocabs, range(1, len(vocabs) + 1)))
self.dicts[i]['<unk>'] = 0
def gen(self, idx, key):
if key not in self.dicts[idx]:
res = self.dicts[idx]['<unk>']
else:
res = self.dicts[idx][key]
return res
def dicts_sizes(self):
return map(len, self.dicts)
class ContinuousFeatureGenerator:
"""
Normalize the integer features to [0, 1] by min-max normalization
"""
def __init__(self, num_feature):
self.num_feature = num_feature
self.min = [sys.maxint] * num_feature
self.max = [-sys.maxint] * num_feature
def build(self, datafile, continous_features):
with open(datafile, 'r') as f:
for line in f:
features = line.rstrip('\n').split('\t')
for i in range(0, self.num_feature):
val = features[continous_features[i]]
if val != '':
val = int(val)
if val > continous_clip[i]:
val = continous_clip[i]
self.min[i] = min(self.min[i], val)
self.max[i] = max(self.max[i], val)
def gen(self, idx, val):
if val == '':
return 0.0
val = float(val)
return (val - self.min[idx]) / (self.max[idx] - self.min[idx])
@click.command("preprocess")
@click.option("--datadir", type=str, help="Path to raw criteo dataset")
@click.option("--outdir", type=str, help="Path to save the processed data")
def preprocess(datadir, outdir):
"""
All the 13 integer features are normalzied to continous values and these
continous features are combined into one vecotr with dimension 13.
Each of the 26 categorical features are one-hot encoded and all the one-hot
vectors are combined into one sparse binary vector.
"""
dists = ContinuousFeatureGenerator(len(continous_features))
dists.build(os.path.join(datadir, 'train.txt'), continous_features)
dicts = CategoryDictGenerator(len(categorial_features))
dicts.build(
os.path.join(datadir, 'train.txt'), categorial_features, cutoff=200)
dict_sizes = dicts.dicts_sizes()
categorial_feature_offset = [0]
for i in range(1, len(categorial_features)):
offset = categorial_feature_offset[i - 1] + dict_sizes[i - 1]
categorial_feature_offset.append(offset)
random.seed(0)
# 90% of the data are used for training, and 10% of the data are used
# for validation.
with open(os.path.join(outdir, 'train.txt'), 'w') as out_train:
with open(os.path.join(outdir, 'valid.txt'), 'w') as out_valid:
with open(os.path.join(datadir, 'train.txt'), 'r') as f:
for line in f:
features = line.rstrip('\n').split('\t')
continous_vals = []
for i in range(0, len(continous_features)):
val = dists.gen(i, features[continous_features[i]])
continous_vals.append("{0:.6f}".format(val).rstrip('0')
.rstrip('.'))
categorial_vals = []
for i in range(0, len(categorial_features)):
val = dicts.gen(i, features[categorial_features[
i]]) + categorial_feature_offset[i]
categorial_vals.append(str(val))
continous_vals = ','.join(continous_vals)
categorial_vals = ','.join(categorial_vals)
label = features[0]
if random.randint(0, 9999) % 10 != 0:
out_train.write('\t'.join(
[continous_vals, categorial_vals, label]) + '\n')
else:
out_valid.write('\t'.join(
[continous_vals, categorial_vals, label]) + '\n')
with open(os.path.join(outdir, 'test.txt'), 'w') as out:
with open(os.path.join(datadir, 'test.txt'), 'r') as f:
for line in f:
features = line.rstrip('\n').split('\t')
continous_vals = []
for i in range(0, len(continous_features)):
val = dists.gen(i, features[continous_features[i] - 1])
continous_vals.append("{0:.6f}".format(val).rstrip('0')
.rstrip('.'))
categorial_vals = []
for i in range(0, len(categorial_features)):
val = dicts.gen(i, features[categorial_features[
i] - 1]) + categorial_feature_offset[i]
categorial_vals.append(str(val))
continous_vals = ','.join(continous_vals)
categorial_vals = ','.join(categorial_vals)
out.write('\t'.join([continous_vals, categorial_vals]) + '\n')
if __name__ == "__main__":
preprocess()
class Dataset:
def _reader_creator(self, path, is_infer):
def reader():
with open(path, 'r') as f:
for line in f:
features = line.rstrip('\n').split('\t')
dense_feature = map(float, features[0].split(','))
sparse_feature = map(lambda x: [int(x)], features[1].split(','))
if not is_infer:
label = [float(features[2])]
yield [dense_feature
] + sparse_feature + [label]
else:
yield [dense_feature] + sparse_feature
return reader
def train(self, path):
return self._reader_creator(path, False)
def test(self, path):
return self._reader_creator(path, False)
def infer(self, path):
return self._reader_creator(path, True)
feeding = {
'dense_input': 0,
'sparse_input': 1,
'C1': 2,
'C2': 3,
'C3': 4,
'C4': 5,
'C5': 6,
'C6': 7,
'C7': 8,
'C8': 9,
'C9': 10,
'C10': 11,
'C11': 12,
'C12': 13,
'C13': 14,
'C14': 15,
'C15': 16,
'C16': 17,
'C17': 18,
'C18': 19,
'C19': 20,
'C20': 21,
'C21': 22,
'C22': 23,
'C23': 24,
'C24': 25,
'C25': 26,
'C26': 27,
'label': 28
}
import os
import logging
import argparse
import paddle.fluid as fluid
from network_conf import DeepFM
import reader
import paddle
logging.basicConfig()
logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle DeepFM example")
parser.add_argument(
'--train_data_path',
type=str,
default='./data/train.txt',
help="The path of training dataset")
parser.add_argument(
'--test_data_path',
type=str,
default='./data/test.txt',
help="The path of testing dataset")
parser.add_argument(
'--batch_size',
type=int,
default=1000,
help="The size of mini-batch (default:1000)")
parser.add_argument(
'--factor_size',
type=int,
default=10,
help="The factor size for the factorization machine (default:10)")
parser.add_argument(
'--num_passes',
type=int,
default=10,
help="The number of passes to train (default: 10)")
parser.add_argument(
'--model_output_dir',
type=str,
default='models',
help='The path for model to store (default: models)')
return parser.parse_args()
def train():
args = parse_args()
if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
loss, data_list = DeepFM(args.factor_size)
optimizer = fluid.optimizer.Adam(learning_rate=1e-4)
optimize_ops, params_grads = optimizer.minimize(loss)
dataset = reader.Dataset()
train_reader = paddle.batch(
paddle.reader.shuffle(
dataset.train(args.train_data_path),
buf_size=args.batch_size * 100),
batch_size=args.batch_size)
place = fluid.CPUPlace()
feeder = fluid.DataFeeder(feed_list=data_list, place=place)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for data in train_reader():
loss_var = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[loss]
)
print(loss_var)
if __name__ == '__main__':
train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册