提交 b1787fe2 编写于 作者: X Xin Pan

clean v2

上级 31e2fd99
# models 简介
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://github.com/PaddlePaddle/models)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](https://github.com/PaddlePaddle/models)
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
PaddlePaddle提供了丰富的运算单元,帮助大家以模块化的方式构建起千变万化的深度学习模型来解决不同的应用问题。这里,我们针对常见的机器学习任务,提供了不同的神经网络模型供大家学习和使用。
## 1. 词向量
词向量用一个实向量表示词语,向量的每个维都表示文本的某种潜在语法或语义特征,是深度学习应用于自然语言处理领域最成功的概念和成果之一。广义的,词向量也可以应用于普通离散特征。词向量的学习通常都是一个无监督的学习过程,因此,可以充分利用海量的无标记数据以捕获特征之间的关系,也可以有效地解决特征稀疏、标签数据缺失、数据噪声等问题。然而,在常见词向量学习方法中,模型最后一层往往会遇到一个超大规模的分类问题,是计算性能的瓶颈。
在词向量任务中,我们向大家展示如何使用Hierarchical-Sigmoid 和噪声对比估计(Noise Contrastive Estimation,NCE)来加速词向量的学习。
- 1.1 [Hsigmoid加速词向量训练](https://github.com/PaddlePaddle/models/tree/develop/hsigmoid)
- 1.2 [噪声对比估计加速词向量训练](https://github.com/PaddlePaddle/models/tree/develop/nce_cost)
## 2. RNN 语言模型
语言模型是自然语言处理领域里一个重要的基础模型,除了得到词向量(语言模型训练的副产物),还可以帮助我们生成文本。给定若干个词,语言模型可以帮助我们预测下一个最可能出现的词。
在利用语言模型生成文本的任务中,我们重点介绍循环神经网络语言模型,大家可以通过文档中的使用说明快速适配到自己的训练语料,完成自动写诗、自动写散文等有趣的模型。
- 2.1 [使用循环神经网络语言模型生成文本](https://github.com/PaddlePaddle/models/tree/develop/generate_sequence_by_rnn_lm)
## 3. 点击率预估
点击率预估模型预判用户对一条广告点击的概率,对每次广告的点击情况做出预测,是广告技术的核心算法之一。逻谛斯克回归对大规模稀疏特征有着很好的学习能力,在点击率预估任务发展的早期一统天下。近年来,DNN 模型由于其强大的学习能力逐渐接过点击率预估任务的大旗。
在点击率预估任务中,我们首先给出谷歌提出的 Wide & Deep 模型。这一模型融合了适用于学习抽象特征的DNN和适用于大规模稀疏特征的逻谛斯克回归两者的优点,可以作为一种相对成熟的模型框架使用,在工业界也有一定的应用。同时,我们提供基于因子分解机的深度神经网络模型,该模型融合了因子分解机和深度神经网络,分别建模输入属性之间的低阶交互和高阶交互。
- 3.1 [Wide & deep 点击率预估模型](https://github.com/PaddlePaddle/models/tree/develop/ctr/README.cn.md)
- 3.2 [基于深度因子分解机的点击率预估模型](https://github.com/PaddlePaddle/models/tree/develop/deep_fm)
## 4. 文本分类
文本分类是自然语言处理领域最基础的任务之一,深度学习方法能够免除复杂的特征工程,直接使用原始文本作为输入,数据驱动地最优化分类准确率。
在文本分类任务中,我们以情感分类任务为例,提供了基于DNN的非序列文本分类模型,以及基于CNN的序列模型供大家学习和使用(基于LSTM的模型见PaddleBook中[情感分类](http://www.paddlepaddle.org/docs/develop/book/06.understand_sentiment/index.cn.html)一课)。
- 4.1 [基于DNN/CNN的情感分类](https://github.com/PaddlePaddle/models/tree/develop/text_classification)
- 4.2 [基于双层序列的文本分类模型](https://github.com/PaddlePaddle/models/tree/develop/nested_sequence/text_classification)
## 5. 排序学习
排序学习(Learning to Rank, LTR)是信息检索和搜索引擎研究的核心问题之一,通过机器学习方法学习一个分值函数对待排序的候选进行打分,再根据分值的高低确定序关系。深度神经网络可以用来建模分值函数,构成各类基于深度学习的LTR模型。
在排序学习任务中,我们介绍基于RankLoss损失函数Pairwise排序模型和基于LambdaRank损失函数的Listwise排序模型(Pointwise学习策略见PaddleBook中[推荐系统](http://www.paddlepaddle.org/docs/develop/book/05.recommender_system/index.cn.html)一课)。
- 5.1 [基于Pairwise和Listwise的排序学习](https://github.com/PaddlePaddle/models/tree/develop/ltr)
## 6. 结构化语义模型
深度结构化语义模型是一种基于神经网络的语义匹配模型框架,可以用于学习两路信息实体或是文本之间的语义相似性。DSSM使用DNN、CNN或是RNN将两路信息实体或是文本映射到同一个连续的低纬度语义空间中。在这个语义空间中,两路实体或是文本可以同时进行表示,然后,通过定义距离度量和匹配函数来刻画并学习不同实体或是文本在同一个语义空间内的语义相似性。
在结构化语义模型任务中,我们演示如何建模两个字符串之间的语义相似度。模型支持DNN(全连接前馈网络)、CNN(卷积网络)、RNN(递归神经网络)等不同的网络结构,以及分类、回归、排序等不同损失函数。本例采用最简单的文本数据作为输入,通过替换自己的训练和预测数据,便可以在真实场景中使用。
- 6.1 [深度结构化语义模型](https://github.com/PaddlePaddle/models/tree/develop/dssm/README.cn.md)
## 7. 命名实体识别
给定输入序列,序列标注模型为序列中每一个元素贴上一个类别标签,是自然语言处理领域最基础的任务之一。随着深度学习方法的不断发展,利用循环神经网络学习输入序列的特征表示,条件随机场(Conditional Random Field, CRF)在特征基础上完成序列标注任务,逐渐成为解决序列标注问题的标配解决方案。
在序列标注任务中,我们以命名实体识别(Named Entity Recognition,NER)任务为例,介绍如何训练一个端到端的序列标注模型。
- 7.1 [命名实体识别](https://github.com/PaddlePaddle/models/tree/develop/sequence_tagging_for_ner)
## 8. 序列到序列学习
序列到序列学习实现两个甚至是多个不定长模型之间的映射,有着广泛的应用,包括:机器翻译、智能对话与问答、广告创意语料生成、自动编码(如金融画像编码)、判断多个文本串之间的语义相关性等。
在序列到序列学习任务中,我们首先以机器翻译任务为例,提供了多种改进模型供大家学习和使用。包括:不带注意力机制的序列到序列映射模型,这一模型是所有序列到序列学习模型的基础;使用Scheduled Sampling改善RNN模型在生成任务中的错误累积问题;带外部记忆机制的神经机器翻译,通过增强神经网络的记忆能力,来完成复杂的序列到序列学习任务。除机器翻译任务之外,我们也提供了一个基于深层LSTM网络生成古诗词,实现同语言生成的模型。
- 8.1 [无注意力机制的神经机器翻译](https://github.com/PaddlePaddle/models/tree/develop/nmt_without_attention/README.cn.md)
- 8.2 [使用Scheduled Sampling改善翻译质量](https://github.com/PaddlePaddle/models/tree/develop/scheduled_sampling)
- 8.3 [带外部记忆机制的神经机器翻译](https://github.com/PaddlePaddle/models/tree/develop/mt_with_external_memory)
- 8.4 [生成古诗词](https://github.com/PaddlePaddle/models/tree/develop/generate_chinese_poetry)
## 9. 阅读理解
当深度学习以及各类新技术不断推动自然语言处理领域向前发展时,我们不禁会问:应该如何确认模型真正理解了人类特有的自然语言,具备一定的理解和推理能力?纵观NLP领域的各类经典问题:词法分析、句法分析、情感分类、写诗等,这些问题的经典解决方案,从技术原理上距离“语言理解”仍有一定距离。为了衡量现有NLP技术到“语言理解”这一终极目标之间的差距,我们需要一个有足够难度且可量化可复现的任务,这也是阅读理解问题提出的初衷。尽管目前的研究现状表明在现有阅读理解数据集上表现良好的模型,依然没有做到真正的语言理解,但机器阅读理解依然被视为是检验模型向理解语言迈进的一个重要任务。
阅读理解本质上也是自动问答的一种,模型“阅读”一段文字后回答给定的问题,在这一任务中,我们介绍使用Learning to Search 方法,将阅读理解转化为从段落中寻找答案所在句子,答案在句子中的起始位置,以及答案在句子中的结束位置,这样一个多步决策过程。
- 9.1 [Globally Normalized Reader](https://github.com/PaddlePaddle/models/tree/develop/globally_normalized_reader)
## 10. 自动问答
自动问答(Question Answering)系统利用计算机自动回答用户提出的问题,是验证机器是否具备自然语言理解能力的重要任务之一,其研究历史可以追溯到人工智能的原点。与检索系统相比,自动问答系统是信息服务的一种高级形式,系统返回给用户的不再是排序后的基于关键字匹配的检索结果,而是精准的自然语言答案。
在自动问答任务中,我们介绍基于深度学习的端到端问答系统,将自动问答转化为一个序列标注问题。端对端问答系统试图通过从高质量的"问题-证据(Evidence)-答案"数据中学习,建立一个联合学习模型,同时学习语料库、知识库、问句语义表示之间的语义映射关系,将传统的问句语义解析、文本检索、答案抽取与生成的复杂步骤转变为一个可学习过程。
- 10.1 [基于序列标注的事实型自动问答模型](https://github.com/PaddlePaddle/models/tree/develop/neural_qa)
## 11. 图像分类
图像相比文字能够提供更加生动、容易理解及更具艺术感的信息,是人们转递与交换信息的重要来源。图像分类是根据图像的语义信息对不同类别图像进行区分,是计算机视觉中重要的基础问题,也是图像检测、图像分割、物体跟踪、行为分析等其他高层视觉任务的基础,在许多领域都有着广泛的应用。如:安防领域的人脸识别和智能视频分析等,交通领域的交通场景识别,互联网领域基于内容的图像检索和相册自动归类,医学领域的图像识别等。
在图像分类任务中,我们向大家介绍如何训练AlexNet、VGG、GoogLeNet、ResNet、Inception-v4、Inception-Resnet-V2和Xception模型。同时提供了能够将Caffe或TensorFlow训练好的模型文件转换为PaddlePaddle模型文件的模型转换工具。
- 11.1 [将Caffe模型文件转换为PaddlePaddle模型文件](https://github.com/PaddlePaddle/models/tree/develop/image_classification/caffe2paddle)
- 11.2 [将TensorFlow模型文件转换为PaddlePaddle模型文件](https://github.com/PaddlePaddle/models/tree/develop/image_classification/tf2paddle)
- 11.3 [AlexNet](https://github.com/PaddlePaddle/models/tree/develop/image_classification)
- 11.4 [VGG](https://github.com/PaddlePaddle/models/tree/develop/image_classification)
- 11.5 [Residual Network](https://github.com/PaddlePaddle/models/tree/develop/image_classification)
- 11.6 [Inception-v4](https://github.com/PaddlePaddle/models/tree/develop/image_classification)
- 11.7 [Inception-Resnet-V2](https://github.com/PaddlePaddle/models/tree/develop/image_classification)
- 11.8 [Xception](https://github.com/PaddlePaddle/models/tree/develop/image_classification)
## 12. 目标检测
目标检测任务的目标是给定一张图像或是视频帧,让计算机找出其中所有目标的位置,并给出每个目标的具体类别。对于人类来说,目标检测是一个非常简单的任务。然而,计算机能够“看到”的仅有一些值为0 ~ 255的矩阵,很难解图像或是视频帧中出现了人或是物体这样的高层语义概念,也就更加难以定位目标出现在图像中哪个区域。与此同时,由于目标会出现在图像或是视频帧中的任何位置,目标的形态千变万化,图像或是视频帧的背景千差万别,诸多因素都使得目标检测对计算机来说是一个具有挑战性的问题。
在目标检测任务中,我们介绍利用SSD方法完成目标检测。SSD全称:Single Shot MultiBox Detector,是目标检测领域较新且效果较好的检测算法之一,具有检测速度快且检测精度高的特点。
- 12.1 [Single Shot MultiBox Detector](https://github.com/PaddlePaddle/models/tree/develop/ssd/README.cn.md)
## 13. 场景文字识别
许多场景图像中包含着丰富的文本信息,对理解图像信息有着重要作用,能够极大地帮助人们认知和理解场景图像的内容。场景文字识别是在图像背景复杂、分辨率低下、字体多样、分布随意等情况下,将图像信息转化为文字序列的过程,可认为是一种特别的翻译过程:将图像输入翻译为自然语言输出。场景图像文字识别技术的发展也促进了一些新型应用的产生,如通过自动识别路牌中的文字帮助街景应用获取更加准确的地址信息等。
在场景文字识别任务中,我们介绍如何将基于CNN的图像特征提取和基于RNN的序列翻译技术结合,免除人工定义特征,避免字符分割,使用自动学习到的图像特征,完成端到端地无约束字符定位和识别。
- 13.1 [场景文字识别](https://github.com/PaddlePaddle/models/tree/develop/scene_text_recognition)
## 14. 语音识别
语音识别技术(Auto Speech Recognize,简称ASR)将人类语音中的词汇内容转化为计算机可读的输入,让机器能够“听懂”人类的语音,在语音助手、语音输入、语音交互等应用中发挥着重要作用。深度学习在语音识别领域取得了瞩目的成绩,端到端的深度学习方法将传统的声学模型、词典、语言模型等模块融为一个整体,不再依赖隐马尔可夫模型中的各种条件独立性假设,令模型变得更加简洁,一个神经网络模型以语音特征为输入,直接输出识别出的文本,目前已经成为语音识别最重要的手段。
在语音识别任务中,我们提供了基于 DeepSpeech2 模型的完整流水线,包括:特征提取、数据增强、模型训练、语言模型、解码模块等,并提供一个训练好的模型和体验实例,大家能够使用自己的声音来体验语音识别的乐趣。
14.1 [语音识别: DeepSpeech2](https://github.com/PaddlePaddle/DeepSpeech)
本教程由[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)创作,采用[Apache-2.0](LICENSE) 许可协议进行许可。
# Introduction to models
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://github.com/PaddlePaddle/models)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](https://github.com/PaddlePaddle/models)
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
PaddlePaddle provides a rich set of computational units to enable users to adopt a modular approach to solving various learning problems. In this repo, we demonstrate how to use PaddlePaddle to solve common machine learning tasks, providing several different neural network model that anyone can easily learn and use.
## 1. Word Embedding
The word embedding expresses words with a real vector. Each dimension of the vector represents some of the latent grammatical or semantic features of the text and is one of the most successful concepts in the field of natural language processing. The generalized word vector can also be applied to discrete features. The study of word vector is usually an unsupervised learning. Therefore, it is possible to take full advantage of massive unmarked data to capture the relationship between features and to solve the problem of sparse features, missing tag data, and data noise. However, in the common word vector learning method, the last layer of the model often encounters a large-scale classification problem, which is the bottleneck of computing performance.
In the example of word vectors, we show how to use Hierarchical-Sigmoid and Noise Contrastive Estimation (NCE) to accelerate word-vector learning.
- 1.1 [Hsigmoid Accelerated Word Vector Training](https://github.com/PaddlePaddle/models/tree/develop/v2/hsigmoid)
- 1.2 [Noise Contrastive Estimation Accelerated Word Vector Training](https://github.com/PaddlePaddle/models/tree/develop/v2/nce_cost)
## 2. RNN language model
The language model is important in the field of natural language processing. In addition to getting the word vector (a by-product of language model training), it can also help us to generate text. Given a number of words, the language model can help us predict the next most likely word. In the example of using the language model to generate text, we focus on the recurrent neural network language model. We can use the instructions in the document quickly adapt to their training corpus, complete automatic writing poetry, automatic writing prose and other interesting models.
- 2.1 [Generate text using the RNN language model](https://github.com/PaddlePaddle/models/tree/develop/v2/generate_sequence_by_rnn_lm)
## 3. Click-Through Rate prediction
The click-through rate model predicts the probability that a user will click on an ad. This is widely used for advertising technology. Logistic Regression has a good learning performance for large-scale sparse features in the early stages of the development of click-through rate prediction. In recent years, DNN model because of its strong learning ability to gradually take the banner rate of the task of the banner.
In the example of click-through rate estimates, we first give the Google's Wide & Deep model. This model combines the advantages of DNN and the applicable logistic regression model for DNN and large-scale sparse features. Then we provide the deep factorization machine for click-through rate prediction. The deep factorization machine combines the factorization machine and deep neural networks to model both low order and high order interactions of input features.
- 3.1 [Click-Through Rate Model](https://github.com/PaddlePaddle/models/tree/develop/v2/ctr)
- 3.2 [Deep Factorization Machine for Click-Through Rate prediction](https://github.com/PaddlePaddle/models/tree/develop/v2/deep_fm)
## 4. Text classification
Text classification is one of the most basic tasks in natural language processing. The deep learning method can eliminate the complex feature engineering, and use the original text as input to optimize the classification accuracy.
For text classification, we provide a non-sequential text classification model based on DNN and CNN. (For LSTM-based model, please refer to PaddleBook [Sentiment Analysis](http://www.paddlepaddle.org/docs/develop/book/06.understand_sentiment/index.html)).
- 4.1 [Sentiment analysis based on DNN / CNN](https://github.com/PaddlePaddle/models/tree/develop/v2/text_classification)
## 5. Learning to rank
Learning to rank (LTR) is one of the core problems in information retrieval and search engine research. Training data is used by a learning algorithm to produce a ranking model which computes the relevance of documents for actual queries.
The depth neural network can be used to model the fractional function to form various LTR models based on depth learning.
The algorithms for learning to rank are usually categorized into three groups by their input representation and the loss function. These are pointwise, pairwise and listwise approaches. Here we demonstrate RankLoss loss function method (pairwise approach), and LambdaRank loss function method (listwise approach). (For Pointwise approaches, please refer to [Recommended System](http://www.paddlepaddle.org/docs/develop/book/05.recommender_system/index.html)).
- 5.1 [Learning to rank based on Pairwise and Listwise approches](https://github.com/PaddlePaddle/models/tree/develop/v2/ltr)
## 6. Semantic model
The deep structured semantic model uses the DNN model to learn the vector representation of the low latitude in a continuous semantic space, finally models the semantic similarity between the two sentences.
In this example, we demonstrate how to use PaddlePaddle to implement a generic deep structured semantic model to model the semantic similarity between two strings. The model supports different network structures such as CNN (Convolutional Network), FC (Fully Connected Network), RNN (Recurrent Neural Network), and different loss functions such as classification, regression, and sequencing.
- 6.1 [Deep structured semantic model](https://github.com/PaddlePaddle/models/tree/develop/v2/dssm)
## 7. Sequence tagging
Given the input sequence, the sequence tagging model is one of the most basic tasks in the natural language processing by assigning a category tag to each element in the sequence. Recurrent neural network models with Conditional Random Field (CRF) are commonly used for sequence tagging tasks.
In the example of the sequence tagging, we describe how to train an end-to-end sequence tagging model with the Named Entity Recognition (NER) task as an example.
- 7.1 [Name Entity Recognition](https://github.com/PaddlePaddle/models/tree/develop/v2/sequence_tagging_for_ner)
## 8. Sequence to sequence learning
Sequence-to-sequence model has a wide range of applications. This includes machine translation, dialogue system, and parse tree generation.
As an example for sequence-to-sequence learning, we take the machine translation task. We demonstrate the sequence-to-sequence mapping model without attention mechanism, which is the basis for all sequence-to-sequence learning models. We will use scheduled sampling to improve the problem of error accumulation in the RNN model, and machine translation with external memory mechanism.
- 8.1 [Basic Sequence-to-sequence model](https://github.com/PaddlePaddle/models/tree/develop/v2/nmt_without_attention)
## 9. Image classification
For the example of image classification, we show you how to train AlexNet, VGG, GoogLeNet, ResNet, Inception-v4, Inception-Resnet-V2 and Xception models in PaddlePaddle. It also provides model conversion tools that convert Caffe or TensorFlow trained model files into PaddlePaddle model files.
- 9.1 [convert Caffe model file to PaddlePaddle model file](https://github.com/PaddlePaddle/models/tree/develop/v2/image_classification/caffe2paddle)
- 9.2 [convert TensorFlow model file to PaddlePaddle model file](https://github.com/PaddlePaddle/models/tree/develop/v2/image_classification/tf2paddle)
- 9.3 [AlexNet](https://github.com/PaddlePaddle/models/tree/develop/v2/image_classification)
- 9.4 [VGG](https://github.com/PaddlePaddle/models/tree/develop/v2/image_classification)
- 9.5 [Residual Network](https://github.com/PaddlePaddle/models/tree/develop/v2/image_classification)
- 9.6 [Inception-v4](https://github.com/PaddlePaddle/models/tree/develop/v2/image_classification)
- 9.7 [Inception-Resnet-V2](https://github.com/PaddlePaddle/models/tree/develop/v2/image_classification)
- 9.8 [Xception](https://github.com/PaddlePaddle/models/tree/develop/v2/image_classification)
This tutorial is contributed by [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) and licensed under the [Apache-2.0 license](LICENSE).
The minimum PaddlePaddle version needed for the code sample in this directory is v0.11.0. If you are on a version of PaddlePaddle earlier than v0.11.0, [please update your installation](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html).
---
# Convolutional Sequence to Sequence Learning
This model implements the work in the following paper:
Jonas Gehring, Micheal Auli, David Grangier, et al. Convolutional Sequence to Sequence Learning. Association for Computational Linguistics (ACL), 2017
# Data Preparation
- The data used in this tutorial can be downloaded by runing:
```bash
sh download.sh
```
- Each line in the data file contains one sample and each sample consists of a source sentence and a target sentence. And the two sentences are seperated by '\t'. So, to use your own data, it should be organized as follows:
```
<source sentence>\t<target sentence>
```
# Training a Model
- Modify the following script if needed and then run:
```bash
python train.py \
--train_data_path ./data/train \
--test_data_path ./data/test \
--src_dict_path ./data/src_dict \
--trg_dict_path ./data/trg_dict \
--enc_blocks "[(256, 3)] * 5" \
--dec_blocks "[(256, 3)] * 3" \
--emb_size 256 \
--pos_size 200 \
--drop_rate 0.2 \
--use_bn False \
--use_gpu False \
--trainer_count 1 \
--batch_size 32 \
--num_passes 20 \
>train.log 2>&1
```
# Inferring by a Trained Model
- Infer by a trained model by running:
```bash
python infer.py \
--infer_data_path ./data/dev \
--src_dict_path ./data/src_dict \
--trg_dict_path ./data/trg_dict \
--enc_blocks "[(256, 3)] * 5" \
--dec_blocks "[(256, 3)] * 3" \
--emb_size 256 \
--pos_size 200 \
--drop_rate 0.2 \
--use_bn False \
--use_gpu False \
--trainer_count 1 \
--max_len 100 \
--batch_size 256 \
--beam_size 1 \
--is_show_attention False \
--model_path ./params.pass-0.tar.gz \
1>infer_result 2>infer.log
```
# Notes
Since PaddlePaddle of current version doesn't support weight normalization, we use batch normalization instead to confirm convergence when the network is deep.
#coding=utf-8
import sys
import time
import math
import numpy as np
import reader
class BeamSearch(object):
"""
Generate sequence by beam search
"""
def __init__(self,
inferer,
trg_dict,
pos_size,
padding_num,
batch_size=1,
beam_size=1,
max_len=100):
self.inferer = inferer
self.trg_dict = trg_dict
self.reverse_trg_dict = reader.get_reverse_dict(trg_dict)
self.word_padding = trg_dict.__len__()
self.pos_size = pos_size
self.pos_padding = pos_size
self.padding_num = padding_num
self.win_len = padding_num + 1
self.max_len = max_len
self.batch_size = batch_size
self.beam_size = beam_size
def get_beam_input(self, batch, sample_list):
"""
Get input for generation at the current iteration.
"""
beam_input = []
for sample_id in sample_list:
for path in self.candidate_path[sample_id]:
if len(path['seq']) < self.win_len:
cur_trg = [self.word_padding] * (
self.win_len - len(path['seq']) - 1
) + [self.trg_dict['<s>']] + path['seq']
cur_trg_pos = [self.pos_padding] * (
self.win_len - len(path['seq']) - 1) + [0] + range(
1, len(path['seq']) + 1)
else:
cur_trg = path['seq'][-self.win_len:]
cur_trg_pos = range(
len(path['seq']) + 1 - self.win_len,
len(path['seq']) + 1)
beam_input.append(batch[sample_id] + [cur_trg] + [cur_trg_pos])
return beam_input
def get_prob(self, beam_input):
"""
Get the probabilities of all possible tokens.
"""
row_list = [j * self.win_len for j in range(len(beam_input))]
prob = self.inferer.infer(beam_input, field='value')[row_list, :]
return prob
def _top_k(self, prob, k):
"""
Get indices of the words with k highest probablities.
"""
return prob.argsort()[-k:][::-1]
def beam_expand(self, prob, sample_list):
"""
In every iteration step, the model predicts the possible next words.
For each input sentence, the top beam_size words are selected as candidates.
"""
top_words = np.apply_along_axis(self._top_k, 1, prob, self.beam_size)
candidate_words = [[]] * len(self.candidate_path)
idx = 0
for sample_id in sample_list:
for seq_id, path in enumerate(self.candidate_path[sample_id]):
for w in top_words[idx, :]:
score = path['score'] + math.log(prob[idx, w])
candidate_words[sample_id] = candidate_words[sample_id] + [{
'word': w,
'score': score,
'seq_id': seq_id
}]
idx = idx + 1
return candidate_words
def beam_shrink(self, candidate_words, sample_list):
"""
Pruning process of the beam search. During the process, beam_size most post possible
sequences are selected for the beam in the next generation.
"""
new_path = [[]] * len(self.candidate_path)
for sample_id in sample_list:
beam_words = sorted(
candidate_words[sample_id],
key=lambda x: x['score'],
reverse=True)[:self.beam_size]
complete_seq_min_score = None
complete_path_num = len(self.complete_path[sample_id])
if complete_path_num > 0:
complete_seq_min_score = min(self.complete_path[sample_id],
key=lambda x: x['score'])['score']
if complete_path_num >= self.beam_size:
beam_words_max_score = beam_words[0]['score']
if beam_words_max_score < complete_seq_min_score:
continue
for w in beam_words:
if w['word'] == self.trg_dict['<e>']:
if complete_path_num < self.beam_size or complete_seq_min_score <= w[
'score']:
seq = self.candidate_path[sample_id][w['seq_id']]['seq']
self.complete_path[sample_id] = self.complete_path[
sample_id] + [{
'seq': seq,
'score': w['score']
}]
if complete_seq_min_score is None or complete_seq_min_score > w[
'score']:
complete_seq_min_score = w['score']
else:
seq = self.candidate_path[sample_id][w['seq_id']]['seq'] + [
w['word']
]
new_path[sample_id] = new_path[sample_id] + [{
'seq': seq,
'score': w['score']
}]
return new_path
def search_one_batch(self, batch):
"""
Perform beam search on one mini-batch.
"""
real_size = len(batch)
self.candidate_path = [[{'seq': [], 'score': 0.}]] * real_size
self.complete_path = [[]] * real_size
sample_list = range(real_size)
for i in xrange(self.max_len):
beam_input = self.get_beam_input(batch, sample_list)
prob = self.get_prob(beam_input)
candidate_words = self.beam_expand(prob, sample_list)
new_path = self.beam_shrink(candidate_words, sample_list)
self.candidate_path = new_path
sample_list = [
sample_id for sample_id in sample_list
if len(new_path[sample_id]) > 0
]
if len(sample_list) == 0:
break
final_path = []
for i in xrange(real_size):
top_path = sorted(
self.complete_path[i] + self.candidate_path[i],
key=lambda x: x['score'],
reverse=True)[:self.beam_size]
final_path.append(top_path)
return final_path
def search(self, infer_data):
"""
Perform beam search on all data.
"""
def _to_sentence(seq):
raw_sentence = [self.reverse_trg_dict[id] for id in seq]
sentence = " ".join(raw_sentence)
return sentence
for pos in xrange(0, len(infer_data), self.batch_size):
batch = infer_data[pos:min(pos + self.batch_size, len(infer_data))]
self.final_path = self.search_one_batch(batch)
for top_path in self.final_path:
print _to_sentence(top_path[0]['seq'])
sys.stdout.flush()
#!/usr/bin/env bash
CUR_PATH=`pwd`
git clone https://github.com/moses-smt/mosesdecoder.git
git clone https://github.com/rizar/actor-critic-public
export MOSES=`pwd`/mosesdecoder
export LVSR=`pwd`/actor-critic-public
cd actor-critic-public/exp/ted
sh create_dataset.sh
cd $CUR_PATH
mkdir data
cp actor-critic-public/exp/ted/prep/*-* data/
cp actor-critic-public/exp/ted/vocab.* data/
cd data
python ../preprocess.py
cd ..
rm -rf actor-critic-public mosesdecoder
#coding=utf-8
import sys
import argparse
import distutils.util
import gzip
import paddle.v2 as paddle
from model import conv_seq2seq
from beamsearch import BeamSearch
import reader
def parse_args():
parser = argparse.ArgumentParser(
description="PaddlePaddle Convolutional Seq2Seq")
parser.add_argument(
'--infer_data_path',
type=str,
required=True,
help="Path of the dataset for inference")
parser.add_argument(
'--src_dict_path',
type=str,
required=True,
help='Path of the source dictionary')
parser.add_argument(
'--trg_dict_path',
type=str,
required=True,
help='path of the target dictionary')
parser.add_argument(
'--enc_blocks', type=str, help='Convolution blocks of the encoder')
parser.add_argument(
'--dec_blocks', type=str, help='Convolution blocks of the decoder')
parser.add_argument(
'--emb_size',
type=int,
default=256,
help='Dimension of word embedding. (default: %(default)s)')
parser.add_argument(
'--pos_size',
type=int,
default=200,
help='Total number of the position indexes. (default: %(default)s)')
parser.add_argument(
'--drop_rate',
type=float,
default=0.,
help='Dropout rate. (default: %(default)s)')
parser.add_argument(
"--use_bn",
default=False,
type=distutils.util.strtobool,
help="Use batch normalization or not. (default: %(default)s)")
parser.add_argument(
"--use_gpu",
default=False,
type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--trainer_count",
default=1,
type=int,
help="Trainer number. (default: %(default)s)")
parser.add_argument(
'--max_len',
type=int,
default=100,
help="The maximum length of the sentence to be generated. (default: %(default)s)"
)
parser.add_argument(
"--batch_size",
default=1,
type=int,
help="Size of a mini-batch. (default: %(default)s)")
parser.add_argument(
"--beam_size",
default=1,
type=int,
help="The width of beam expansion. (default: %(default)s)")
parser.add_argument(
"--model_path",
type=str,
required=True,
help="The path of trained model. (default: %(default)s)")
parser.add_argument(
"--is_show_attention",
default=False,
type=distutils.util.strtobool,
help="Whether to show attention weight or not. (default: %(default)s)")
return parser.parse_args()
def infer(infer_data_path,
src_dict_path,
trg_dict_path,
model_path,
enc_conv_blocks,
dec_conv_blocks,
emb_dim=256,
pos_size=200,
drop_rate=0.,
use_bn=False,
max_len=100,
batch_size=1,
beam_size=1,
is_show_attention=False):
"""
Inference.
:param infer_data_path: The path of the data for inference.
:type infer_data_path: str
:param src_dict_path: The path of the source dictionary.
:type src_dict_path: str
:param trg_dict_path: The path of the target dictionary.
:type trg_dict_path: str
:param model_path: The path of a trained model.
:type model_path: str
:param enc_conv_blocks: The scale list of the encoder's convolution blocks. And each element of
the list contains output dimension and context length of the corresponding
convolution block.
:type enc_conv_blocks: list of tuple
:param dec_conv_blocks: The scale list of the decoder's convolution blocks. And each element of
the list contains output dimension and context length of the corresponding
convolution block.
:type dec_conv_blocks: list of tuple
:param emb_dim: The dimension of the embedding vector.
:type emb_dim: int
:param pos_size: The total number of the position indexes, which means
the maximum value of the index is pos_size - 1.
:type pos_size: int
:param drop_rate: Dropout rate.
:type drop_rate: float
:param use_bn: Whether to use batch normalization or not. False is the default value.
:type use_bn: bool
:param max_len: The maximum length of the sentence to be generated.
:type max_len: int
:param beam_size: The width of beam expansion.
:type beam_size: int
:param is_show_attention: Whether to show attention weight or not. False is the default value.
:type is_show_attention: bool
"""
# load dict
src_dict = reader.load_dict(src_dict_path)
trg_dict = reader.load_dict(trg_dict_path)
src_dict_size = src_dict.__len__()
trg_dict_size = trg_dict.__len__()
prob, weight = conv_seq2seq(
src_dict_size=src_dict_size,
trg_dict_size=trg_dict_size,
pos_size=pos_size,
emb_dim=emb_dim,
enc_conv_blocks=enc_conv_blocks,
dec_conv_blocks=dec_conv_blocks,
drop_rate=drop_rate,
with_bn=use_bn,
is_infer=True)
# load parameters
parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_path))
padding_list = [context_len - 1 for (size, context_len) in dec_conv_blocks]
padding_num = reduce(lambda x, y: x + y, padding_list)
infer_reader = reader.data_reader(
data_file=infer_data_path,
src_dict=src_dict,
trg_dict=trg_dict,
pos_size=pos_size,
padding_num=padding_num)
if is_show_attention:
attention_inferer = paddle.inference.Inference(
output_layer=weight, parameters=parameters)
for i, data in enumerate(infer_reader()):
src_len = len(data[0])
trg_len = len(data[2])
attention_weight = attention_inferer.infer(
[data], field='value', flatten_result=False)
attention_weight = [
weight.reshape((trg_len, src_len))
for weight in attention_weight
]
print attention_weight
break
return
infer_data = []
for i, raw_data in enumerate(infer_reader()):
infer_data.append([raw_data[0], raw_data[1]])
inferer = paddle.inference.Inference(
output_layer=prob, parameters=parameters)
searcher = BeamSearch(
inferer=inferer,
trg_dict=trg_dict,
pos_size=pos_size,
padding_num=padding_num,
max_len=max_len,
batch_size=batch_size,
beam_size=beam_size)
searcher.search(infer_data)
return
def main():
args = parse_args()
enc_conv_blocks = eval(args.enc_blocks)
dec_conv_blocks = eval(args.dec_blocks)
sys.setrecursionlimit(10000)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
infer(
infer_data_path=args.infer_data_path,
src_dict_path=args.src_dict_path,
trg_dict_path=args.trg_dict_path,
model_path=args.model_path,
enc_conv_blocks=enc_conv_blocks,
dec_conv_blocks=dec_conv_blocks,
emb_dim=args.emb_size,
pos_size=args.pos_size,
drop_rate=args.drop_rate,
use_bn=args.use_bn,
max_len=args.max_len,
batch_size=args.batch_size,
beam_size=args.beam_size,
is_show_attention=args.is_show_attention)
if __name__ == '__main__':
main()
#coding=utf-8
import math
import paddle.v2 as paddle
__all__ = ["conv_seq2seq"]
def gated_conv_with_batchnorm(input,
size,
context_len,
context_start=None,
learning_rate=1.0,
drop_rate=0.,
with_bn=False):
"""
Definition of the convolution block.
:param input: The input of this block.
:type input: LayerOutput
:param size: The dimension of the block's output.
:type size: int
:param context_len: The context length of the convolution.
:type context_len: int
:param context_start: The start position of the context.
:type context_start: int
:param learning_rate: The learning rate factor of the parameters in the block.
The actual learning rate is the product of the global
learning rate and this factor.
:type learning_rate: float
:param drop_rate: Dropout rate.
:type drop_rate: float
:param with_bn: Whether to use batch normalization or not. False is the default
value.
:type with_bn: bool
:return: The output of the convolution block.
:rtype: LayerOutput
"""
input = paddle.layer.dropout(input=input, dropout_rate=drop_rate)
context = paddle.layer.mixed(
size=input.size * context_len,
input=paddle.layer.context_projection(
input=input, context_len=context_len, context_start=context_start))
raw_conv = paddle.layer.fc(
input=context,
size=size * 2,
act=paddle.activation.Linear(),
param_attr=paddle.attr.Param(
initial_mean=0.,
initial_std=math.sqrt(4.0 * (1.0 - drop_rate) / context.size),
learning_rate=learning_rate),
bias_attr=False)
if with_bn:
raw_conv = paddle.layer.batch_norm(
input=raw_conv,
act=paddle.activation.Linear(),
param_attr=paddle.attr.Param(learning_rate=learning_rate))
with paddle.layer.mixed(size=size) as conv:
conv += paddle.layer.identity_projection(raw_conv, size=size, offset=0)
with paddle.layer.mixed(size=size, act=paddle.activation.Sigmoid()) as gate:
gate += paddle.layer.identity_projection(
raw_conv, size=size, offset=size)
with paddle.layer.mixed(size=size) as gated_conv:
gated_conv += paddle.layer.dotmul_operator(conv, gate)
return gated_conv
def encoder(token_emb,
pos_emb,
conv_blocks=[(256, 3)] * 5,
num_attention=3,
drop_rate=0.,
with_bn=False):
"""
Definition of the encoder.
:param token_emb: The embedding vector of the input token.
:type token_emb: LayerOutput
:param pos_emb: The embedding vector of the input token's position.
:type pos_emb: LayerOutput
:param conv_blocks: The scale list of the convolution blocks. Each element of
the list contains output dimension and context length of
the corresponding convolution block.
:type conv_blocks: list of tuple
:param num_attention: The total number of the attention modules used in the decoder.
:type num_attention: int
:param drop_rate: Dropout rate.
:type drop_rate: float
:param with_bn: Whether to use batch normalization or not. False is the default
value.
:type with_bn: bool
:return: The input token encoding.
:rtype: LayerOutput
"""
embedding = paddle.layer.addto(
input=[token_emb, pos_emb],
layer_attr=paddle.attr.Extra(drop_rate=drop_rate))
proj_size = conv_blocks[0][0]
block_input = paddle.layer.fc(
input=embedding,
size=proj_size,
act=paddle.activation.Linear(),
param_attr=paddle.attr.Param(
initial_mean=0.,
initial_std=math.sqrt((1.0 - drop_rate) / embedding.size),
learning_rate=1.0 / (2.0 * num_attention)),
bias_attr=True, )
for (size, context_len) in conv_blocks:
if block_input.size == size:
residual = block_input
else:
residual = paddle.layer.fc(
input=block_input,
size=size,
act=paddle.activation.Linear(),
param_attr=paddle.attr.Param(learning_rate=1.0 /
(2.0 * num_attention)),
bias_attr=True)
gated_conv = gated_conv_with_batchnorm(
input=block_input,
size=size,
context_len=context_len,
learning_rate=1.0 / (2.0 * num_attention),
drop_rate=drop_rate,
with_bn=with_bn)
with paddle.layer.mixed(size=size) as block_output:
block_output += paddle.layer.identity_projection(residual)
block_output += paddle.layer.identity_projection(gated_conv)
# halve the variance of the sum
block_output = paddle.layer.slope_intercept(
input=block_output, slope=math.sqrt(0.5))
block_input = block_output
emb_dim = embedding.size
encoded_vec = paddle.layer.fc(
input=block_output,
size=emb_dim,
act=paddle.activation.Linear(),
param_attr=paddle.attr.Param(learning_rate=1.0 / (2.0 * num_attention)),
bias_attr=True)
encoded_sum = paddle.layer.addto(input=[encoded_vec, embedding])
# halve the variance of the sum
encoded_sum = paddle.layer.slope_intercept(
input=encoded_sum, slope=math.sqrt(0.5))
return encoded_vec, encoded_sum
def attention(decoder_state, cur_embedding, encoded_vec, encoded_sum):
"""
Definition of the attention.
:param decoder_state: The hidden state of the decoder.
:type decoder_state: LayerOutput
:param cur_embedding: The embedding vector of the current token.
:type cur_embedding: LayerOutput
:param encoded_vec: The source token encoding.
:type encoded_vec: LayerOutput
:param encoded_sum: The sum of the source token's encoding and embedding.
:type encoded_sum: LayerOutput
:return: A context vector and the attention weight.
:rtype: LayerOutput
"""
residual = decoder_state
state_size = decoder_state.size
emb_dim = cur_embedding.size
with paddle.layer.mixed(size=emb_dim, bias_attr=True) as state_summary:
state_summary += paddle.layer.full_matrix_projection(decoder_state)
state_summary += paddle.layer.identity_projection(cur_embedding)
# halve the variance of the sum
state_summary = paddle.layer.slope_intercept(
input=state_summary, slope=math.sqrt(0.5))
expanded = paddle.layer.expand(input=state_summary, expand_as=encoded_vec)
m = paddle.layer.dot_prod(input1=expanded, input2=encoded_vec)
attention_weight = paddle.layer.fc(input=m,
size=1,
act=paddle.activation.SequenceSoftmax(),
bias_attr=False)
scaled = paddle.layer.scaling(weight=attention_weight, input=encoded_sum)
attended = paddle.layer.pooling(
input=scaled, pooling_type=paddle.pooling.Sum())
attended_proj = paddle.layer.fc(input=attended,
size=state_size,
act=paddle.activation.Linear(),
bias_attr=True)
attention_result = paddle.layer.addto(input=[attended_proj, residual])
# halve the variance of the sum
attention_result = paddle.layer.slope_intercept(
input=attention_result, slope=math.sqrt(0.5))
return attention_result, attention_weight
def decoder(token_emb,
pos_emb,
encoded_vec,
encoded_sum,
dict_size,
conv_blocks=[(256, 3)] * 3,
drop_rate=0.,
with_bn=False):
"""
Definition of the decoder.
:param token_emb: The embedding vector of the input token.
:type token_emb: LayerOutput
:param pos_emb: The embedding vector of the input token's position.
:type pos_emb: LayerOutput
:param encoded_vec: The source token encoding.
:type encoded_vec: LayerOutput
:param encoded_sum: The sum of the source token's encoding and embedding.
:type encoded_sum: LayerOutput
:param dict_size: The size of the target dictionary.
:type dict_size: int
:param conv_blocks: The scale list of the convolution blocks. Each element
of the list contains output dimension and context length
of the corresponding convolution block.
:type conv_blocks: list of tuple
:param drop_rate: Dropout rate.
:type drop_rate: float
:param with_bn: Whether to use batch normalization or not. False is the default
value.
:type with_bn: bool
:return: The probability of the predicted token and the attention weights.
:rtype: LayerOutput
"""
def attention_step(decoder_state, cur_embedding, encoded_vec, encoded_sum):
conditional = attention(
decoder_state=decoder_state,
cur_embedding=cur_embedding,
encoded_vec=encoded_vec,
encoded_sum=encoded_sum)
return conditional
embedding = paddle.layer.addto(
input=[token_emb, pos_emb],
layer_attr=paddle.attr.Extra(drop_rate=drop_rate))
proj_size = conv_blocks[0][0]
block_input = paddle.layer.fc(
input=embedding,
size=proj_size,
act=paddle.activation.Linear(),
param_attr=paddle.attr.Param(
initial_mean=0.,
initial_std=math.sqrt((1.0 - drop_rate) / embedding.size)),
bias_attr=True, )
weight = []
for (size, context_len) in conv_blocks:
if block_input.size == size:
residual = block_input
else:
residual = paddle.layer.fc(input=block_input,
size=size,
act=paddle.activation.Linear(),
bias_attr=True)
decoder_state = gated_conv_with_batchnorm(
input=block_input,
size=size,
context_len=context_len,
context_start=0,
drop_rate=drop_rate,
with_bn=with_bn)
group_inputs = [
decoder_state,
embedding,
paddle.layer.StaticInput(input=encoded_vec),
paddle.layer.StaticInput(input=encoded_sum),
]
conditional, attention_weight = paddle.layer.recurrent_group(
step=attention_step, input=group_inputs)
weight.append(attention_weight)
block_output = paddle.layer.addto(input=[conditional, residual])
# halve the variance of the sum
block_output = paddle.layer.slope_intercept(
input=block_output, slope=math.sqrt(0.5))
block_input = block_output
out_emb_dim = embedding.size
block_output = paddle.layer.fc(
input=block_output,
size=out_emb_dim,
act=paddle.activation.Linear(),
layer_attr=paddle.attr.Extra(drop_rate=drop_rate))
decoder_out = paddle.layer.fc(
input=block_output,
size=dict_size,
act=paddle.activation.Softmax(),
param_attr=paddle.attr.Param(
initial_mean=0.,
initial_std=math.sqrt((1.0 - drop_rate) / block_output.size)),
bias_attr=True)
return decoder_out, weight
def conv_seq2seq(src_dict_size,
trg_dict_size,
pos_size,
emb_dim,
enc_conv_blocks=[(256, 3)] * 5,
dec_conv_blocks=[(256, 3)] * 3,
drop_rate=0.,
with_bn=False,
is_infer=False):
"""
Definition of convolutional sequence-to-sequence network.
:param src_dict_size: The size of the source dictionary.
:type src_dict_size: int
:param trg_dict_size: The size of the target dictionary.
:type trg_dict_size: int
:param pos_size: The total number of the position indexes, which means
the maximum value of the index is pos_size - 1.
:type pos_size: int
:param emb_dim: The dimension of the embedding vector.
:type emb_dim: int
:param enc_conv_blocks: The scale list of the encoder's convolution blocks. Each element
of the list contains output dimension and context length of the
corresponding convolution block.
:type enc_conv_blocks: list of tuple
:param dec_conv_blocks: The scale list of the decoder's convolution blocks. Each element
of the list contains output dimension and context length of the
corresponding convolution block.
:type dec_conv_blocks: list of tuple
:param drop_rate: Dropout rate.
:type drop_rate: float
:param with_bn: Whether to use batch normalization or not. False is the default value.
:type with_bn: bool
:param is_infer: Whether infer or not.
:type is_infer: bool
:return: Cost or output layer.
:rtype: LayerOutput
"""
src = paddle.layer.data(
name='src_word',
type=paddle.data_type.integer_value_sequence(src_dict_size))
src_pos = paddle.layer.data(
name='src_word_pos',
type=paddle.data_type.integer_value_sequence(pos_size +
1)) # one for padding
src_emb = paddle.layer.embedding(
input=src,
size=emb_dim,
name='src_word_emb',
param_attr=paddle.attr.Param(
initial_mean=0., initial_std=0.1))
src_pos_emb = paddle.layer.embedding(
input=src_pos,
size=emb_dim,
name='src_pos_emb',
param_attr=paddle.attr.Param(
initial_mean=0., initial_std=0.1))
num_attention = len(dec_conv_blocks)
encoded_vec, encoded_sum = encoder(
token_emb=src_emb,
pos_emb=src_pos_emb,
conv_blocks=enc_conv_blocks,
num_attention=num_attention,
drop_rate=drop_rate,
with_bn=with_bn)
trg = paddle.layer.data(
name='trg_word',
type=paddle.data_type.integer_value_sequence(trg_dict_size +
1)) # one for padding
trg_pos = paddle.layer.data(
name='trg_word_pos',
type=paddle.data_type.integer_value_sequence(pos_size +
1)) # one for padding
trg_emb = paddle.layer.embedding(
input=trg,
size=emb_dim,
name='trg_word_emb',
param_attr=paddle.attr.Param(
initial_mean=0., initial_std=0.1))
trg_pos_emb = paddle.layer.embedding(
input=trg_pos,
size=emb_dim,
name='trg_pos_emb',
param_attr=paddle.attr.Param(
initial_mean=0., initial_std=0.1))
decoder_out, weight = decoder(
token_emb=trg_emb,
pos_emb=trg_pos_emb,
encoded_vec=encoded_vec,
encoded_sum=encoded_sum,
dict_size=trg_dict_size,
conv_blocks=dec_conv_blocks,
drop_rate=drop_rate,
with_bn=with_bn)
if is_infer:
return decoder_out, weight
trg_next_word = paddle.layer.data(
name='trg_next_word',
type=paddle.data_type.integer_value_sequence(trg_dict_size))
cost = paddle.layer.classification_cost(
input=decoder_out, label=trg_next_word)
return cost
#coding=utf-8
import cPickle
def concat_file(file1, file2, dst_file):
with open(dst_file, 'w') as dst:
with open(file1) as f1:
with open(file2) as f2:
for i, (line1, line2) in enumerate(zip(f1, f2)):
line1 = line1.strip()
line = line1 + '\t' + line2
dst.write(line)
if __name__ == '__main__':
concat_file('dev.de-en.de', 'dev.de-en.en', 'dev')
concat_file('test.de-en.de', 'test.de-en.en', 'test')
concat_file('train.de-en.de', 'train.de-en.en', 'train')
src_dict = cPickle.load(open('vocab.de'))
trg_dict = cPickle.load(open('vocab.en'))
with open('src_dict', 'w') as f:
f.write('<s>\n<e>\nUNK\n')
f.writelines('\n'.join(src_dict.keys()))
with open('trg_dict', 'w') as f:
f.write('<s>\n<e>\nUNK\n')
f.writelines('\n'.join(trg_dict.keys()))
#coding=utf-8
import random
def load_dict(dict_file):
word_dict = dict()
with open(dict_file, 'r') as f:
for i, line in enumerate(f):
w = line.strip().split()[0]
word_dict[w] = i
return word_dict
def get_reverse_dict(dictionary):
reverse_dict = {dictionary[k]: k for k in dictionary.keys()}
return reverse_dict
def load_data(data_file, src_dict, trg_dict):
UNK_IDX = src_dict['UNK']
with open(data_file, 'r') as f:
for line in f:
line_split = line.strip().split('\t')
if len(line_split) < 2:
continue
src, trg = line_split
src_words = src.strip().split()
trg_words = trg.strip().split()
src_seq = [src_dict.get(w, UNK_IDX) for w in src_words]
trg_seq = [trg_dict.get(w, UNK_IDX) for w in trg_words]
yield src_seq, trg_seq
def data_reader(data_file, src_dict, trg_dict, pos_size, padding_num):
def reader():
UNK_IDX = src_dict['UNK']
word_padding = trg_dict.__len__()
pos_padding = pos_size
def _get_pos(pos_list, pos_size, pos_padding):
return [pos if pos < pos_size else pos_padding for pos in pos_list]
with open(data_file, 'r') as f:
for line in f:
line_split = line.strip().split('\t')
if len(line_split) != 2:
continue
src, trg = line_split
src = src.strip().split()
src_word = [src_dict.get(w, UNK_IDX) for w in src]
src_word_pos = range(len(src_word))
src_word_pos = _get_pos(src_word_pos, pos_size, pos_padding)
trg = trg.strip().split()
trg_word = [trg_dict['<s>']
] + [trg_dict.get(w, UNK_IDX) for w in trg]
trg_word_pos = range(len(trg_word))
trg_word_pos = _get_pos(trg_word_pos, pos_size, pos_padding)
trg_next_word = trg_word[1:] + [trg_dict['<e>']]
trg_word = [word_padding] * padding_num + trg_word
trg_word_pos = [pos_padding] * padding_num + trg_word_pos
trg_next_word = trg_next_word + [trg_dict['<e>']] * padding_num
yield src_word, src_word_pos, trg_word, trg_word_pos, trg_next_word
return reader
#coding=utf-8
import os
import sys
import time
import argparse
import distutils.util
import gzip
import numpy as np
import paddle.v2 as paddle
from model import conv_seq2seq
import reader
def parse_args():
parser = argparse.ArgumentParser(
description="PaddlePaddle Convolutional Seq2Seq")
parser.add_argument(
'--train_data_path',
type=str,
required=True,
help="Path of the training set")
parser.add_argument(
'--test_data_path', type=str, help='Path of the test set')
parser.add_argument(
'--src_dict_path',
type=str,
required=True,
help='Path of source dictionary')
parser.add_argument(
'--trg_dict_path',
type=str,
required=True,
help='Path of target dictionary')
parser.add_argument(
'--enc_blocks', type=str, help='Convolution blocks of the encoder')
parser.add_argument(
'--dec_blocks', type=str, help='Convolution blocks of the decoder')
parser.add_argument(
'--emb_size',
type=int,
default=256,
help='Dimension of word embedding. (default: %(default)s)')
parser.add_argument(
'--pos_size',
type=int,
default=200,
help='Total number of the position indexes. (default: %(default)s)')
parser.add_argument(
'--drop_rate',
type=float,
default=0.,
help='Dropout rate. (default: %(default)s)')
parser.add_argument(
"--use_bn",
default=False,
type=distutils.util.strtobool,
help="Use batch normalization or not. (default: %(default)s)")
parser.add_argument(
"--use_gpu",
default=False,
type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--trainer_count",
default=1,
type=int,
help="Trainer number. (default: %(default)s)")
parser.add_argument(
'--batch_size',
type=int,
default=32,
help="Size of a mini-batch. (default: %(default)s)")
parser.add_argument(
'--num_passes',
type=int,
default=15,
help="Number of passes to train. (default: %(default)s)")
return parser.parse_args()
def create_reader(padding_num,
train_data_path,
test_data_path=None,
src_dict=None,
trg_dict=None,
pos_size=200,
batch_size=32):
train_reader = paddle.batch(
reader=paddle.reader.shuffle(
reader=reader.data_reader(
data_file=train_data_path,
src_dict=src_dict,
trg_dict=trg_dict,
pos_size=pos_size,
padding_num=padding_num),
buf_size=10240),
batch_size=batch_size)
test_reader = None
if test_data_path:
test_reader = paddle.batch(
reader=paddle.reader.shuffle(
reader=reader.data_reader(
data_file=test_data_path,
src_dict=src_dict,
trg_dict=trg_dict,
pos_size=pos_size,
padding_num=padding_num),
buf_size=10240),
batch_size=batch_size)
return train_reader, test_reader
def train(train_data_path,
test_data_path,
src_dict_path,
trg_dict_path,
enc_conv_blocks,
dec_conv_blocks,
emb_dim=256,
pos_size=200,
drop_rate=0.,
use_bn=False,
batch_size=32,
num_passes=15):
"""
Train the convolution sequence-to-sequence model.
:param train_data_path: The path of the training set.
:type train_data_path: str
:param test_data_path: The path of the test set.
:type test_data_path: str
:param src_dict_path: The path of the source dictionary.
:type src_dict_path: str
:param trg_dict_path: The path of the target dictionary.
:type trg_dict_path: str
:param enc_conv_blocks: The scale list of the encoder's convolution blocks. And each element of
the list contains output dimension and context length of the corresponding
convolution block.
:type enc_conv_blocks: list of tuple
:param dec_conv_blocks: The scale list of the decoder's convolution blocks. And each element of
the list contains output dimension and context length of the corresponding
convolution block.
:type dec_conv_blocks: list of tuple
:param emb_dim: The dimension of the embedding vector.
:type emb_dim: int
:param pos_size: The total number of the position indexes, which means
the maximum value of the index is pos_size - 1.
:type pos_size: int
:param drop_rate: Dropout rate.
:type drop_rate: float
:param use_bn: Whether to use batch normalization or not. False is the default value.
:type use_bn: bool
:param batch_size: The size of a mini-batch.
:type batch_size: int
:param num_passes: The total number of the passes to train.
:type num_passes: int
"""
# load dict
src_dict = reader.load_dict(src_dict_path)
trg_dict = reader.load_dict(trg_dict_path)
src_dict_size = src_dict.__len__()
trg_dict_size = trg_dict.__len__()
optimizer = paddle.optimizer.Adam(learning_rate=1e-3, )
cost = conv_seq2seq(
src_dict_size=src_dict_size,
trg_dict_size=trg_dict_size,
pos_size=pos_size,
emb_dim=emb_dim,
enc_conv_blocks=enc_conv_blocks,
dec_conv_blocks=dec_conv_blocks,
drop_rate=drop_rate,
with_bn=use_bn,
is_infer=False)
# create parameters and trainer
parameters = paddle.parameters.create(cost)
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer)
padding_list = [context_len - 1 for (size, context_len) in dec_conv_blocks]
padding_num = reduce(lambda x, y: x + y, padding_list)
train_reader, test_reader = create_reader(
padding_num=padding_num,
train_data_path=train_data_path,
test_data_path=test_data_path,
src_dict=src_dict,
trg_dict=trg_dict,
pos_size=pos_size,
batch_size=batch_size)
feeding = {
'src_word': 0,
'src_word_pos': 1,
'trg_word': 2,
'trg_word_pos': 3,
'trg_next_word': 4
}
# create event handler
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 20 == 0:
cur_time = time.strftime('%Y.%m.%d %H:%M:%S', time.localtime())
print "[%s]: Pass: %d, Batch: %d, TrainCost: %f, %s" % (
cur_time, event.pass_id, event.batch_id, event.cost,
event.metrics)
sys.stdout.flush()
if isinstance(event, paddle.event.EndPass):
if test_reader is not None:
cur_time = time.strftime('%Y.%m.%d %H:%M:%S', time.localtime())
result = trainer.test(reader=test_reader, feeding=feeding)
print "[%s]: Pass: %d, TestCost: %f, %s" % (
cur_time, event.pass_id, result.cost, result.metrics)
sys.stdout.flush()
with gzip.open("output/params.pass-%d.tar.gz" % event.pass_id,
'w') as f:
trainer.save_parameter_to_tar(f)
if not os.path.exists('output'):
os.mkdir('output')
trainer.train(
reader=train_reader,
event_handler=event_handler,
num_passes=num_passes,
feeding=feeding)
def main():
args = parse_args()
enc_conv_blocks = eval(args.enc_blocks)
dec_conv_blocks = eval(args.dec_blocks)
sys.setrecursionlimit(10000)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
train(
train_data_path=args.train_data_path,
test_data_path=args.test_data_path,
src_dict_path=args.src_dict_path,
trg_dict_path=args.trg_dict_path,
enc_conv_blocks=enc_conv_blocks,
dec_conv_blocks=dec_conv_blocks,
emb_dim=args.emb_size,
pos_size=args.pos_size,
drop_rate=args.drop_rate,
use_bn=args.use_bn,
batch_size=args.batch_size,
num_passes=args.num_passes)
if __name__ == '__main__':
main()
运行本目录下的程序示例需要使用PaddlePaddle v0.10.0 版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。
---
# 点击率预估
以下是本例目录包含的文件以及对应说明:
```
├── README.md # 本教程markdown 文档
├── dataset.md # 数据集处理教程
├── images # 本教程图片目录
│   ├── lr_vs_dnn.jpg
│   └── wide_deep.png
├── infer.py # 预测脚本
├── network_conf.py # 模型网络配置
├── reader.py # data reader
├── train.py # 训练脚本
└── utils.py # helper functions
└── avazu_data_processer.py # 示例数据预处理脚本
```
## 背景介绍
CTR(Click-Through Rate,点击率预估)\[[1](https://en.wikipedia.org/wiki/Click-through_rate)\]
是对用户点击一个特定链接的概率做出预测,是广告投放过程中的一个重要环节。精准的点击率预估对在线广告系统收益最大化具有重要意义。
当有多个广告位时,CTR 预估一般会作为排序的基准,比如在搜索引擎的广告系统里,当用户输入一个带商业价值的搜索词(query)时,系统大体上会执行下列步骤来展示广告:
1. 获取与用户搜索词相关的广告集合
2. 业务规则和相关性过滤
3. 根据拍卖机制和 CTR 排序
4. 展出广告
可以看到,CTR 在最终排序中起到了很重要的作用。
### 发展阶段
在业内,CTR 模型经历了如下的发展阶段:
- Logistic Regression(LR) / GBDT + 特征工程
- LR + DNN 特征
- DNN + 特征工程
在发展早期时 LR 一统天下,但最近 DNN 模型由于其强大的学习能力和逐渐成熟的性能优化,
逐渐地接过 CTR 预估任务的大旗。
### LR vs DNN
下图展示了 LR 和一个 \(3x2\) 的 DNN 模型的结构:
<p align="center">
<img src="images/lr_vs_dnn.jpg" width="620" hspace='10'/> <br/>
Figure 1. LR 和 DNN 模型结构对比
</p>
LR 的蓝色箭头部分可以直接类比到 DNN 中对应的结构,可以看到 LR 和 DNN 有一些共通之处(比如权重累加),
但前者的模型复杂度在相同输入维度下比后者可能低很多(从某方面讲,模型越复杂,越有潜力学习到更复杂的信息);
如果 LR 要达到匹敌 DNN 的学习能力,必须增加输入的维度,也就是增加特征的数量,
这也就是为何 LR 和大规模的特征工程必须绑定在一起的原因。
LR 对于 DNN 模型的优势是对大规模稀疏特征的容纳能力,包括内存和计算量等方面,工业界都有非常成熟的优化方法;
而 DNN 模型具有自己学习新特征的能力,一定程度上能够提升特征使用的效率,
这使得 DNN 模型在同样规模特征的情况下,更有可能达到更好的学习效果。
本文后面的章节会演示如何使用 PaddlePaddle 编写一个结合两者优点的模型。
## 数据和任务抽象
我们可以将 `click` 作为学习目标,任务可以有以下几种方案:
1. 直接学习 click,0,1 作二元分类
2. Learning to rank, 具体用 pairwise rank(标签 1>0)或者 listwise rank
3. 统计每个广告的点击率,将同一个 query 下的广告两两组合,点击率高的>点击率低的,做 rank 或者分类
我们直接使用第一种方法做分类任务。
我们使用 Kaggle 上 `Click-through rate prediction` 任务的数据集\[[2](https://www.kaggle.com/c/avazu-ctr-prediction/data)\] 来演示本例中的模型。
具体的特征处理方法参看 [data process](./dataset.md)
本教程中演示模型的输入格式如下:
```
# <dnn input ids> \t <lr input sparse values> \t click
1 23 190 \t 230:0.12 3421:0.9 23451:0.12 \t 0
23 231 \t 1230:0.12 13421:0.9 \t 1
```
详细的格式描述如下:
- `dnn input ids` 采用 one-hot 表示,只需要填写值为1的ID(注意这里不是变长输入)
- `lr input sparse values` 使用了 `ID:VALUE` 的表示,值部分最好规约到值域 `[-1, 1]`
此外,模型训练时需要传入一个文件描述 dnn 和 lr两个子模型的输入维度,文件的格式如下:
```
dnn_input_dim: <int>
lr_input_dim: <int>
```
其中, `<int>` 表示一个整型数值。
本目录下的 `avazu_data_processor.py` 可以对下载的演示数据集\[[2](#参考文档)\] 进行处理,具体使用方法参考如下说明:
```
usage: avazu_data_processer.py [-h] --data_path DATA_PATH --output_dir
OUTPUT_DIR
[--num_lines_to_detect NUM_LINES_TO_DETECT]
[--test_set_size TEST_SET_SIZE]
[--train_size TRAIN_SIZE]
PaddlePaddle CTR example
optional arguments:
-h, --help show this help message and exit
--data_path DATA_PATH
path of the Avazu dataset
--output_dir OUTPUT_DIR
directory to output
--num_lines_to_detect NUM_LINES_TO_DETECT
number of records to detect dataset's meta info
--test_set_size TEST_SET_SIZE
size of the validation dataset(default: 10000)
--train_size TRAIN_SIZE
size of the trainset (default: 100000)
```
- `data_path` 是待处理的数据路径
- `output_dir` 生成数据的输出路径
- `num_lines_to_detect` 预先扫描数据生成ID的个数,这里是扫描的文件行数
- `test_set_size` 生成测试集的行数
- `train_size` 生成训练姐的行数
## Wide & Deep Learning Model
谷歌在 16 年提出了 Wide & Deep Learning 的模型框架,用于融合适合学习抽象特征的 DNN 和 适用于大规模稀疏特征的 LR 两种模型的优点。
### 模型简介
Wide & Deep Learning Model\[[3](#参考文献)\] 可以作为一种相对成熟的模型框架使用,
在 CTR 预估的任务中工业界也有一定的应用,因此本文将演示使用此模型来完成 CTR 预估的任务。
模型结构如下:
<p align="center">
<img src="images/wide_deep.png" width="820" hspace='10'/> <br/>
Figure 2. Wide & Deep Model
</p>
模型上边的 Wide 部分,可以容纳大规模系数特征,并且对一些特定的信息(比如 ID)有一定的记忆能力;
而模型下边的 Deep 部分,能够学习特征间的隐含关系,在相同数量的特征下有更好的学习和推导能力。
### 编写模型输入
模型只接受 3 个输入,分别是
- `dnn_input` ,也就是 Deep 部分的输入
- `lr_input` ,也就是 Wide 部分的输入
- `click` , 点击与否,作为二分类模型学习的标签
```python
dnn_merged_input = layer.data(
name='dnn_input',
type=paddle.data_type.sparse_binary_vector(data_meta_info['dnn_input']))
lr_merged_input = layer.data(
name='lr_input',
type=paddle.data_type.sparse_binary_vector(data_meta_info['lr_input']))
click = paddle.layer.data(name='click', type=dtype.dense_vector(1))
```
### 编写 Wide 部分
Wide 部分直接使用了 LR 模型,但激活函数改成了 `RELU` 来加速
```python
def build_lr_submodel():
fc = layer.fc(
input=lr_merged_input, size=1, name='lr', act=paddle.activation.Relu())
return fc
```
### 编写 Deep 部分
Deep 部分使用了标准的多层前向传导的 DNN 模型
```python
def build_dnn_submodel(dnn_layer_dims):
dnn_embedding = layer.fc(input=dnn_merged_input, size=dnn_layer_dims[0])
_input_layer = dnn_embedding
for i, dim in enumerate(dnn_layer_dims[1:]):
fc = layer.fc(
input=_input_layer,
size=dim,
act=paddle.activation.Relu(),
name='dnn-fc-%d' % i)
_input_layer = fc
return _input_layer
```
### 两者融合
两个 submodel 的最上层输出加权求和得到整个模型的输出,输出部分使用 `sigmoid` 作为激活函数,得到区间 (0,1) 的预测值,
来逼近训练数据中二元类别的分布,并最终作为 CTR 预估的值使用。
```python
# conbine DNN and LR submodels
def combine_submodels(dnn, lr):
merge_layer = layer.concat(input=[dnn, lr])
fc = layer.fc(
input=merge_layer,
size=1,
name='output',
# use sigmoid function to approximate ctr, wihch is a float value between 0 and 1.
act=paddle.activation.Sigmoid())
return fc
```
### 训练任务的定义
```python
dnn = build_dnn_submodel(dnn_layer_dims)
lr = build_lr_submodel()
output = combine_submodels(dnn, lr)
# ==============================================================================
# cost and train period
# ==============================================================================
classification_cost = paddle.layer.multi_binary_label_cross_entropy_cost(
input=output, label=click)
paddle.init(use_gpu=False, trainer_count=11)
params = paddle.parameters.create(classification_cost)
optimizer = paddle.optimizer.Momentum(momentum=0)
trainer = paddle.trainer.SGD(
cost=classification_cost, parameters=params, update_equation=optimizer)
dataset = AvazuDataset(train_data_path, n_records_as_test=test_set_size)
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
logging.warning("Pass %d, Samples %d, Cost %f" % (
event.pass_id, event.batch_id * batch_size, event.cost))
if event.batch_id % 1000 == 0:
result = trainer.test(
reader=paddle.batch(dataset.test, batch_size=1000),
feeding=field_index)
logging.warning("Test %d-%d, Cost %f" % (event.pass_id, event.batch_id,
result.cost))
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(dataset.train, buf_size=500),
batch_size=batch_size),
feeding=field_index,
event_handler=event_handler,
num_passes=100)
```
## 运行训练和测试
训练模型需要如下步骤:
1. 准备训练数据
1.[Kaggle CTR](https://www.kaggle.com/c/avazu-ctr-prediction/data) 下载 train.gz
2. 解压 train.gz 得到 train.txt
3. `mkdir -p output; python avazu_data_processer.py --data_path train.txt --output_dir output --num_lines_to_detect 1000 --test_set_size 100` 生成演示数据
2. 执行 `python train.py --train_data_path ./output/train.txt --test_data_path ./output/test.txt --data_meta_file ./output/data.meta.txt --model_type=0` 开始训练
上面第2个步骤可以为 `train.py` 填充命令行参数来定制模型的训练过程,具体的命令行参数及用法如下
```
usage: train.py [-h] --train_data_path TRAIN_DATA_PATH
[--test_data_path TEST_DATA_PATH] [--batch_size BATCH_SIZE]
[--num_passes NUM_PASSES]
[--model_output_prefix MODEL_OUTPUT_PREFIX] --data_meta_file
DATA_META_FILE --model_type MODEL_TYPE
PaddlePaddle CTR example
optional arguments:
-h, --help show this help message and exit
--train_data_path TRAIN_DATA_PATH
path of training dataset
--test_data_path TEST_DATA_PATH
path of testing dataset
--batch_size BATCH_SIZE
size of mini-batch (default:10000)
--num_passes NUM_PASSES
number of passes to train
--model_output_prefix MODEL_OUTPUT_PREFIX
prefix of path for model to store (default:
./ctr_models)
--data_meta_file DATA_META_FILE
path of data meta info file
--model_type MODEL_TYPE
model type, classification: 0, regression 1 (default
classification)
```
- `train_data_path` : 训练集的路径
- `test_data_path` : 测试集的路径
- `num_passes`: 模型训练多少轮
- `data_meta_file`: 参考[数据和任务抽象](### 数据和任务抽象)的描述。
- `model_type`: 模型分类或回归
## 用训好的模型做预测
训好的模型可以用来预测新的数据, 预测数据的格式为
```
# <dnn input ids> \t <lr input sparse values>
1 23 190 \t 230:0.12 3421:0.9 23451:0.12
23 231 \t 1230:0.12 13421:0.9
```
这里与训练数据的格式唯一不同的地方,就是没有标签,也就是训练数据中第3列 `click` 对应的数值。
`infer.py` 的使用方法如下
```
usage: infer.py [-h] --model_gz_path MODEL_GZ_PATH --data_path DATA_PATH
--prediction_output_path PREDICTION_OUTPUT_PATH
[--data_meta_path DATA_META_PATH] --model_type MODEL_TYPE
PaddlePaddle CTR example
optional arguments:
-h, --help show this help message and exit
--model_gz_path MODEL_GZ_PATH
path of model parameters gz file
--data_path DATA_PATH
path of the dataset to infer
--prediction_output_path PREDICTION_OUTPUT_PATH
path to output the prediction
--data_meta_path DATA_META_PATH
path of trainset's meta info, default is ./data.meta
--model_type MODEL_TYPE
model type, classification: 0, regression 1 (default
classification)
```
- `model_gz_path_model`:用 `gz` 压缩过的模型路径
- `data_path` : 需要预测的数据路径
- `prediction_output_paht`:预测输出的路径
- `data_meta_file` :参考[数据和任务抽象](### 数据和任务抽象)的描述。
- `model_type` :分类或回归
示例数据可以用如下命令预测
```
python infer.py --model_gz_path <model_path> --data_path output/infer.txt --prediction_output_path predictions.txt --data_meta_path data.meta.txt
```
最终的预测结果位于 `predictions.txt`
## 参考文献
1. <https://en.wikipedia.org/wiki/Click-through_rate>
2. <https://www.kaggle.com/c/avazu-ctr-prediction/data>
3. Cheng H T, Koc L, Harmsen J, et al. [Wide & deep learning for recommender systems](https://arxiv.org/pdf/1606.07792.pdf)[C]//Proceedings of the 1st Workshop on Deep Learning for Recommender Systems. ACM, 2016: 7-10.
The minimum PaddlePaddle version needed for the code sample in this directory is v0.10.0. If you are on a version of PaddlePaddle earlier than v0.10.0, [please update your installation](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html).
---
# Click-Through Rate Prediction
## Introduction
CTR(Click-Through Rate)\[[1](https://en.wikipedia.org/wiki/Click-through_rate)\]
is a prediction of the probability that a user clicks on an advertisement. This model is widely used in the advertisement industry. Accurate click rate estimates are important for maximizing online advertising revenue.
When there are multiple ad slots, CTR estimates are generally used as a baseline for ranking. For example, in a search engine's ad system, when the user enters a query, the system typically performs the following steps to show relevant ads.
1. Get the ad collection associated with the user's search term.
2. Business rules and relevance filtering.
3. Rank by auction mechanism and CTR.
4. Show ads.
Here,CTR plays a crucial role.
### Brief history
Historically, the CTR prediction model has been evolving as follows.
- Logistic Regression(LR) / Gradient Boosting Decision Trees (GBDT) + feature engineering
- LR + Deep Neural Network (DNN)
- DNN + feature engineering
In the early stages of development LR dominated, but the recent years DNN based models are mainly used.
### LR vs DNN
The following figure shows the structure of LR and DNN model:
<p align="center">
<img src="images/lr_vs_dnn.jpg" width="620" hspace='10'/> <br/>
Figure 1. LR and DNN model structure comparison
</p>
We can see, LR and CNN have some common structures. However, DNN can have non-linear relation between input and output values by adding activation unit and further layers. This enables DNN to achieve better learning results in CTR estimates.
In the following, we demonstrate how to use PaddlePaddle to learn to predict CTR.
## Data and Model formation
Here `click` is the learning objective. There are several ways to learn the objectives.
1. Direct learning click, 0,1 for binary classification
2. Learning to rank, pairwise rank or listwise rank
3. Measure the ad click rate of each ad, then rank by the click rate.
In this example, we use the first method.
We use the Kaggle `Click-through rate prediction` task \[[2](https://www.kaggle.com/c/avazu-ctr-prediction/data)\].
Please see the [data process](./dataset.md) for pre-processing data.
The input data format for the demo model in this tutorial is as follows:
```
# <dnn input ids> \t <lr input sparse values> \t click
1 23 190 \t 230:0.12 3421:0.9 23451:0.12 \t 0
23 231 \t 1230:0.12 13421:0.9 \t 1
```
Description:
- `dnn input ids` one-hot coding.
- `lr input sparse values` Use `ID:VALUE` , values are preferaly scaled to the range `[-1, 1]`
此外,模型训练时需要传入一个文件描述 dnn 和 lr两个子模型的输入维度,文件的格式如下:
```
dnn_input_dim: <int>
lr_input_dim: <int>
```
<int> represents an integer value.
`avazu_data_processor.py` can be used to download the data set \[[2](#参考文档)\]and pre-process the data.
```
usage: avazu_data_processer.py [-h] --data_path DATA_PATH --output_dir
OUTPUT_DIR
[--num_lines_to_detect NUM_LINES_TO_DETECT]
[--test_set_size TEST_SET_SIZE]
[--train_size TRAIN_SIZE]
PaddlePaddle CTR example
optional arguments:
-h, --help show this help message and exit
--data_path DATA_PATH
path of the Avazu dataset
--output_dir OUTPUT_DIR
directory to output
--num_lines_to_detect NUM_LINES_TO_DETECT
number of records to detect dataset's meta info
--test_set_size TEST_SET_SIZE
size of the validation dataset(default: 10000)
--train_size TRAIN_SIZE
size of the trainset (default: 100000)
```
- `data_path` The data path to be processed
- `output_dir` The output path of the data
- `num_lines_to_detect` The number of generated IDs
- `test_set_size` The number of rows for the test set
- `train_size` The number of rows of training set
## Wide & Deep Learning Model
Google proposed a model framework for Wide & Deep Learning to integrate the advantages of both DNNs suitable for learning abstract features and LR models for large sparse features.
### Introduction to the model
Wide & Deep Learning Model\[[3](#References)\] is a relatively mature model, but this model is still being used in the CTR predicting task. Here we demonstrate the use of this model to complete the CTR predicting task.
The model structure is as follows:
<p align="center">
<img src="images/wide_deep.png" width="820" hspace='10'/> <br/>
Figure 2. Wide & Deep Model
</p>
The wide part of the top side of the model can accommodate large-scale coefficient features and has some memory for some specific information (such as ID); and the Deep part of the bottom side of the model can learn the implicit relationship between features.
### Model Input
The model has three inputs as follows.
- `dnn_input` ,the Deep part of the input
- `lr_input` ,the wide part of the input
- `click` , click on or not
```python
dnn_merged_input = layer.data(
name='dnn_input',
type=paddle.data_type.sparse_binary_vector(self.dnn_input_dim))
lr_merged_input = layer.data(
name='lr_input',
type=paddle.data_type.sparse_vector(self.lr_input_dim))
click = paddle.layer.data(name='click', type=dtype.dense_vector(1))
```
### Wide part
Wide part uses of the LR model, but the activation function changed to `RELU` for speed.
```python
def build_lr_submodel():
fc = layer.fc(
input=lr_merged_input, size=1, name='lr', act=paddle.activation.Relu())
return fc
```
### Deep part
The Deep part uses a standard multi-layer DNN.
```python
def build_dnn_submodel(dnn_layer_dims):
dnn_embedding = layer.fc(input=dnn_merged_input, size=dnn_layer_dims[0])
_input_layer = dnn_embedding
for i, dim in enumerate(dnn_layer_dims[1:]):
fc = layer.fc(
input=_input_layer,
size=dim,
act=paddle.activation.Relu(),
name='dnn-fc-%d' % i)
_input_layer = fc
return _input_layer
```
### Combine
The output section uses `sigmoid` function to output (0,1) as the prediction value.
```python
# conbine DNN and LR submodels
def combine_submodels(dnn, lr):
merge_layer = layer.concat(input=[dnn, lr])
fc = layer.fc(
input=merge_layer,
size=1,
name='output',
# use sigmoid function to approximate ctr, wihch is a float value between 0 and 1.
act=paddle.activation.Sigmoid())
return fc
```
### Training
```python
dnn = build_dnn_submodel(dnn_layer_dims)
lr = build_lr_submodel()
output = combine_submodels(dnn, lr)
# ==============================================================================
# cost and train period
# ==============================================================================
classification_cost = paddle.layer.multi_binary_label_cross_entropy_cost(
input=output, label=click)
paddle.init(use_gpu=False, trainer_count=11)
params = paddle.parameters.create(classification_cost)
optimizer = paddle.optimizer.Momentum(momentum=0)
trainer = paddle.trainer.SGD(
cost=classification_cost, parameters=params, update_equation=optimizer)
dataset = AvazuDataset(train_data_path, n_records_as_test=test_set_size)
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
logging.warning("Pass %d, Samples %d, Cost %f" % (
event.pass_id, event.batch_id * batch_size, event.cost))
if event.batch_id % 1000 == 0:
result = trainer.test(
reader=paddle.batch(dataset.test, batch_size=1000),
feeding=field_index)
logging.warning("Test %d-%d, Cost %f" % (event.pass_id, event.batch_id,
result.cost))
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(dataset.train, buf_size=500),
batch_size=batch_size),
feeding=field_index,
event_handler=event_handler,
num_passes=100)
```
## Run training and testing
The model go through the following steps:
1. Prepare training data
1. Download train.gz from [Kaggle CTR](https://www.kaggle.com/c/avazu-ctr-prediction/data) .
2. Unzip train.gz to get train.txt
3. `mkdir -p output; python avazu_data_processer.py --data_path train.txt --output_dir output --num_lines_to_detect 1000 --test_set_size 100` 生成演示数据
2. Execute `python train.py --train_data_path ./output/train.txt --test_data_path ./output/test.txt --data_meta_file ./output/data.meta.txt --model_type=0`. Start training.
The argument options for `train.py` are as follows.
```
usage: train.py [-h] --train_data_path TRAIN_DATA_PATH
[--test_data_path TEST_DATA_PATH] [--batch_size BATCH_SIZE]
[--num_passes NUM_PASSES]
[--model_output_prefix MODEL_OUTPUT_PREFIX] --data_meta_file
DATA_META_FILE --model_type MODEL_TYPE
PaddlePaddle CTR example
optional arguments:
-h, --help show this help message and exit
--train_data_path TRAIN_DATA_PATH
path of training dataset
--test_data_path TEST_DATA_PATH
path of testing dataset
--batch_size BATCH_SIZE
size of mini-batch (default:10000)
--num_passes NUM_PASSES
number of passes to train
--model_output_prefix MODEL_OUTPUT_PREFIX
prefix of path for model to store (default:
./ctr_models)
--data_meta_file DATA_META_FILE
path of data meta info file
--model_type MODEL_TYPE
model type, classification: 0, regression 1 (default
classification)
```
- `train_data_path` : The path of the training set
- `test_data_path` : The path of the testing set
- `num_passes`: number of rounds of model training
- `data_meta_file`: Please refer to [数据和任务抽象](### 数据和任务抽象)的描述。
- `model_type`: Model classification or regressio
## Use the training model for prediction
The training model can be used to predict new data, and the format of the forecast data is as follows.
```
# <dnn input ids> \t <lr input sparse values>
1 23 190 \t 230:0.12 3421:0.9 23451:0.12
23 231 \t 1230:0.12 13421:0.9
```
Here the only difference to the training data is that there is no label (i.e. `click` values).
We now can use `infer.py` to perform inference.
```
usage: infer.py [-h] --model_gz_path MODEL_GZ_PATH --data_path DATA_PATH
--prediction_output_path PREDICTION_OUTPUT_PATH
[--data_meta_path DATA_META_PATH] --model_type MODEL_TYPE
PaddlePaddle CTR example
optional arguments:
-h, --help show this help message and exit
--model_gz_path MODEL_GZ_PATH
path of model parameters gz file
--data_path DATA_PATH
path of the dataset to infer
--prediction_output_path PREDICTION_OUTPUT_PATH
path to output the prediction
--data_meta_path DATA_META_PATH
path of trainset's meta info, default is ./data.meta
--model_type MODEL_TYPE
model type, classification: 0, regression 1 (default
classification)
```
- `model_gz_path_model`:path for `gz` compressed data.
- `data_path`
- `prediction_output_patj`:path for the predicted values s
- `data_meta_file` :Please refer to [数据和任务抽象](### 数据和任务抽象)。
- `model_type` :Classification or regression
The sample data can be predicted with the following command
```
python infer.py --model_gz_path <model_path> --data_path output/infer.txt --prediction_output_path predictions.txt --data_meta_path data.meta.txt
```
The final prediction is written in `predictions.txt`
## References
1. <https://en.wikipedia.org/wiki/Click-through_rate>
2. <https://www.kaggle.com/c/avazu-ctr-prediction/data>
3. Cheng H T, Koc L, Harmsen J, et al. [Wide & deep learning for recommender systems](https://arxiv.org/pdf/1606.07792.pdf)[C]//Proceedings of the 1st Workshop on Deep Learning for Recommender Systems. ACM, 2016: 7-10.
import sys
import csv
import cPickle
import argparse
import os
import numpy as np
from utils import logger, TaskMode
parser = argparse.ArgumentParser(description="PaddlePaddle CTR example")
parser.add_argument(
'--data_path', type=str, required=True, help="path of the Avazu dataset")
parser.add_argument(
'--output_dir', type=str, required=True, help="directory to output")
parser.add_argument(
'--num_lines_to_detect',
type=int,
default=500000,
help="number of records to detect dataset's meta info")
parser.add_argument(
'--test_set_size',
type=int,
default=10000,
help="size of the validation dataset(default: 10000)")
parser.add_argument(
'--train_size',
type=int,
default=100000,
help="size of the trainset (default: 100000)")
args = parser.parse_args()
'''
The fields of the dataset are:
0. id: ad identifier
1. click: 0/1 for non-click/click
2. hour: format is YYMMDDHH, so 14091123 means 23:00 on Sept. 11, 2014 UTC.
3. C1 -- anonymized categorical variable
4. banner_pos
5. site_id
6. site_domain
7. site_category
8. app_id
9. app_domain
10. app_category
11. device_id
12. device_ip
13. device_model
14. device_type
15. device_conn_type
16. C14-C21 -- anonymized categorical variables
We will treat the following fields as categorical features:
- C1
- banner_pos
- site_category
- app_category
- device_type
- device_conn_type
and some other features as id features:
- id
- site_id
- app_id
- device_id
The `hour` field will be treated as a continuous feature and will be transformed
to one-hot representation which has 24 bits.
This script will output 3 files:
1. train.txt
2. test.txt
3. infer.txt
all the files are for demo.
'''
feature_dims = {}
categorial_features = (
'C1 banner_pos site_category app_category ' + 'device_type device_conn_type'
).split()
id_features = 'id site_id app_id device_id _device_id_cross_site_id'.split()
def get_all_field_names(mode=0):
'''
@mode: int
0 for train, 1 for test
@return: list of str
'''
return categorial_features + ['hour'] + id_features + ['click'] \
if mode == 0 else []
class CategoryFeatureGenerator(object):
'''
Generator category features.
Register all records by calling `register` first, then call `gen` to generate
one-hot representation for a record.
'''
def __init__(self):
self.dic = {'unk': 0}
self.counter = 1
def register(self, key):
'''
Register record.
'''
if key not in self.dic:
self.dic[key] = self.counter
self.counter += 1
def size(self):
return len(self.dic)
def gen(self, key):
'''
Generate one-hot representation for a record.
'''
if key not in self.dic:
res = self.dic['unk']
else:
res = self.dic[key]
return [res]
def __repr__(self):
return '<CategoryFeatureGenerator %d>' % len(self.dic)
class IDfeatureGenerator(object):
def __init__(self, max_dim, cross_fea0=None, cross_fea1=None):
'''
@max_dim: int
Size of the id elements' space
'''
self.max_dim = max_dim
self.cross_fea0 = cross_fea0
self.cross_fea1 = cross_fea1
def gen(self, key):
'''
Generate one-hot representation for records
'''
return [hash(key) % self.max_dim]
def gen_cross_fea(self, fea1, fea2):
key = str(fea1) + str(fea2)
return self.gen(key)
def size(self):
return self.max_dim
class ContinuousFeatureGenerator(object):
def __init__(self, n_intervals):
self.min = sys.maxint
self.max = sys.minint
self.n_intervals = n_intervals
def register(self, val):
self.min = min(self.minint, val)
self.max = max(self.maxint, val)
def gen(self, val):
self.len_part = (self.max - self.min) / self.n_intervals
return (val - self.min) / self.len_part
# init all feature generators
fields = {}
for key in categorial_features:
fields[key] = CategoryFeatureGenerator()
for key in id_features:
# for cross features
if 'cross' in key:
feas = key[1:].split('_cross_')
fields[key] = IDfeatureGenerator(10000000, *feas)
# for normal ID features
else:
fields[key] = IDfeatureGenerator(10000)
# used as feed_dict in PaddlePaddle
field_index = dict((key, id)
for id, key in enumerate(['dnn_input', 'lr_input', 'click']))
def detect_dataset(path, topn, id_fea_space=10000):
'''
Parse the first `topn` records to collect meta information of this dataset.
NOTE the records should be randomly shuffled first.
'''
# create categorical statis objects.
logger.warning('detecting dataset')
with open(path, 'rb') as csvfile:
reader = csv.DictReader(csvfile)
for row_id, row in enumerate(reader):
if row_id > topn:
break
for key in categorial_features:
fields[key].register(row[key])
for key, item in fields.items():
feature_dims[key] = item.size()
feature_dims['hour'] = 24
feature_dims['click'] = 1
feature_dims['dnn_input'] = np.sum(
feature_dims[key] for key in categorial_features + ['hour']) + 1
feature_dims['lr_input'] = np.sum(feature_dims[key]
for key in id_features) + 1
return feature_dims
def load_data_meta(meta_path):
'''
Load dataset's meta infomation.
'''
feature_dims, fields = cPickle.load(open(meta_path, 'rb'))
return feature_dims, fields
def concat_sparse_vectors(inputs, dims):
'''
Concaterate more than one sparse vectors into one.
@inputs: list
list of sparse vector
@dims: list of int
dimention of each sparse vector
'''
res = []
assert len(inputs) == len(dims)
start = 0
for no, vec in enumerate(inputs):
for v in vec:
res.append(v + start)
start += dims[no]
return res
class AvazuDataset(object):
'''
Load AVAZU dataset as train set.
'''
def __init__(self,
train_path,
n_records_as_test=-1,
fields=None,
feature_dims=None):
self.train_path = train_path
self.n_records_as_test = n_records_as_test
self.fields = fields
# default is train mode.
self.mode = TaskMode.create_train()
self.categorial_dims = [
feature_dims[key] for key in categorial_features + ['hour']
]
self.id_dims = [feature_dims[key] for key in id_features]
def train(self):
'''
Load trainset.
'''
logger.info("load trainset from %s" % self.train_path)
self.mode = TaskMode.create_train()
with open(self.train_path) as f:
reader = csv.DictReader(f)
for row_id, row in enumerate(reader):
# skip top n lines
if self.n_records_as_test > 0 and row_id < self.n_records_as_test:
continue
rcd = self._parse_record(row)
if rcd:
yield rcd
def test(self):
'''
Load testset.
'''
logger.info("load testset from %s" % self.train_path)
self.mode = TaskMode.create_test()
with open(self.train_path) as f:
reader = csv.DictReader(f)
for row_id, row in enumerate(reader):
# skip top n lines
if self.n_records_as_test > 0 and row_id > self.n_records_as_test:
break
rcd = self._parse_record(row)
if rcd:
yield rcd
def infer(self):
'''
Load inferset.
'''
logger.info("load inferset from %s" % self.train_path)
self.mode = TaskMode.create_infer()
with open(self.train_path) as f:
reader = csv.DictReader(f)
for row_id, row in enumerate(reader):
rcd = self._parse_record(row)
if rcd:
yield rcd
def _parse_record(self, row):
'''
Parse a CSV row and get a record.
'''
record = []
for key in categorial_features:
record.append(self.fields[key].gen(row[key]))
record.append([int(row['hour'][-2:])])
dense_input = concat_sparse_vectors(record, self.categorial_dims)
record = []
for key in id_features:
if 'cross' not in key:
record.append(self.fields[key].gen(row[key]))
else:
fea0 = self.fields[key].cross_fea0
fea1 = self.fields[key].cross_fea1
record.append(self.fields[key].gen_cross_fea(row[fea0], row[
fea1]))
sparse_input = concat_sparse_vectors(record, self.id_dims)
record = [dense_input, sparse_input]
if not self.mode.is_infer():
record.append(list((int(row['click']), )))
return record
def ids2dense(vec, dim):
return vec
def ids2sparse(vec):
return ["%d:1" % x for x in vec]
detect_dataset(args.data_path, args.num_lines_to_detect)
dataset = AvazuDataset(
args.data_path,
args.test_set_size,
fields=fields,
feature_dims=feature_dims)
output_trainset_path = os.path.join(args.output_dir, 'train.txt')
output_testset_path = os.path.join(args.output_dir, 'test.txt')
output_infer_path = os.path.join(args.output_dir, 'infer.txt')
output_meta_path = os.path.join(args.output_dir, 'data.meta.txt')
with open(output_trainset_path, 'w') as f:
for id, record in enumerate(dataset.train()):
if id and id % 10000 == 0:
logger.info("load %d records" % id)
if id > args.train_size:
break
dnn_input, lr_input, click = record
dnn_input = ids2dense(dnn_input, feature_dims['dnn_input'])
lr_input = ids2sparse(lr_input)
line = "%s\t%s\t%d\n" % (' '.join(map(str, dnn_input)),
' '.join(map(str, lr_input)), click[0])
f.write(line)
logger.info('write to %s' % output_trainset_path)
with open(output_testset_path, 'w') as f:
for id, record in enumerate(dataset.test()):
dnn_input, lr_input, click = record
dnn_input = ids2dense(dnn_input, feature_dims['dnn_input'])
lr_input = ids2sparse(lr_input)
line = "%s\t%s\t%d\n" % (' '.join(map(str, dnn_input)),
' '.join(map(str, lr_input)), click[0])
f.write(line)
logger.info('write to %s' % output_testset_path)
with open(output_infer_path, 'w') as f:
for id, record in enumerate(dataset.infer()):
dnn_input, lr_input = record
dnn_input = ids2dense(dnn_input, feature_dims['dnn_input'])
lr_input = ids2sparse(lr_input)
line = "%s\t%s\n" % (
' '.join(map(str, dnn_input)),
' '.join(map(str, lr_input)), )
f.write(line)
if id > args.test_set_size:
break
logger.info('write to %s' % output_infer_path)
with open(output_meta_path, 'w') as f:
lines = [
"dnn_input_dim: %d" % feature_dims['dnn_input'],
"lr_input_dim: %d" % feature_dims['lr_input']
]
f.write('\n'.join(lines))
logger.info('write data meta into %s' % output_meta_path)
# 数据及处理
## 数据集介绍
本教程演示使用Kaggle上CTR任务的数据集\[[3](#参考文献)\]的预处理方法,最终产生本模型需要的格式,详细的数据格式参考[README.md](./README.md)
Wide && Deep Model\[[2](#参考文献)\]的优势是融合稠密特征和大规模稀疏特征,
因此特征处理方面也针对稠密和稀疏两种特征作处理,
其中Deep部分的稠密值全部转化为ID类特征,
通过embedding 来转化为稠密的向量输入;Wide部分主要通过ID的叉乘提升维度。
数据集使用 `csv` 格式存储,其中各个字段内容如下:
- `id` : ad identifier
- `click` : 0/1 for non-click/click
- `hour` : format is YYMMDDHH, so 14091123 means 23:00 on Sept. 11, 2014 UTC.
- `C1` : anonymized categorical variable
- `banner_pos`
- `site_id`
- `site_domain`
- `site_category`
- `app_id`
- `app_domain`
- `app_category`
- `device_id`
- `device_ip`
- `device_model`
- `device_type`
- `device_conn_type`
- `C14-C21` : anonymized categorical variables
## 特征提取
下面我们会简单演示几种特征的提取方式。
原始数据中的特征可以分为以下几类:
1. ID 类特征(稀疏,数量多)
- `id`
- `site_id`
- `app_id`
- `device_id`
2. 类别类特征(稀疏,但数量有限)
- `C1`
- `site_category`
- `device_type`
- `C14-C21`
3. 数值型特征转化为类别型特征
- hour (可以转化成数值,也可以按小时为单位转化为类别)
### 类别类特征
类别类特征的提取方法有以下两种:
1. One-hot 表示作为特征
2. 类似词向量,用一个 Embedding 将每个类别映射到对应的向量
### ID 类特征
ID 类特征的特点是稀疏数据,但量比较大,直接使用 One-hot 表示时维度过大。
一般会作如下处理:
1. 确定表示的最大维度 N
2. newid = id % N
3. 用 newid 作为类别类特征使用
上面的方法尽管存在一定的碰撞概率,但能够处理任意数量的 ID 特征,并保留一定的效果\[[2](#参考文献)\]
### 数值型特征
一般会做如下处理:
- 归一化,直接作为特征输入模型
- 用区间分割处理成类别类特征,稀疏化表示,模糊细微上的差别
## 特征处理
### 类别型特征
类别型特征有有限多种值,在模型中,我们一般使用 Embedding将每种值映射为连续值的向量。
这种特征在输入到模型时,一般使用 One-hot 表示,相关处理方法如下:
```python
class CategoryFeatureGenerator(object):
'''
Generator category features.
Register all records by calling ~register~ first, then call ~gen~ to generate
one-hot representation for a record.
'''
def __init__(self):
self.dic = {'unk': 0}
self.counter = 1
def register(self, key):
'''
Register record.
'''
if key not in self.dic:
self.dic[key] = self.counter
self.counter += 1
def size(self):
return len(self.dic)
def gen(self, key):
'''
Generate one-hot representation for a record.
'''
if key not in self.dic:
res = self.dic['unk']
else:
res = self.dic[key]
return [res]
def __repr__(self):
return '<CategoryFeatureGenerator %d>' % len(self.dic)
```
`CategoryFeatureGenerator` 需要先扫描数据集,得到该类别对应的项集合,之后才能开始生成特征。
我们的实验数据集\[[3](https://www.kaggle.com/c/avazu-ctr-prediction/data)\]已经经过shuffle,可以扫描前面一定数目的记录来近似总的类别项集合(等价于随机抽样),
对于没有抽样上的低频类别项,可以用一个 UNK 的特殊值表示。
```python
fields = {}
for key in categorial_features:
fields[key] = CategoryFeatureGenerator()
def detect_dataset(path, topn, id_fea_space=10000):
'''
Parse the first `topn` records to collect meta information of this dataset.
NOTE the records should be randomly shuffled first.
'''
# create categorical statis objects.
with open(path, 'rb') as csvfile:
reader = csv.DictReader(csvfile)
for row_id, row in enumerate(reader):
if row_id > topn:
break
for key in categorial_features:
fields[key].register(row[key])
```
`CategoryFeatureGenerator` 在注册得到数据集中对应类别信息后,可以对相应记录生成对应的特征表示:
```python
record = []
for key in categorial_features:
record.append(fields[key].gen(row[key]))
```
本任务中,类别类特征会输入到 DNN 中使用。
### ID 类特征
ID 类特征代稀疏值,且值的空间很大的情况,一般用模操作规约到一个有限空间,
之后可以当成类别类特征使用,这里我们会将 ID 类特征输入到 LR 模型中使用。
```python
class IDfeatureGenerator(object):
def __init__(self, max_dim):
'''
@max_dim: int
Size of the id elements' space
'''
self.max_dim = max_dim
def gen(self, key):
'''
Generate one-hot representation for records
'''
return [hash(key) % self.max_dim]
def size(self):
return self.max_dim
```
`IDfeatureGenerator` 不需要预先初始化,可以直接生成特征,比如
```python
record = []
for key in id_features:
if 'cross' not in key:
record.append(fields[key].gen(row[key]))
```
### 交叉类特征
LR 模型作为 Wide & Deep model 的 `wide` 部分,可以输入很 wide 的数据(特征空间的维度很大),
为了充分利用这个优势,我们将演示交叉组合特征构建成更大维度特征的情况,之后塞入到模型中训练。
这里我们依旧使用模操作来约束最终组合出的特征空间的大小,具体实现是直接在 `IDfeatureGenerator` 中添加一个 `gen_cross_feature` 的方法:
```python
def gen_cross_fea(self, fea1, fea2):
key = str(fea1) + str(fea2)
return self.gen(key)
```
比如,我们觉得原始数据中, `device_id``site_id` 有一些关联(比如某个 device 倾向于浏览特定 site),
我们通过组合出两者组合来捕捉这类信息。
```python
fea0 = fields[key].cross_fea0
fea1 = fields[key].cross_fea1
record.append(
fields[key].gen_cross_fea(row[fea0], row[fea1]))
```
### 特征维度
#### Deep submodel(DNN)特征
| feature | dimention |
|------------------|-----------|
| app_category | 21 |
| site_category | 22 |
| device_conn_type | 5 |
| hour | 24 |
| banner_pos | 7 |
| **Total** | 79 |
#### Wide submodel(LR)特征
| Feature | Dimention |
|---------------------|-----------|
| id | 10000 |
| site_id | 10000 |
| app_id | 10000 |
| device_id | 10000 |
| device_id X site_id | 1000000 |
| **Total** | 1,040,000 |
## 输入到 PaddlePaddle 中
Deep 和 Wide 两部分均以 `sparse_binary_vector` 的格式 \[[1](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/api/v1/data_provider/pydataprovider2_en.rst)\] 输入,输入前需要将相关特征拼合,模型最终只接受 3 个 input,
分别是
1. `dnn input` ,DNN 的输入
2. `lr input` , LR 的输入
3. `click` , 标签
拼合特征的方法:
```python
def concat_sparse_vectors(inputs, dims):
'''
concaterate sparse vectors into one
@inputs: list
list of sparse vector
@dims: list of int
dimention of each sparse vector
'''
res = []
assert len(inputs) == len(dims)
start = 0
for no, vec in enumerate(inputs):
for v in vec:
res.append(v + start)
start += dims[no]
return res
```
生成最终特征的代码如下:
```python
# dimentions of the features
categorial_dims = [
feature_dims[key] for key in categorial_features + ['hour']
]
id_dims = [feature_dims[key] for key in id_features]
dense_input = concat_sparse_vectors(record, categorial_dims)
sparse_input = concat_sparse_vectors(record, id_dims)
record = [dense_input, sparse_input]
record.append(list((int(row['click']), )))
yield record
```
## 参考文献
1. <https://github.com/PaddlePaddle/Paddle/blob/develop/doc/api/v1/data_provider/pydataprovider2_en.rst>
2. Mikolov T, Deoras A, Povey D, et al. [Strategies for training large scale neural network language models](https://www.researchgate.net/profile/Lukas_Burget/publication/241637478_Strategies_for_training_large_scale_neural_network_language_models/links/542c14960cf27e39fa922ed3.pdf)[C]//Automatic Speech Recognition and Understanding (ASRU), 2011 IEEE Workshop on. IEEE, 2011: 196-201.
3. <https://www.kaggle.com/c/avazu-ctr-prediction/data>
import gzip
import argparse
import itertools
import paddle.v2 as paddle
import network_conf
from train import dnn_layer_dims
import reader
from utils import logger, ModelType
parser = argparse.ArgumentParser(description="PaddlePaddle CTR example")
parser.add_argument(
'--model_gz_path',
type=str,
required=True,
help="path of model parameters gz file")
parser.add_argument(
'--data_path', type=str, required=True, help="path of the dataset to infer")
parser.add_argument(
'--prediction_output_path',
type=str,
required=True,
help="path to output the prediction")
parser.add_argument(
'--data_meta_path',
type=str,
default="./data.meta",
help="path of trainset's meta info, default is ./data.meta")
parser.add_argument(
'--model_type',
type=int,
required=True,
default=ModelType.CLASSIFICATION,
help='model type, classification: %d, regression %d (default classification)'
% (ModelType.CLASSIFICATION, ModelType.REGRESSION))
args = parser.parse_args()
paddle.init(use_gpu=False, trainer_count=1)
class CTRInferer(object):
def __init__(self, param_path):
logger.info("create CTR model")
dnn_input_dim, lr_input_dim = reader.load_data_meta(args.data_meta_path)
# create the mdoel
self.ctr_model = network_conf.CTRmodel(
dnn_layer_dims,
dnn_input_dim,
lr_input_dim,
model_type=ModelType(args.model_type),
is_infer=True)
# load parameter
logger.info("load model parameters from %s" % param_path)
self.parameters = paddle.parameters.Parameters.from_tar(
gzip.open(param_path, 'r'))
self.inferer = paddle.inference.Inference(
output_layer=self.ctr_model.model,
parameters=self.parameters, )
def infer(self, data_path):
logger.info("infer data...")
dataset = reader.Dataset()
infer_reader = paddle.batch(
dataset.infer(args.data_path), batch_size=1000)
logger.warning('write predictions to %s' % args.prediction_output_path)
output_f = open(args.prediction_output_path, 'w')
for id, batch in enumerate(infer_reader()):
res = self.inferer.infer(input=batch)
predictions = [x for x in itertools.chain.from_iterable(res)]
assert len(batch) == len(
predictions), "predict error, %d inputs, but %d predictions" % (
len(batch), len(predictions))
output_f.write('\n'.join(map(str, predictions)) + '\n')
if __name__ == '__main__':
ctr_inferer = CTRInferer(args.model_gz_path)
ctr_inferer.infer(args.data_path)
import paddle.v2 as paddle
from paddle.v2 import layer
from paddle.v2 import data_type as dtype
from utils import logger, ModelType
class CTRmodel(object):
'''
A CTR model which implements wide && deep learning model.
'''
def __init__(self,
dnn_layer_dims,
dnn_input_dim,
lr_input_dim,
model_type=ModelType.create_classification(),
is_infer=False):
'''
@dnn_layer_dims: list of integer
dims of each layer in dnn
@dnn_input_dim: int
size of dnn's input layer
@lr_input_dim: int
size of lr's input layer
@is_infer: bool
whether to build a infer model
'''
self.dnn_layer_dims = dnn_layer_dims
self.dnn_input_dim = dnn_input_dim
self.lr_input_dim = lr_input_dim
self.model_type = model_type
self.is_infer = is_infer
self._declare_input_layers()
self.dnn = self._build_dnn_submodel_(self.dnn_layer_dims)
self.lr = self._build_lr_submodel_()
# model's prediction
# TODO(superjom) rename it to prediction
if self.model_type.is_classification():
self.model = self._build_classification_model(self.dnn, self.lr)
if self.model_type.is_regression():
self.model = self._build_regression_model(self.dnn, self.lr)
def _declare_input_layers(self):
self.dnn_merged_input = layer.data(
name='dnn_input',
type=paddle.data_type.sparse_binary_vector(self.dnn_input_dim))
self.lr_merged_input = layer.data(
name='lr_input',
type=paddle.data_type.sparse_float_vector(self.lr_input_dim))
if not self.is_infer:
self.click = paddle.layer.data(
name='click', type=dtype.dense_vector(1))
def _build_dnn_submodel_(self, dnn_layer_dims):
'''
build DNN submodel.
'''
dnn_embedding = layer.fc(input=self.dnn_merged_input,
size=dnn_layer_dims[0])
_input_layer = dnn_embedding
for i, dim in enumerate(dnn_layer_dims[1:]):
fc = layer.fc(input=_input_layer,
size=dim,
act=paddle.activation.Relu(),
name='dnn-fc-%d' % i)
_input_layer = fc
return _input_layer
def _build_lr_submodel_(self):
'''
config LR submodel
'''
fc = layer.fc(input=self.lr_merged_input,
size=1,
act=paddle.activation.Relu())
return fc
def _build_classification_model(self, dnn, lr):
merge_layer = layer.concat(input=[dnn, lr])
self.output = layer.fc(
input=merge_layer,
size=1,
# use sigmoid function to approximate ctr rate, a float value between 0 and 1.
act=paddle.activation.Sigmoid())
if not self.is_infer:
self.train_cost = paddle.layer.multi_binary_label_cross_entropy_cost(
input=self.output, label=self.click)
return self.output
def _build_regression_model(self, dnn, lr):
merge_layer = layer.concat(input=[dnn, lr])
self.output = layer.fc(input=merge_layer,
size=1,
act=paddle.activation.Sigmoid())
if not self.is_infer:
self.train_cost = paddle.layer.square_error_cost(
input=self.output, label=self.click)
return self.output
from utils import logger, TaskMode, load_dnn_input_record, load_lr_input_record
feeding_index = {'dnn_input': 0, 'lr_input': 1, 'click': 2}
class Dataset(object):
def train(self, path):
'''
Load trainset.
'''
logger.info("load trainset from %s" % path)
mode = TaskMode.create_train()
return self._parse_creator(path, mode)
def test(self, path):
'''
Load testset.
'''
logger.info("load testset from %s" % path)
mode = TaskMode.create_test()
return self._parse_creator(path, mode)
def infer(self, path):
'''
Load infer set.
'''
logger.info("load inferset from %s" % path)
mode = TaskMode.create_infer()
return self._parse_creator(path, mode)
def _parse_creator(self, path, mode):
'''
Parse dataset.
'''
def _parse():
with open(path) as f:
for line_id, line in enumerate(f):
fs = line.strip().split('\t')
dnn_input = load_dnn_input_record(fs[0])
lr_input = load_lr_input_record(fs[1])
if not mode.is_infer():
click = [int(fs[2])]
yield dnn_input, lr_input, click
else:
yield dnn_input, lr_input
return _parse
def load_data_meta(path):
'''
load data meta info from path, return (dnn_input_dim, lr_input_dim)
'''
with open(path) as f:
lines = f.read().split('\n')
err_info = "wrong meta format"
assert len(lines) == 2, err_info
assert 'dnn_input_dim:' in lines[0] and 'lr_input_dim:' in lines[
1], err_info
res = map(int, [_.split(':')[1] for _ in lines])
logger.info('dnn input dim: %d' % res[0])
logger.info('lr input dim: %d' % res[1])
return res
import argparse
import gzip
import reader
import paddle.v2 as paddle
from utils import logger, ModelType
from network_conf import CTRmodel
def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle CTR example")
parser.add_argument(
'--train_data_path',
type=str,
required=True,
help="path of training dataset")
parser.add_argument(
'--test_data_path', type=str, help='path of testing dataset')
parser.add_argument(
'--batch_size',
type=int,
default=10000,
help="size of mini-batch (default:10000)")
parser.add_argument(
'--num_passes', type=int, default=10, help="number of passes to train")
parser.add_argument(
'--model_output_prefix',
type=str,
default='./ctr_models',
help='prefix of path for model to store (default: ./ctr_models)')
parser.add_argument(
'--data_meta_file',
type=str,
required=True,
help='path of data meta info file', )
parser.add_argument(
'--model_type',
type=int,
required=True,
default=ModelType.CLASSIFICATION,
help='model type, classification: %d, regression %d (default classification)'
% (ModelType.CLASSIFICATION, ModelType.REGRESSION))
return parser.parse_args()
dnn_layer_dims = [128, 64, 32, 1]
# ==============================================================================
# cost and train period
# ==============================================================================
def train():
args = parse_args()
args.model_type = ModelType(args.model_type)
paddle.init(use_gpu=False, trainer_count=1)
dnn_input_dim, lr_input_dim = reader.load_data_meta(args.data_meta_file)
# create ctr model.
model = CTRmodel(
dnn_layer_dims,
dnn_input_dim,
lr_input_dim,
model_type=args.model_type,
is_infer=False)
params = paddle.parameters.create(model.train_cost)
optimizer = paddle.optimizer.AdaGrad()
trainer = paddle.trainer.SGD(cost=model.train_cost,
parameters=params,
update_equation=optimizer)
dataset = reader.Dataset()
def __event_handler__(event):
if isinstance(event, paddle.event.EndIteration):
num_samples = event.batch_id * args.batch_size
if event.batch_id % 100 == 0:
logger.warning("Pass %d, Samples %d, Cost %f, %s" % (
event.pass_id, num_samples, event.cost, event.metrics))
if event.batch_id % 1000 == 0:
if args.test_data_path:
result = trainer.test(
reader=paddle.batch(
dataset.test(args.test_data_path),
batch_size=args.batch_size),
feeding=reader.feeding_index)
logger.warning("Test %d-%d, Cost %f, %s" %
(event.pass_id, event.batch_id, result.cost,
result.metrics))
path = "{}-pass-{}-batch-{}-test-{}.tar.gz".format(
args.model_output_prefix, event.pass_id, event.batch_id,
result.cost)
with gzip.open(path, 'w') as f:
trainer.save_parameter_to_tar(f)
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(
dataset.train(args.train_data_path), buf_size=500),
batch_size=args.batch_size),
feeding=reader.feeding_index,
event_handler=__event_handler__,
num_passes=args.num_passes)
if __name__ == '__main__':
train()
import logging
logging.basicConfig()
logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
class TaskMode:
TRAIN_MODE = 0
TEST_MODE = 1
INFER_MODE = 2
def __init__(self, mode):
self.mode = mode
def is_train(self):
return self.mode == self.TRAIN_MODE
def is_test(self):
return self.mode == self.TEST_MODE
def is_infer(self):
return self.mode == self.INFER_MODE
@staticmethod
def create_train():
return TaskMode(TaskMode.TRAIN_MODE)
@staticmethod
def create_test():
return TaskMode(TaskMode.TEST_MODE)
@staticmethod
def create_infer():
return TaskMode(TaskMode.INFER_MODE)
class ModelType:
CLASSIFICATION = 0
REGRESSION = 1
def __init__(self, mode):
self.mode = mode
def is_classification(self):
return self.mode == self.CLASSIFICATION
def is_regression(self):
return self.mode == self.REGRESSION
@staticmethod
def create_classification():
return ModelType(ModelType.CLASSIFICATION)
@staticmethod
def create_regression():
return ModelType(ModelType.REGRESSION)
def load_dnn_input_record(sent):
return map(int, sent.split())
def load_lr_input_record(sent):
res = []
for _ in [x.split(':') for x in sent.split()]:
res.append((
int(_[0]),
float(_[1]), ))
return res
运行本目录下的程序示例需要使用PaddlePaddle v0.10.0 版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html)中的说明更新PaddlePaddle安装版本。
---
# 基于深度因子分解机的点击率预估模型
## 介绍
本模型实现了下述论文中提出的DeepFM模型:
```text
@inproceedings{guo2017deepfm,
title={DeepFM: A Factorization-Machine based Neural Network for CTR Prediction},
author={Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li and Xiuqiang He},
booktitle={the Twenty-Sixth International Joint Conference on Artificial Intelligence (IJCAI)},
pages={1725--1731},
year={2017}
}
```
DeepFM模型把因子分解机和深度神经网络的低阶和高阶特征的相互作用结合起来,有关因子分解机的详细信息,请参考论文[因子分解机](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf)
## 数据集
本文使用的是Kaggle公司举办的[展示广告竞赛](https://www.kaggle.com/c/criteo-display-ad-challenge/)中所使用的Criteo数据集。
每一行是一次广告展示的特征,第一列是一个标签,表示这次广告展示是否被点击。总共有39个特征,其中13个特征采用整型值,另外26个特征是类别类特征。测试集中是没有标签的。
下载数据集:
```bash
cd data && ./download.sh && cd ..
```
## 模型
DeepFM模型是由因子分解机(FM)和深度神经网络(DNN)组成的。所有的输入特征都会同时输入FM和DNN,最后把FM和DNN的输出结合在一起形成最终的输出。DNN中稀疏特征生成的嵌入层与FM层中的隐含向量(因子)共享参数。
PaddlePaddle中的因子分解机层负责计算二阶组合特征的相互关系。以下的代码示例结合了因子分解机层和全连接层,形成了完整的的因子分解机:
```python
def fm_layer(input, factor_size):
first_order = paddle.layer.fc(input=input, size=1, act=paddle.activation.Linear())
second_order = paddle.layer.factorization_machine(input=input, factor_size=factor_size)
fm = paddle.layer.addto(input=[first_order, second_order],
act=paddle.activation.Linear(),
bias_attr=False)
return fm
```
## 数据准备
处理原始数据集,整型特征使用min-max归一化方法规范到[0, 1],类别类特征使用了one-hot编码。原始数据集分割成两部分:90%用于训练,其他10%用于训练过程中的验证。
```bash
python preprocess.py --datadir ./data/raw --outdir ./data
```
## 训练
训练的命令行选项可以通过`python train.py -h`列出。
训练模型:
```bash
python train.py \
--train_data_path data/train.txt \
--test_data_path data/valid.txt \
2>&1 | tee train.log
```
训练到第9轮的第40000个batch后,测试的AUC为0.807178,误差(cost)为0.445196。
## 预测
预测的命令行选项可以通过`python infer.py -h`列出。
对测试集进行预测:
```bash
python infer.py \
--model_gz_path models/model-pass-9-batch-10000.tar.gz \
--data_path data/test.txt \
--prediction_output_path ./predict.txt
```
The minimum PaddlePaddle version needed for the code sample in this directory is v0.11.0. If you are on a version of PaddlePaddle earlier than v0.11.0, [please update your installation](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html).
---
# Deep Factorization Machine for Click-Through Rate prediction
## Introduction
This model implements the DeepFM proposed in the following paper:
```text
@inproceedings{guo2017deepfm,
title={DeepFM: A Factorization-Machine based Neural Network for CTR Prediction},
author={Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li and Xiuqiang He},
booktitle={the Twenty-Sixth International Joint Conference on Artificial Intelligence (IJCAI)},
pages={1725--1731},
year={2017}
}
```
The DeepFm combines factorization machine and deep neural networks to model
both low order and high order feature interactions. For details of the
factorization machines, please refer to the paper [factorization
machines](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf)
## Dataset
This example uses Criteo dataset which was used for the [Display Advertising
Challenge](https://www.kaggle.com/c/criteo-display-ad-challenge/)
hosted by Kaggle.
Each row is the features for an ad display and the first column is a label
indicating whether this ad has been clicked or not. There are 39 features in
total. 13 features take integer values and the other 26 features are
categorical features. For the test dataset, the labels are omitted.
Download dataset:
```bash
cd data && ./download.sh && cd ..
```
## Model
The DeepFM model is composed of the factorization machine layer (FM) and deep
neural networks (DNN). All the input features are feeded to both FM and DNN.
The output from FM and DNN are combined to form the final output. The embedding
layer for sparse features in the DNN shares the parameters with the latent
vectors (factors) of the FM layer.
The factorization machine layer in PaddlePaddle computes the second order
interactions. The following code example combines the factorization machine
layer and fully connected layer to form the full version of factorization
machine:
```python
def fm_layer(input, factor_size):
first_order = paddle.layer.fc(input=input, size=1, act=paddle.activation.Linear())
second_order = paddle.layer.factorization_machine(input=input, factor_size=factor_size)
fm = paddle.layer.addto(input=[first_order, second_order],
act=paddle.activation.Linear(),
bias_attr=False)
return fm
```
## Data preparation
To preprocess the raw dataset, the integer features are clipped then min-max
normalized to [0, 1] and the categorical features are one-hot encoded. The raw
training dataset are splited such that 90% are used for training and the other
10% are used for validation during training.
```bash
python preprocess.py --datadir ./data/raw --outdir ./data
```
## Train
The command line options for training can be listed by `python train.py -h`.
To train the model:
```bash
python train.py \
--train_data_path data/train.txt \
--test_data_path data/valid.txt \
2>&1 | tee train.log
```
After training pass 9 batch 40000, the testing AUC is `0.807178` and the testing
cost is `0.445196`.
## Infer
The command line options for infering can be listed by `python infer.py -h`.
To make inference for the test dataset:
```bash
python infer.py \
--model_gz_path models/model-pass-9-batch-10000.tar.gz \
--data_path data/test.txt \
--prediction_output_path ./predict.txt
```
#!/bin/bash
wget --no-check-certificate https://s3-eu-west-1.amazonaws.com/criteo-labs/dac.tar.gz
tar zxf dac.tar.gz
rm -f dac.tar.gz
mkdir raw
mv ./*.txt raw/
import os
import gzip
import argparse
import itertools
import paddle.v2 as paddle
from network_conf import DeepFM
import reader
def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle DeepFM example")
parser.add_argument(
'--model_gz_path',
type=str,
required=True,
help="The path of model parameters gz file")
parser.add_argument(
'--data_path',
type=str,
required=True,
help="The path of the dataset to infer")
parser.add_argument(
'--prediction_output_path',
type=str,
required=True,
help="The path to output the prediction")
parser.add_argument(
'--factor_size',
type=int,
default=10,
help="The factor size for the factorization machine (default:10)")
return parser.parse_args()
def infer():
args = parse_args()
paddle.init(use_gpu=False, trainer_count=1)
model = DeepFM(args.factor_size, infer=True)
parameters = paddle.parameters.Parameters.from_tar(
gzip.open(args.model_gz_path, 'r'))
inferer = paddle.inference.Inference(
output_layer=model, parameters=parameters)
dataset = reader.Dataset()
infer_reader = paddle.batch(dataset.infer(args.data_path), batch_size=1000)
with open(args.prediction_output_path, 'w') as out:
for id, batch in enumerate(infer_reader()):
res = inferer.infer(input=batch)
predictions = [x for x in itertools.chain.from_iterable(res)]
out.write('\n'.join(map(str, predictions)) + '\n')
if __name__ == '__main__':
infer()
import paddle.v2 as paddle
dense_feature_dim = 13
sparse_feature_dim = 117568
def fm_layer(input, factor_size, fm_param_attr):
first_order = paddle.layer.fc(input=input,
size=1,
act=paddle.activation.Linear())
second_order = paddle.layer.factorization_machine(
input=input,
factor_size=factor_size,
act=paddle.activation.Linear(),
param_attr=fm_param_attr)
out = paddle.layer.addto(
input=[first_order, second_order],
act=paddle.activation.Linear(),
bias_attr=False)
return out
def DeepFM(factor_size, infer=False):
dense_input = paddle.layer.data(
name="dense_input",
type=paddle.data_type.dense_vector(dense_feature_dim))
sparse_input = paddle.layer.data(
name="sparse_input",
type=paddle.data_type.sparse_binary_vector(sparse_feature_dim))
sparse_input_ids = [
paddle.layer.data(
name="C" + str(i),
type=paddle.data_type.integer_value(sparse_feature_dim))
for i in range(1, 27)
]
dense_fm = fm_layer(
dense_input,
factor_size,
fm_param_attr=paddle.attr.Param(name="DenseFeatFactors"))
sparse_fm = fm_layer(
sparse_input,
factor_size,
fm_param_attr=paddle.attr.Param(name="SparseFeatFactors"))
def embedding_layer(input):
return paddle.layer.embedding(
input=input,
size=factor_size,
param_attr=paddle.attr.Param(name="SparseFeatFactors"))
sparse_embed_seq = map(embedding_layer, sparse_input_ids)
sparse_embed = paddle.layer.concat(sparse_embed_seq)
fc1 = paddle.layer.fc(input=[sparse_embed, dense_input],
size=400,
act=paddle.activation.Relu())
fc2 = paddle.layer.fc(input=fc1, size=400, act=paddle.activation.Relu())
fc3 = paddle.layer.fc(input=fc2, size=400, act=paddle.activation.Relu())
predict = paddle.layer.fc(input=[dense_fm, sparse_fm, fc3],
size=1,
act=paddle.activation.Sigmoid())
if not infer:
label = paddle.layer.data(
name="label", type=paddle.data_type.dense_vector(1))
cost = paddle.layer.multi_binary_label_cross_entropy_cost(
input=predict, label=label)
paddle.evaluator.classification_error(
name="classification_error", input=predict, label=label)
paddle.evaluator.auc(name="auc", input=predict, label=label)
return cost
else:
return predict
"""
Preprocess Criteo dataset. This dataset was used for the Display Advertising
Challenge (https://www.kaggle.com/c/criteo-display-ad-challenge).
"""
import os
import sys
import click
import random
import collections
# There are 13 integer features and 26 categorical features
continous_features = range(1, 14)
categorial_features = range(14, 40)
# Clip integer features. The clip point for each integer feature
# is derived from the 95% quantile of the total values in each feature
continous_clip = [20, 600, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
class CategoryDictGenerator:
"""
Generate dictionary for each of the categorical features
"""
def __init__(self, num_feature):
self.dicts = []
self.num_feature = num_feature
for i in range(0, num_feature):
self.dicts.append(collections.defaultdict(int))
def build(self, datafile, categorial_features, cutoff=0):
with open(datafile, 'r') as f:
for line in f:
features = line.rstrip('\n').split('\t')
for i in range(0, self.num_feature):
if features[categorial_features[i]] != '':
self.dicts[i][features[categorial_features[i]]] += 1
for i in range(0, self.num_feature):
self.dicts[i] = filter(lambda x: x[1] >= cutoff,
self.dicts[i].items())
self.dicts[i] = sorted(self.dicts[i], key=lambda x: (-x[1], x[0]))
vocabs, _ = list(zip(*self.dicts[i]))
self.dicts[i] = dict(zip(vocabs, range(1, len(vocabs) + 1)))
self.dicts[i]['<unk>'] = 0
def gen(self, idx, key):
if key not in self.dicts[idx]:
res = self.dicts[idx]['<unk>']
else:
res = self.dicts[idx][key]
return res
def dicts_sizes(self):
return map(len, self.dicts)
class ContinuousFeatureGenerator:
"""
Normalize the integer features to [0, 1] by min-max normalization
"""
def __init__(self, num_feature):
self.num_feature = num_feature
self.min = [sys.maxint] * num_feature
self.max = [-sys.maxint] * num_feature
def build(self, datafile, continous_features):
with open(datafile, 'r') as f:
for line in f:
features = line.rstrip('\n').split('\t')
for i in range(0, self.num_feature):
val = features[continous_features[i]]
if val != '':
val = int(val)
if val > continous_clip[i]:
val = continous_clip[i]
self.min[i] = min(self.min[i], val)
self.max[i] = max(self.max[i], val)
def gen(self, idx, val):
if val == '':
return 0.0
val = float(val)
return (val - self.min[idx]) / (self.max[idx] - self.min[idx])
@click.command("preprocess")
@click.option("--datadir", type=str, help="Path to raw criteo dataset")
@click.option("--outdir", type=str, help="Path to save the processed data")
def preprocess(datadir, outdir):
"""
All the 13 integer features are normalzied to continous values and these
continous features are combined into one vecotr with dimension 13.
Each of the 26 categorical features are one-hot encoded and all the one-hot
vectors are combined into one sparse binary vector.
"""
dists = ContinuousFeatureGenerator(len(continous_features))
dists.build(os.path.join(datadir, 'train.txt'), continous_features)
dicts = CategoryDictGenerator(len(categorial_features))
dicts.build(
os.path.join(datadir, 'train.txt'), categorial_features, cutoff=200)
dict_sizes = dicts.dicts_sizes()
categorial_feature_offset = [0]
for i in range(1, len(categorial_features)):
offset = categorial_feature_offset[i - 1] + dict_sizes[i - 1]
categorial_feature_offset.append(offset)
random.seed(0)
# 90% of the data are used for training, and 10% of the data are used
# for validation.
with open(os.path.join(outdir, 'train.txt'), 'w') as out_train:
with open(os.path.join(outdir, 'valid.txt'), 'w') as out_valid:
with open(os.path.join(datadir, 'train.txt'), 'r') as f:
for line in f:
features = line.rstrip('\n').split('\t')
continous_vals = []
for i in range(0, len(continous_features)):
val = dists.gen(i, features[continous_features[i]])
continous_vals.append("{0:.6f}".format(val).rstrip('0')
.rstrip('.'))
categorial_vals = []
for i in range(0, len(categorial_features)):
val = dicts.gen(i, features[categorial_features[
i]]) + categorial_feature_offset[i]
categorial_vals.append(str(val))
continous_vals = ','.join(continous_vals)
categorial_vals = ','.join(categorial_vals)
label = features[0]
if random.randint(0, 9999) % 10 != 0:
out_train.write('\t'.join(
[continous_vals, categorial_vals, label]) + '\n')
else:
out_valid.write('\t'.join(
[continous_vals, categorial_vals, label]) + '\n')
with open(os.path.join(outdir, 'test.txt'), 'w') as out:
with open(os.path.join(datadir, 'test.txt'), 'r') as f:
for line in f:
features = line.rstrip('\n').split('\t')
continous_vals = []
for i in range(0, len(continous_features)):
val = dists.gen(i, features[continous_features[i] - 1])
continous_vals.append("{0:.6f}".format(val).rstrip('0')
.rstrip('.'))
categorial_vals = []
for i in range(0, len(categorial_features)):
val = dicts.gen(i, features[categorial_features[
i] - 1]) + categorial_feature_offset[i]
categorial_vals.append(str(val))
continous_vals = ','.join(continous_vals)
categorial_vals = ','.join(categorial_vals)
out.write('\t'.join([continous_vals, categorial_vals]) + '\n')
if __name__ == "__main__":
preprocess()
class Dataset:
def _reader_creator(self, path, is_infer):
def reader():
with open(path, 'r') as f:
for line in f:
features = line.rstrip('\n').split('\t')
dense_feature = map(float, features[0].split(','))
sparse_feature = map(int, features[1].split(','))
if not is_infer:
label = [float(features[2])]
yield [dense_feature, sparse_feature
] + sparse_feature + [label]
else:
yield [dense_feature, sparse_feature] + sparse_feature
return reader
def train(self, path):
return self._reader_creator(path, False)
def test(self, path):
return self._reader_creator(path, False)
def infer(self, path):
return self._reader_creator(path, True)
feeding = {
'dense_input': 0,
'sparse_input': 1,
'C1': 2,
'C2': 3,
'C3': 4,
'C4': 5,
'C5': 6,
'C6': 7,
'C7': 8,
'C8': 9,
'C9': 10,
'C10': 11,
'C11': 12,
'C12': 13,
'C13': 14,
'C14': 15,
'C15': 16,
'C16': 17,
'C17': 18,
'C18': 19,
'C19': 20,
'C20': 21,
'C21': 22,
'C22': 23,
'C23': 24,
'C24': 25,
'C25': 26,
'C26': 27,
'label': 28
}
import os
import gzip
import logging
import argparse
import paddle.v2 as paddle
from network_conf import DeepFM
import reader
logging.basicConfig()
logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle DeepFM example")
parser.add_argument(
'--train_data_path',
type=str,
required=True,
help="The path of training dataset")
parser.add_argument(
'--test_data_path',
type=str,
required=True,
help="The path of testing dataset")
parser.add_argument(
'--batch_size',
type=int,
default=1000,
help="The size of mini-batch (default:1000)")
parser.add_argument(
'--num_passes',
type=int,
default=10,
help="The number of passes to train (default: 10)")
parser.add_argument(
'--factor_size',
type=int,
default=10,
help="The factor size for the factorization machine (default:10)")
parser.add_argument(
'--model_output_dir',
type=str,
default='models',
help='The path for model to store (default: models)')
return parser.parse_args()
def train():
args = parse_args()
if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
paddle.init(use_gpu=False, trainer_count=1)
optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
model = DeepFM(args.factor_size)
params = paddle.parameters.create(model)
trainer = paddle.trainer.SGD(cost=model,
parameters=params,
update_equation=optimizer)
dataset = reader.Dataset()
def __event_handler__(event):
if isinstance(event, paddle.event.EndIteration):
num_samples = event.batch_id * args.batch_size
if event.batch_id % 100 == 0:
logger.warning("Pass %d, Batch %d, Samples %d, Cost %f, %s" %
(event.pass_id, event.batch_id, num_samples,
event.cost, event.metrics))
if event.batch_id % 10000 == 0:
if args.test_data_path:
result = trainer.test(
reader=paddle.batch(
dataset.test(args.test_data_path),
batch_size=args.batch_size),
feeding=reader.feeding)
logger.warning("Test %d-%d, Cost %f, %s" %
(event.pass_id, event.batch_id, result.cost,
result.metrics))
path = "{}/model-pass-{}-batch-{}.tar.gz".format(
args.model_output_dir, event.pass_id, event.batch_id)
with gzip.open(path, 'w') as f:
trainer.save_parameter_to_tar(f)
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(
dataset.train(args.train_data_path),
buf_size=args.batch_size * 10000),
batch_size=args.batch_size),
feeding=reader.feeding,
event_handler=__event_handler__,
num_passes=args.num_passes)
if __name__ == '__main__':
train()
运行本目录下的程序示例需要使用PaddlePaddle v0.10.0 版本。如果您的PaddlePaddle安装版本低于此版本要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。
---
# 深度结构化语义模型 (Deep Structured Semantic Models, DSSM)
DSSM使用DNN模型在一个连续的语义空间中学习文本低纬的表示向量,并且建模两个句子间的语义相似度。本例演示如何使用PaddlePaddle实现一个通用的DSSM 模型,用于建模两个字符串间的语义相似度,模型实现支持通用的数据格式,用户替换数据便可以在真实场景中使用该模型。
## 背景介绍
DSSM \[[1](##参考文献)\]是微软研究院13年提出来的经典的语义模型,用于学习两个文本之间的语义距离,广义上模型也可以推广和适用如下场景:
1. CTR预估模型,衡量用户搜索词(Query)与候选网页集合(Documents)之间的相关联程度。
2. 文本相关性,衡量两个字符串间的语义相关程度。
3. 自动推荐,衡量User与被推荐的Item之间的关联程度。
DSSM 已经发展成了一个框架,可以很自然地建模两个记录之间的距离关系,例如对于文本相关性问题,可以用余弦相似度 (cosin similarity) 来刻画语义距离;而对于搜索引擎的结果排序,可以在DSSM上接上Rank损失训练出一个排序模型。
## 模型简介
在原论文\[[1](#参考文献)\]中,DSSM模型用来衡量用户搜索词 Query 和文档集合 Documents 之间隐含的语义关系,模型结构如下
<p align="center">
<img src="./images/dssm.png"/><br/><br/>
图 1. DSSM 原始结构
</p>
其贯彻的思想是, **用DNN将高维特征向量转化为低纬空间的连续向量(图中红色框部分)****在上层使用cosine similarity来衡量用户搜索词与候选文档间的语义相关性**
在最顶层损失函数的设计上,原始模型使用类似Word2Vec中负例采样的方法,一个Query会抽取正例 $D+$ 和4个负例 $D-$ 整体上算条件概率用对数似然函数作为损失,这也就是图 1中类似 $P(D_1|Q)$ 的结构,具体细节请参考原论文。
随着后续优化DSSM模型的结构得以简化\[[3](#参考文献)\],演变为:
<p align="center">
<img src="./images/dssm2.png" width="600"/><br/><br/>
图 2. DSSM通用结构
</p>
图中的空白方框可以用任何模型替代,例如:全连接FC,卷积CNN,RNN等。该模型结构专门用于衡量两个元素(比如字符串)间的语义距离。在实际任务中,DSSM模型会作为基础的积木,搭配上不同的损失函数来实现具体的功能,比如:
- 在排序学习中,将 图 2 中结构添加 pairwise rank损失,变成一个排序模型
- 在CTR预估中,对点击与否做0,1二元分类,添加交叉熵损失变成一个分类模型
- 在需要对一个子串打分时,可以使用余弦相似度来计算相似度,变成一个回归模型
本例提供一个比较通用的解决方案,在模型任务类型上支持:
- 分类
- [-1, 1] 值域内的回归
- Pairwise-Rank
在生成低纬语义向量的模型结构上,支持以下三种:
- FC, 多层全连接层
- CNN,卷积神经网络
- RNN,递归神经网络
## 模型实现
DSSM模型可以拆成三部分:分别是左边和右边的DNN,以及顶层的损失函数。在复杂任务中,左右两边DNN的结构可以不同。在原始论文中左右网络分别学习Query和Document的语义向量,两者数据的数据不同,建议对应定制DNN的结构。
**本例中为了简便和通用,将左右两个DNN的结构设为相同,因此只提供三个选项FC、CNN、RNN**
损失函数的设计也支持三种类型:分类, 回归, 排序;其中,在回归和排序两种损失中,左右两边的匹配程度通过余弦相似度(cosine similairty)来计算;在分类任务中,类别预测的分布通过softmax计算。
在其它教程中,对上述很多内容都有过详细的介绍,例如:
- 如何CNN, FC 做文本信息提取可以参考 [text classification](https://github.com/PaddlePaddle/models/blob/develop/text_classification/README.md#模型详解)
- RNN/GRU 的内容可以参考 [Machine Translation](https://github.com/PaddlePaddle/book/blob/develop/08.machine_translation/README.md#gated-recurrent-unit-gru)
- Pairwise Rank即排序学习可参考 [learn to rank](https://github.com/PaddlePaddle/models/blob/develop/ltr/README.md)
相关原理在此不再赘述,本文接下来的篇幅主要集中介绍使用PaddlePaddle实现这些结构上。
如图3,回归和分类模型的结构相似:
<p align="center">
<img src="./images/dssm3.jpg"/><br/><br/>
图 3. DSSM for REGRESSION or CLASSIFICATION
</p>
最重要的组成部分包括词向量,图中`(1)`,`(2)`两个低纬向量的学习器(可以用RNN/CNN/FC中的任意一种实现),最上层对应的损失函数。
Pairwise Rank的结构会复杂一些,图 4. 中的结构会出现两次,增加了对应的损失函数,模型总体思想是:
- 给定同一个source(源)为左右两个target(目标)分别打分——`(a),(b)`,学习目标是(a),(b)之间的大小关系
- `(a)``(b)`类似图3中结构,用于给source和target的pair打分
- `(1)``(2)`的结构其实是共用的,都表示同一个source,图中为了表达效果展开成两个
<p align="center">
<img src="./images/dssm2.jpg"/><br/><br/>
图 4. DSSM for Pairwise Rank
</p>
下面是各个部分的具体实现,相关代码均包含在 `./network_conf.py` 中。
### 创建文本的词向量表
```python
def create_embedding(self, input, prefix=''):
"""
Create word embedding. The `prefix` is added in front of the name of
embedding"s learnable parameter.
"""
logger.info("Create embedding table [%s] whose dimention is %d" %
(prefix, self.dnn_dims[0]))
emb = paddle.layer.embedding(
input=input,
size=self.dnn_dims[0],
param_attr=ParamAttr(name='%s_emb.w' % prefix))
return emb
```
由于输入给词向量表(embedding table)的是一个句子对应的词的ID的列表 ,因此词向量表输出的是词向量的序列。
### CNN 结构实现
```python
def create_cnn(self, emb, prefix=''):
"""
A multi-layer CNN.
:param emb: The word embedding.
:type emb: paddle.layer
:param prefix: The prefix will be added to of layers' names.
:type prefix: str
"""
def create_conv(context_len, hidden_size, prefix):
key = "%s_%d_%d" % (prefix, context_len, hidden_size)
conv = paddle.networks.sequence_conv_pool(
input=emb,
context_len=context_len,
hidden_size=hidden_size,
# set parameter attr for parameter sharing
context_proj_param_attr=ParamAttr(name=key + "contex_proj.w"),
fc_param_attr=ParamAttr(name=key + "_fc.w"),
fc_bias_attr=ParamAttr(name=key + "_fc.b"),
pool_bias_attr=ParamAttr(name=key + "_pool.b"))
return conv
conv_3 = create_conv(3, self.dnn_dims[1], "cnn")
conv_4 = create_conv(4, self.dnn_dims[1], "cnn")
return paddle.layer.concat(input=[conv_3, conv_4])
```
CNN 接受词向量序列,通过卷积和池化操作捕捉到原始句子的关键信息,最终输出一个语义向量(可以认为是句子向量)。
本例的实现中,分别使用了窗口长度为3和4的CNN学到的句子向量按元素求和得到最终的句子向量。
### RNN 结构实现
RNN很适合学习变长序列的信息,使用RNN来学习句子的信息几乎是自然语言处理任务的标配。
```python
def create_rnn(self, emb, prefix=''):
"""
A GRU sentence vector learner.
"""
gru = paddle.networks.simple_gru(
input=emb,
size=self.dnn_dims[1],
mixed_param_attr=ParamAttr(name='%s_gru_mixed.w' % prefix),
mixed_bias_param_attr=ParamAttr(name="%s_gru_mixed.b" % prefix),
gru_param_attr=ParamAttr(name='%s_gru.w' % prefix),
gru_bias_attr=ParamAttr(name="%s_gru.b" % prefix))
sent_vec = paddle.layer.last_seq(gru)
return sent_vec
```
### 多层全连接网络FC
```python
def create_fc(self, emb, prefix=''):
"""
A multi-layer fully connected neural networks.
:param emb: The output of the embedding layer
:type emb: paddle.layer
:param prefix: A prefix will be added to the layers' names.
:type prefix: str
"""
_input_layer = paddle.layer.pooling(
input=emb, pooling_type=paddle.pooling.Max())
fc = paddle.layer.fc(
input=_input_layer,
size=self.dnn_dims[1],
param_attr=ParamAttr(name='%s_fc.w' % prefix),
bias_attr=ParamAttr(name="%s_fc.b" % prefix))
return fc
```
在构建全连接网络时首先使用`paddle.layer.pooling` 对词向量序列进行最大池化操作,将边长序列转化为一个固定维度向量,作为整个句子的语义表达,使用最大池化能够降低句子长度对句向量表达的影响。
### 多层DNN
在 CNN/DNN/FC提取出 semantic vector后,在上层可继续接多层FC来实现深层DNN结构。
```python
def create_dnn(self, sent_vec, prefix):
if len(self.dnn_dims) > 1:
_input_layer = sent_vec
for id, dim in enumerate(self.dnn_dims[1:]):
name = "%s_fc_%d_%d" % (prefix, id, dim)
fc = paddle.layer.fc(
input=_input_layer,
size=dim,
act=paddle.activation.Tanh(),
param_attr=ParamAttr(name='%s.w' % name),
bias_attr=ParamAttr(name='%s.b' % name),
)
_input_layer = fc
return _input_layer
```
### 分类及回归
分类和回归的结构比较相似,具体实现请参考[network_conf.py]( https://github.com/PaddlePaddle/models/blob/develop/dssm/network_conf.py)中的
`_build_classification_or_regression_model` 函数。
### Pairwise Rank
Pairwise Rank复用上面的DNN结构,同一个source对两个target求相似度打分,如果左边的target打分高,预测为1,否则预测为 0。实现请参考 [network_conf.py]( https://github.com/PaddlePaddle/models/blob/develop/dssm/network_conf.py) 中的`_build_rank_model` 函数。
## 数据格式
`./data` 中有简单的示例数据
### 回归的数据格式
```
# 3 fields each line:
# - source word list
# - target word list
# - target
<word list> \t <word list> \t <float>
```
比如:
```
苹果 六 袋 苹果 6s 0.1
新手 汽车 驾驶 驾校 培训 0.9
```
### 分类的数据格式
```
# 3 fields each line:
# - source word list
# - target word list
# - target
<word list> \t <word list> \t <label>
```
比如:
```
苹果 六 袋 苹果 6s 0
新手 汽车 驾驶 驾校 培训 1
```
### 排序的数据格式
```
# 4 fields each line:
# - source word list
# - target1 word list
# - target2 word list
# - label
<word list> \t <word list> \t <word list> \t <label>
```
比如:
```
苹果 六 袋 苹果 6s 新手 汽车 驾驶 1
新手 汽车 驾驶 驾校 培训 苹果 6s 1
```
## 执行训练
可以直接执行 `python train.py -y 0 --model_arch 0 --class_num 2` 使用 `./data/classification` 目录里的实例数据来测试能否直接运行训练分类FC模型。
其他模型结构也可以通过命令行实现定制,详细命令行参数请执行 `python train.py --help`进行查阅。
这里介绍最重要的几个参数:
- `train_data_path` 训练数据路径
- `test_data_path` 测试数据路局,可以不设置
- `source_dic_path` 源字典字典路径
- `target_dic_path` 目标字典路径
- `model_type` 模型的损失函数的类型,分类0,排序1,回归2
- `model_arch` 模型结构,FC 0, CNN 1, RNN 2
- `dnn_dims` 模型各层的维度设置,默认为 `256,128,64,32`,即模型有4层,各层维度如上设置
## 使用训练好的模型预测
详细命令行参数请执行 `python infer.py --help`进行查阅。重要参数解释如下:
- `data_path` 需要预测的数据路径
- `prediction_output_path` 预测的输出路径
## 参考文献
1. Huang P S, He X, Gao J, et al. Learning deep structured semantic models for web search using clickthrough data[C]//Proceedings of the 22nd ACM international conference on Conference on information & knowledge management. ACM, 2013: 2333-2338.
2. [Microsoft Learning to Rank Datasets](https://www.microsoft.com/en-us/research/project/mslr/)
3. [Gao J, He X, Deng L. Deep Learning for Web Search and Natural Language Processing[J]. Microsoft Research Technical Report, 2015.](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/wsdm2015.v3.pdf)
The minimum PaddlePaddle version needed for the code sample in this directory is v0.10.0. If you are on a version of PaddlePaddle earlier than v0.10.0, [please update your installation](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html).
---
# Deep Structured Semantic Models (DSSM)
Deep Structured Semantic Models (DSSM) is simple but powerful DNN based model for matching web search queries and the URL based documents. This example demonstrates how to use PaddlePaddle to implement a generic DSSM model for modeling the semantic similarity between two strings.
## Background Introduction
DSSM \[[1](##References)]is a classic semantic model proposed by the Institute of Physics. It is used to study the semantic distance between two texts. The general implementation of DSSM is as follows.
1. The CTR predictor measures the degree of association between a user search query and a candidate web page.
2. Text relevance, which measures the degree of semantic correlation between two strings.
3. Automatically recommend, measure the degree of association between User and the recommended Item.
## Model Architecture
In the original paper \[[1](#References)] the DSSM model uses the implicit semantic relation between the user search query and the document as metric. The model structure is as follows
<p align="center">
<img src="./images/dssm.png"/><br/><br/>
Figure 1. DSSM In the original paper
</p>
With the subsequent optimization of the DSSM model to simplify the structure \[[3](#References)],the model becomes:
<p align="center">
<img src="./images/dssm2.png" width="600"/><br/><br/>
Figure 2. DSSM generic structure
</p>
The blank box in the figure can be replaced by any model, such as fully connected FC, convoluted CNN, RNN, etc. The structure is designed to measure the semantic distance between two elements (such as strings).
In practice,DSSM model serves as a basic building block, with different loss functions to achieve specific functions, such as
- In ranking system, the pairwise rank loss function.
- In the CTR estimate, instead of the binary classification on the click, use cross-entropy loss for a classification model
- In regression model, the cosine similarity is used to calculate the similarity
## Model Implementation
At a high level, DSSM model is composed of three components: the left and right DNN, and loss function on top of them. In complex tasks, the structure of the left DNN and the light DNN can be different. In this example, we keep these two DNN structures the same. And we choose any of FC, CNN, and RNN for the DNN architecture.
In PaddlePaddle, the loss functions are supported for any of classification, regression, and ranking. Among them, the distance between the left and right DNN is calculated by the cosine similarity. In the classification task, the predicted distribution is calculated by softmax.
Here we demonstrate:
- How CNN, FC do text information extraction can refer to [text classification](https://github.com/PaddlePaddle/models/blob/develop/text_classification/README.md#模型详解)
- The contents of the RNN / GRU can be found in [Machine Translation](https://github.com/PaddlePaddle/book/blob/develop/08.machine_translation/README.md#gated-recurrent-unit-gru)
- For Pairwise Rank learning, please refer to [learn to rank](https://github.com/PaddlePaddle/models/blob/develop/ltr/README.md)
Figure 3 shows the general architecture for both regression and classification models.
<p align="center">
<img src="./images/dssm3.jpg"/><br/><br/>
Figure 3. DSSM for REGRESSION or CLASSIFICATION
</p>
The structure of the Pairwise Rank is more complex, as shown in Figure 4.
<p align="center">
<img src="./images/dssm2.jpg"/><br/><br/>
图 4. DSSM for Pairwise Rank
</p>
In below, we describe how to train DSSM model in PaddlePaddle. All the codes are included in `./network_conf.py`.
### Create a word vector table for the text
```python
def create_embedding(self, input, prefix=''):
"""
Create word embedding. The `prefix` is added in front of the name of
embedding"s learnable parameter.
"""
logger.info("Create embedding table [%s] whose dimention is %d" %
(prefix, self.dnn_dims[0]))
emb = paddle.layer.embedding(
input=input,
size=self.dnn_dims[0],
param_attr=ParamAttr(name='%s_emb.w' % prefix))
return emb
```
Since the input (embedding table) is a list of the IDs of the words corresponding to a sentence, the word vector table outputs the sequence of word vectors.
### CNN implementation
```python
def create_cnn(self, emb, prefix=''):
"""
A multi-layer CNN.
:param emb: The word embedding.
:type emb: paddle.layer
:param prefix: The prefix will be added to of layers' names.
:type prefix: str
"""
def create_conv(context_len, hidden_size, prefix):
key = "%s_%d_%d" % (prefix, context_len, hidden_size)
conv = paddle.networks.sequence_conv_pool(
input=emb,
context_len=context_len,
hidden_size=hidden_size,
# set parameter attr for parameter sharing
context_proj_param_attr=ParamAttr(name=key + "contex_proj.w"),
fc_param_attr=ParamAttr(name=key + "_fc.w"),
fc_bias_attr=ParamAttr(name=key + "_fc.b"),
pool_bias_attr=ParamAttr(name=key + "_pool.b"))
return conv
conv_3 = create_conv(3, self.dnn_dims[1], "cnn")
conv_4 = create_conv(4, self.dnn_dims[1], "cnn")
return paddle.layer.concat(input=[conv_3, conv_4])
```
CNN accepts the word sequence of the embedding table, then process the data by convolution and pooling, and finally outputs a semantic vector.
### RNN implementation
RNN is suitable for learning variable length of the information
```python
def create_rnn(self, emb, prefix=''):
"""
A GRU sentence vector learner.
"""
gru = paddle.networks.simple_gru(
input=emb,
size=self.dnn_dims[1],
mixed_param_attr=ParamAttr(name='%s_gru_mixed.w' % prefix),
mixed_bias_param_attr=ParamAttr(name="%s_gru_mixed.b" % prefix),
gru_param_attr=ParamAttr(name='%s_gru.w' % prefix),
gru_bias_attr=ParamAttr(name="%s_gru.b" % prefix))
sent_vec = paddle.layer.last_seq(gru)
return sent_vec
```
### FC implementation
```python
def create_fc(self, emb, prefix=''):
"""
A multi-layer fully connected neural networks.
:param emb: The output of the embedding layer
:type emb: paddle.layer
:param prefix: A prefix will be added to the layers' names.
:type prefix: str
"""
_input_layer = paddle.layer.pooling(
input=emb, pooling_type=paddle.pooling.Max())
fc = paddle.layer.fc(
input=_input_layer,
size=self.dnn_dims[1],
param_attr=ParamAttr(name='%s_fc.w' % prefix),
bias_attr=ParamAttr(name="%s_fc.b" % prefix))
return fc
```
In the construction of FC, we use `paddle.layer.pooling` for the maximum pooling operation on the word vector sequence. Then we transform the sequence into a fixed dimensional vector.
### Multi-layer DNN implementation
```python
def create_dnn(self, sent_vec, prefix):
if len(self.dnn_dims) > 1:
_input_layer = sent_vec
for id, dim in enumerate(self.dnn_dims[1:]):
name = "%s_fc_%d_%d" % (prefix, id, dim)
fc = paddle.layer.fc(
input=_input_layer,
size=dim,
act=paddle.activation.Tanh(),
param_attr=ParamAttr(name='%s.w' % name),
bias_attr=ParamAttr(name='%s.b' % name),
)
_input_layer = fc
return _input_layer
```
### Classification / Regression
The structure of classification and regression is similar. Below function can be used for both tasks.
Please check the function `_build_classification_or_regression_model` in [network_conf.py]( https://github.com/PaddlePaddle/models/blob/develop/dssm/network_conf.py) for detail implementation.
### Pairwise Rank
Please check the function `_build_rank_model` in [network_conf.py]( https://github.com/PaddlePaddle/models/blob/develop/dssm/network_conf.py) for implementation.
## Data Format
Below is a simple example for the data in `./data`
### Regression data format
```
# 3 fields each line:
# - source word list
# - target word list
# - target
<word list> \t <word list> \t <float>
```
The example of this format is as follows.
```
Six bags of apples Apple 6s 0.1
The new driver The driving school 0.9
```
### Classification data format
```
# 3 fields each line:
# - source word list
# - target word list
# - target
<word list> \t <word list> \t <label>
```
The example of this format is as follows.
```
Six bags of apples Apple 6s 0
The new driver The driving school 1
```
### Ranking data format
```
# 4 fields each line:
# - source word list
# - target1 word list
# - target2 word list
# - label
<word list> \t <word list> \t <word list> \t <label>
```
The example of this format is as follows.
```
Six bags of apples Apple 6s The new driver 1
The new driver The driving school Apple 6s 1
```
## Training
We use `python train.py -y 0 --model_arch 0 --class_num 2` with the data in `./data/classification` to train a DSSM model for classification. The paremeters to execute the script `train.py` can be found by execution `python infer.py --help`. Some important parameters are:
- `train_data_path` Training data path
- `test_data_path` Test data path, optional
- `source_dic_path` Source dictionary path
- `target_dic_path` Target dictionary path
- `model_type` The type of loss function of the model: classification 0, sort 1, regression 2
- `model_arch` Model structure: FC 0,CNN 1, RNN 2
- `dnn_dims` The dimension of each layer of the model is set, the default is `256,128,64,32`,with 4 layers.
## To predict using the trained model
The paremeters to execute the script `infer.py` can be found by execution `python infer.py --help`. Some important parameters are:
- `data_path` Path for the data to predict
- `prediction_output_path` Prediction output path
## References
1. Huang P S, He X, Gao J, et al. Learning deep structured semantic models for web search using clickthrough data[C]//Proceedings of the 22nd ACM international conference on Conference on information & knowledge management. ACM, 2013: 2333-2338.
2. [Microsoft Learning to Rank Datasets](https://www.microsoft.com/en-us/research/project/mslr/)
3. [Gao J, He X, Deng L. Deep Learning for Web Search and Natural Language Processing[J]. Microsoft Research Technical Report, 2015.](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/wsdm2015.v3.pdf)
苹果 苹果 6s 0
汽车 驾驶 驾校 培训 1
苹果 六 袋 苹果 6s 0
新手 汽车 驾驶 驾校 培训 1
苹果 六 袋 苹果 6s 新手 汽车 驾驶 1
新手 汽车 驾驶 驾校 培训 苹果 6s 0
苹果 六 袋 苹果 6s 新手 汽车 驾驶 1
新手 汽车 驾驶 驾校 培训 苹果 6s 1
UNK
苹果
6s
新手
汽车
驾驶
驾校
培训
\ No newline at end of file
import argparse
import itertools
import distutils.util
import reader
import paddle.v2 as paddle
from network_conf import DSSM
from utils import logger, ModelType, ModelArch, load_dic
parser = argparse.ArgumentParser(description="PaddlePaddle DSSM infer")
parser.add_argument(
"--model_path", type=str, required=True, help="The path of trained model.")
parser.add_argument(
"-i",
"--data_path",
type=str,
required=True,
help="The path of the data for inferring.")
parser.add_argument(
"-o",
"--prediction_output_path",
type=str,
required=True,
help="The path to save the predictions.")
parser.add_argument(
"-y",
"--model_type",
type=int,
required=True,
default=ModelType.CLASSIFICATION_MODE,
help=("The model type: %d for classification, %d for pairwise rank, "
"%d for regression (default: classification).") %
(ModelType.CLASSIFICATION_MODE, ModelType.RANK_MODE,
ModelType.REGRESSION_MODE))
parser.add_argument(
"-s",
"--source_dic_path",
type=str,
required=False,
help="The path of the source's word dictionary.")
parser.add_argument(
"--target_dic_path",
type=str,
required=False,
help=("The path of the target's word dictionary, "
"if this parameter is not set, the `source_dic_path` will be used."))
parser.add_argument(
"-a",
"--model_arch",
type=int,
required=True,
default=ModelArch.CNN_MODE,
help="model architecture, %d for CNN, %d for FC, %d for RNN" %
(ModelArch.CNN_MODE, ModelArch.FC_MODE, ModelArch.RNN_MODE))
parser.add_argument(
"--share_network_between_source_target",
type=distutils.util.strtobool,
default=False,
help="whether to share network parameters between source and target")
parser.add_argument(
"--share_embed",
type=distutils.util.strtobool,
default=False,
help="whether to share word embedding between source and target")
parser.add_argument(
"--dnn_dims",
type=str,
default="256,128,64,32",
help=("The dimentions of dnn layers, default is `256,128,64,32`, "
"which means a dnn with 4 layers with "
"dmentions 256, 128, 64 and 32 will be created."))
parser.add_argument(
"-c",
"--class_num",
type=int,
default=0,
help="The number of categories for classification task.")
args = parser.parse_args()
args.model_type = ModelType(args.model_type)
args.model_arch = ModelArch(args.model_arch)
if args.model_type.is_classification():
assert args.class_num > 1, ("The parameter class_num should be set "
"in classification task.")
layer_dims = map(int, args.dnn_dims.split(","))
args.target_dic_path = args.source_dic_path if not args.target_dic_path \
else args.target_dic_path
paddle.init(use_gpu=False, trainer_count=1)
class Inferer(object):
def __init__(self, param_path):
prediction = DSSM(
dnn_dims=layer_dims,
vocab_sizes=[
len(load_dic(path))
for path in [args.source_dic_path, args.target_dic_path]
],
model_type=args.model_type,
model_arch=args.model_arch,
share_semantic_generator=args.share_network_between_source_target,
class_num=args.class_num,
share_embed=args.share_embed,
is_infer=True)()
# load parameter
logger.info("Load the trained model from %s." % param_path)
self.parameters = paddle.parameters.Parameters.from_tar(
open(param_path, "r"))
self.inferer = paddle.inference.Inference(
output_layer=prediction, parameters=self.parameters)
def infer(self, data_path):
dataset = reader.Dataset(
train_path=data_path,
test_path=None,
source_dic_path=args.source_dic_path,
target_dic_path=args.target_dic_path,
model_type=args.model_type, )
infer_reader = paddle.batch(dataset.infer, batch_size=1000)
logger.warning("Write predictions to %s." % args.prediction_output_path)
output_f = open(args.prediction_output_path, "w")
for id, batch in enumerate(infer_reader()):
res = self.inferer.infer(input=batch)
predictions = [" ".join(map(str, x)) for x in res]
assert len(batch) == len(predictions), (
"Error! %d inputs are given, "
"but only %d predictions are returned.") % (len(batch),
len(predictions))
output_f.write("\n".join(map(str, predictions)) + "\n")
if __name__ == "__main__":
inferer = Inferer(args.model_path)
inferer.infer(args.data_path)
此差异已折叠。
from utils import UNK, ModelType, TaskType, load_dic, \
sent2ids, logger, ModelType
class Dataset(object):
def __init__(self, train_path, test_path, source_dic_path, target_dic_path,
model_type):
self.train_path = train_path
self.test_path = test_path
self.source_dic_path = source_dic_path
self.target_dic_path = target_dic_path
self.model_type = ModelType(model_type)
self.source_dic = load_dic(self.source_dic_path)
self.target_dic = load_dic(self.target_dic_path)
_record_reader = {
ModelType.CLASSIFICATION_MODE: self._read_classification_record,
ModelType.REGRESSION_MODE: self._read_regression_record,
ModelType.RANK_MODE: self._read_rank_record,
}
assert isinstance(model_type, ModelType)
self.record_reader = _record_reader[model_type.mode]
self.is_infer = False
def train(self):
'''
Load trainset.
'''
logger.info("[reader] load trainset from %s" % self.train_path)
with open(self.train_path) as f:
for line_id, line in enumerate(f):
yield self.record_reader(line)
def test(self):
'''
Load testset.
'''
with open(self.test_path) as f:
for line_id, line in enumerate(f):
yield self.record_reader(line)
def infer(self):
self.is_infer = True
with open(self.train_path) as f:
for line in f:
yield self.record_reader(line)
def _read_classification_record(self, line):
'''
data format:
<source words> [TAB] <target words> [TAB] <label>
@line: str
a string line which represent a record.
'''
fs = line.strip().split('\t')
assert len(fs) == 3, "wrong format for classification\n" + \
"the format shoud be " +\
"<source words> [TAB] <target words> [TAB] <label>'"
source = sent2ids(fs[0], self.source_dic)
target = sent2ids(fs[1], self.target_dic)
if not self.is_infer:
label = int(fs[2])
return (
source,
target,
label, )
return source, target
def _read_regression_record(self, line):
'''
data format:
<source words> [TAB] <target words> [TAB] <label>
@line: str
a string line which represent a record.
'''
fs = line.strip().split('\t')
assert len(fs) == 3, "wrong format for regression\n" + \
"the format shoud be " +\
"<source words> [TAB] <target words> [TAB] <label>'"
source = sent2ids(fs[0], self.source_dic)
target = sent2ids(fs[1], self.target_dic)
if not self.is_infer:
label = float(fs[2])
return (
source,
target,
[label], )
return source, target
def _read_rank_record(self, line):
'''
data format:
<source words> [TAB] <left_target words> [TAB] <right_target words> [TAB] <label>
'''
fs = line.strip().split('\t')
assert len(fs) == 4, "wrong format for rank\n" + \
"the format should be " +\
"<source words> [TAB] <left_target words> [TAB] <right_target words> [TAB] <label>"
source = sent2ids(fs[0], self.source_dic)
left_target = sent2ids(fs[1], self.target_dic)
right_target = sent2ids(fs[2], self.target_dic)
if not self.is_infer:
label = int(fs[3])
return (source, left_target, right_target, label)
return source, left_target, right_target
if __name__ == '__main__':
path = './data/classification/train.txt'
test_path = './data/classification/test.txt'
source_dic = './data/vocab.txt'
dataset = Dataset(path, test_path, source_dic, source_dic,
ModelType.CLASSIFICATION)
for rcd in dataset.train():
print rcd
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
#!/bin/bash
git clone https://github.com/chinese-poetry/chinese-poetry.git
if [ ! -d raw ]
then
mkdir raw
fi
mv chinese-poetry/json/poet.tang.* raw/
rm -rf chinese-poetry
此差异已折叠。
此差异已折叠。
# -*- coding: utf-8 -*-
import os
import io
import re
import json
import click
import collections
def build_vocabulary(dataset, cutoff=0):
dictionary = collections.defaultdict(int)
for data in dataset:
for sent in data[2]:
for char in sent:
dictionary[char] += 1
dictionary = filter(lambda x: x[1] >= cutoff, dictionary.items())
dictionary = sorted(dictionary, key=lambda x: (-x[1], x[0]))
vocab, _ = list(zip(*dictionary))
return (u"<s>", u"<e>", u"<unk>") + vocab
@click.command("preprocess")
@click.option("--datadir", type=str, help="Path to raw data")
@click.option("--outfile", type=str, help="Path to save the training data")
@click.option("--dictfile", type=str, help="Path to save the dictionary file")
def preprocess(datadir, outfile, dictfile):
dataset = []
note_pattern1 = re.compile(u"(.*?)", re.U)
note_pattern2 = re.compile(u"〖.*?〗", re.U)
note_pattern3 = re.compile(u"-.*?-。?", re.U)
note_pattern4 = re.compile(u"(.*$", re.U)
note_pattern5 = re.compile(u"。。.*)$", re.U)
note_pattern6 = re.compile(u"。。", re.U)
note_pattern7 = re.compile(u"[《》「」\[\]]", re.U)
print("Load raw data...")
for fn in os.listdir(datadir):
with io.open(os.path.join(datadir, fn), "r", encoding="utf8") as f:
for data in json.load(f):
title = data['title']
author = data['author']
p = "".join(data['paragraphs'])
p = "".join(p.split())
p = note_pattern1.sub(u"", p)
p = note_pattern2.sub(u"", p)
p = note_pattern3.sub(u"", p)
p = note_pattern4.sub(u"", p)
p = note_pattern5.sub(u"。", p)
p = note_pattern6.sub(u"。", p)
p = note_pattern7.sub(u"", p)
if (p == u"" or u"{" in p or u"}" in p or u"{" in p or
u"}" in p or u"、" in p or u":" in p or u";" in p or
u"!" in p or u"?" in p or u"●" in p or u"□" in p or
u"囗" in p or u")" in p):
continue
paragraphs = re.split(u"。|,", p)
paragraphs = filter(lambda x: len(x), paragraphs)
if len(paragraphs) > 1:
dataset.append((title, author, paragraphs))
print("Construct vocabularies...")
vocab = build_vocabulary(dataset, cutoff=10)
with io.open(dictfile, "w", encoding="utf8") as f:
for v in vocab:
f.write(v + "\n")
print("Write processed data...")
with io.open(outfile, "w", encoding="utf8") as f:
for data in dataset:
title = data[0]
author = data[1]
paragraphs = ".".join(data[2])
f.write("\t".join((title, author, paragraphs)) + "\n")
if __name__ == "__main__":
preprocess()
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
我们 不会 伤害 你 的 。 他们 也 这么 说 。
你 拥有 你 父亲 皇室 的 血统 。 是 合法 的 继承人 。
叫 什么 你 可以 告诉 我 。
你 并 没有 留言 说 要 去 哪里 。 是 的 , 因为 我 必须 要 去 完成 这件 事 。
你 查出 是 谁 住 在 隔壁 房间 吗 ?
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
#!/bin/bash
wget --no-check-certificate https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json -O train.json
wget --no-check-certificate https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json -O dev.json
wget http://nlp.stanford.edu/data/glove.840B.300d.zip
unzip glove.840B.300d.zip
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册