# 基于DNN模型的点击率预估模型 ## 介绍 本模型实现了下述论文中提出的DNN模型: ```text @inproceedings{guo2017deepfm, title={DeepFM: A Factorization-Machine based Neural Network for CTR Prediction}, author={Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li and Xiuqiang He}, booktitle={the Twenty-Sixth International Joint Conference on Artificial Intelligence (IJCAI)}, pages={1725--1731}, year={2017} } ``` ## 数据集 本文使用的是Kaggle公司举办的[展示广告竞赛](https://www.kaggle.com/c/criteo-display-ad-challenge/)中所使用的Criteo数据集。 每一行是一次广告展示的特征,第一列是一个标签,表示这次广告展示是否被点击。总共有39个特征,其中13个特征采用整型值,另外26个特征是类别类特征。测试集中是没有标签的。 下载数据集: ```bash cd data && ./download.sh && cd .. ``` ## 模型 本例子只实现了DeepFM论文中介绍的模型的DNN部分,DeepFM会在其他例子中给出。 ``` ## 数据准备 处理原始数据集,整型特征使用min-max归一化方法规范到[0, 1],类别类特征使用了one-hot编码。原始数据集分割成两部分:90%用于训练,其他10%用于训练过程中的验证。 ```bash python preprocess.py --datadir ./data/raw --outdir ./data ``` ## 训练 训练的命令行选项可以通过`python train.py -h`列出。 ### 单机训练: ```bash python train.py \ --train_data_path data/train.txt \ 2>&1 | tee train.log ``` 训练到第1轮的第40000个batch后,测试的AUC为0.807178,误差(cost)为0.445196。 ### 本地启动一个2 trainer 2 pserver的分布式训练任务 ```bash # start pserver0 python train.py \ --train_data_path /paddle/data/train.txt \ --is_local 0 \ --role pserver \ --endpoints 127.0.0.1:6000,127.0.0.1:6001 \ --current_endpoint 127.0.0.1:6000 \ --trainers 2 \ > pserver0.log 2>&1 & # start pserver1 python train.py \ --train_data_path /paddle/data/train.txt \ --is_local 0 \ --role pserver \ --endpoints 127.0.0.1:6000,127.0.0.1:6001 \ --current_endpoint 127.0.0.1:6001 \ --trainers 2 \ > pserver1.log 2>&1 & # start trainer0 python train.py \ --train_data_path /paddle/data/train.txt \ --is_local 0 \ --role trainer \ --endpoints 127.0.0.1:6000,127.0.0.1:6001 \ --trainers 2 \ --trainer_id 0 \ > trainer0.log 2>&1 & # start trainer1 python train.py \ --train_data_path /paddle/data/train.txt \ --is_local 0 \ --role trainer \ --endpoints 127.0.0.1:6000,127.0.0.1:6001 \ --trainers 2 \ --trainer_id 1 \ > trainer1.log 2>&1 & ``` ## 预测 预测的命令行选项可以通过`python infer.py -h`列出。 对测试集进行预测: ```bash python infer.py \ --model_path models/pass-0/ \ --data_path data/valid.txt ```