提交 6e44fd66 编写于 作者: W wangmeng28

Implement DeepFM for CTR prediction

上级 5d4166ab
# DeepFM 基于深度因子分解机的点击率预测模型
# Deep Factorization Machines (DeepFM) for Click-Through Rate prediction
## 简介
## Introduction
This model implements the DeepFM proposed in the following paper:
[TBD]
```text
Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li and Xiuqiang He. DeepFM:
A Factorization-Machine based Neural Network for CTR Prediction.
Proceedings of the Twenty-Sixth International Joint Conference on
Artificial Intelligence (IJCAI-17), 2017
```
The DeepFm combines factorization machines and deep neural networks to model
both low order and high order feature interactions. For details of the
factorization machines, please refer to the paper [factorization
machines](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf)
## Dataset
This example uses Criteo dataset which was used for the [Display Advertising
Challenge](https://www.kaggle.com/c/criteo-display-ad-challenge/)
hosted by Kaggle.
Each row is the features for an ad display and the first column is a label
indicating whether this ad has been clicked or not. There are 39 features in
total. 13 features take integer values and the other 26 features are
categorical features. For the test dataset, the labels are omitted.
Download dataset:
```bash
cd data && ./download.sh && cd ..
```
## Model
The DeepFM model is composed of the factorization machine layer (FM) and deep
neural networks (DNN). All the input features are feeded to both FM and DNN.
The output from FM and DNN are combined to form the final output. The embedding
layer for sparse features in the DNN shares the parameters with the latent
vectors (factors) of the FM layer.
The factorization machine layer in PaddlePaddle computes the second order
interactions. The following code example combines the factorization machine
layer and fully connected layer to form the full version of factorization
machine:
```python
def fm_layer(input, factor_size):
first_order = paddle.layer.fc(input=input, size=1, act=paddle.activation.Linear())
second_order = paddle.layer.factorization_machine(input=input, factor_size=factor_size)
fm = paddle.layer.addto(input=[first_order, second_order],
act=paddle.activation.Linear(),
ias_attr=False)
return fm
```
## Data preparation
To preprocess the raw dataset, the integer features are clipped then min-max
normalized to [0, 1] and the categorical features are one-hot encoded. The raw
training dataset are splited such that 90% are used for training and the other
10% are used for validation during training.
```bash
python preprocess.py --datadir ./data/raw --outdir ./data
```
## Train
The command line options for training can be listed by `python train.py -h`.
To train the model:
```bash
python train.py \
--train_data_path data/train.txt \
--test_data_path data/valid.txt \
2>&1 | train.log
```
## Infer
The command line options for infering can be listed by `python infer.py -h`.
To make inference for the test dataset:
```bash
python infer.py \
--model_gz_path models/model-pass-9-batch-10000.tar.gz \
--data_path data/test.txt \
--prediction_output_path ./predict.txt
```
#!/bin/bash
wget https://s3-eu-west-1.amazonaws.com/criteo-labs/dac.tar.gz
wget --no-check-certificate https://s3-eu-west-1.amazonaws.com/criteo-labs/dac.tar.gz
tar zxf dac.tar.gz
rm -f dac.tar.gz
mkdir raw
mv ./*.txt raw/
......@@ -14,7 +14,7 @@ def fm_layer(input, factor_size, fm_param_attr):
param_attr=fm_param_attr)
out = paddle.layer.addto(
input=[first_order, second_order],
act=paddle.activation.Sigmoid(),
act=paddle.activation.Linear(),
bias_attr=False)
return out
......@@ -68,6 +68,9 @@ def DeepFM(factor_size, infer=False):
name="label", type=paddle.data_type.dense_vector(1))
cost = paddle.layer.multi_binary_label_cross_entropy_cost(
input=predict, label=label)
paddle.evaluator.classification_error(
name="classification_error", input=predict, label=label)
paddle.evaluator.auc(name="auc", input=predict, label=label)
return cost
else:
return predict
......@@ -5,12 +5,17 @@ Challenge (https://www.kaggle.com/c/criteo-display-ad-challenge).
import os
import sys
import click
import random
import collections
# There are 13 integer features and 26 categorical features
continous_features = range(1, 14)
categorial_features = range(14, 40)
# Clip integer features. The clip point for each integer feature
# is derived from the 95% quantile of the total values in each feature
continous_clip = [20, 600, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
class CategoryDictGenerator:
"""
......@@ -67,12 +72,14 @@ class ContinuousFeatureGenerator:
val = features[continous_features[i]]
if val != '':
val = int(val)
if val > continous_clip[i]:
val = continous_clip[i]
self.min[i] = min(self.min[i], val)
self.max[i] = max(self.max[i], val)
def gen(self, idx, val):
if val == '':
return 0
return 0.0
val = float(val)
return (val - self.min[idx]) / (self.max[idx] - self.min[idx])
......@@ -101,26 +108,36 @@ def preprocess(datadir, outdir):
offset = categorial_feature_offset[i - 1] + dict_sizes[i - 1]
categorial_feature_offset.append(offset)
with open(os.path.join(outdir, 'train.txt'), 'w') as out:
with open(os.path.join(datadir, 'train.txt'), 'r') as f:
for line in f:
features = line.rstrip('\n').split('\t')
continous_vals = []
for i in range(0, len(continous_features)):
val = dists.gen(i, features[continous_features[i]])
continous_vals.append(str(val))
categorial_vals = []
for i in range(0, len(categorial_features)):
val = dicts.gen(i, features[categorial_features[
i]]) + categorial_feature_offset[i]
categorial_vals.append(str(val))
continous_vals = ','.join(continous_vals)
categorial_vals = ','.join(categorial_vals)
label = features[0]
out.write('\t'.join([continous_vals, categorial_vals, label]) +
'\n')
random.seed(0)
# 90% of the data are used for training, and 10% of the data are used
# for validation.
with open(os.path.join(outdir, 'train.txt'), 'w') as out_train:
with open(os.path.join(outdir, 'valid.txt'), 'w') as out_valid:
with open(os.path.join(datadir, 'train.txt'), 'r') as f:
for line in f:
features = line.rstrip('\n').split('\t')
continous_vals = []
for i in range(0, len(continous_features)):
val = dists.gen(i, features[continous_features[i]])
continous_vals.append(
"{0:.6f}".format(val).rstrip('0').rstrip('.'))
categorial_vals = []
for i in range(0, len(categorial_features)):
val = dicts.gen(i, features[categorial_features[
i]]) + categorial_feature_offset[i]
categorial_vals.append(str(val))
continous_vals = ','.join(continous_vals)
categorial_vals = ','.join(categorial_vals)
label = features[0]
if random.randint(0, 9999) % 10 != 0:
out_train.write('\t'.join(
[continous_vals, categorial_vals, label]) + '\n')
else:
out_valid.write('\t'.join(
[continous_vals, categorial_vals, label]) + '\n')
with open(os.path.join(outdir, 'test.txt'), 'w') as out:
with open(os.path.join(datadir, 'test.txt'), 'r') as f:
......@@ -130,7 +147,8 @@ def preprocess(datadir, outdir):
continous_vals = []
for i in range(0, len(continous_features)):
val = dists.gen(i, features[continous_features[i] - 1])
continous_vals.append(str(val))
continous_vals.append(
"{0:.6f}".format(val).rstrip('0').rstrip('.'))
categorial_vals = []
for i in range(0, len(categorial_features)):
val = dicts.gen(i,
......
......@@ -18,6 +18,9 @@ class Dataset:
def train(self, path):
return self._reader_creator(path, False)
def test(self, path):
return self._reader_creator(path, False)
def infer(self, path):
return self._reader_creator(path, True)
......
......@@ -20,11 +20,16 @@ def parse_args():
type=str,
required=True,
help="path of training dataset")
parser.add_argument(
'--test_data_path',
type=str,
required=True,
help="path of testing dataset")
parser.add_argument(
'--batch_size',
type=int,
default=10000,
help="size of mini-batch (default:10000)")
default=1000,
help="size of mini-batch (default:1000)")
parser.add_argument(
'--num_passes',
type=int,
......@@ -52,7 +57,7 @@ def train():
paddle.init(use_gpu=False, trainer_count=1)
optimizer = paddle.optimizer.Adam(learning_rate=1e-3)
optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
model = DeepFM(args.factor_size)
......@@ -66,11 +71,22 @@ def train():
def __event_handler__(event):
if isinstance(event, paddle.event.EndIteration):
num_samples = event.batch_id * args.batch_size
if event.batch_id % 10 == 0:
logger.warning("Pass %d, Batch %d, Samples %d, Cost %f" % (
event.pass_id, event.batch_id, num_samples, event.cost))
if event.batch_id % 100 == 0:
logger.warning("Pass %d, Batch %d, Samples %d, Cost %f, %s" %
(event.pass_id, event.batch_id, num_samples,
event.cost, event.metrics))
if event.batch_id % 10000 == 0:
if args.test_data_path:
result = trainer.test(
reader=paddle.batch(
dataset.test(args.test_data_path),
batch_size=args.batch_size),
feeding=reader.feeding)
logger.warning("Test %d-%d, Cost %f, %s" %
(event.pass_id, event.batch_id, result.cost,
result.metrics))
if event.batch_id % 1000 == 0:
path = "{}/model-pass-{}-batch-{}.tar.gz".format(
args.model_output_dir, event.pass_id, event.batch_id)
with gzip.open(path, 'w') as f:
......@@ -80,7 +96,7 @@ def train():
reader=paddle.batch(
paddle.reader.shuffle(
dataset.train(args.train_data_path),
buf_size=args.batch_size * 100),
buf_size=args.batch_size * 10000),
batch_size=args.batch_size),
feeding=reader.feeding,
event_handler=__event_handler__,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册