提交 7d501da2 编写于 作者: S seiriosPlus

add online learning

上级 18371de5
# PaddleRec 启动训练
# PaddleRec 流式训练(OnlineLearning)任务启动及配置流程
## 流式训练简介
流式训练是按照一定顺序进行数据的接收和处理,每接收一个数据,模型会对它进行预测并对当前模型进行更新,然后处理下一个数据。 像信息流、小视频、电商等场景,每天都会新增大量的数据, 让每天(每一刻)新增的数据基于上一天(上一刻)的模型进行新的预测和模型更新。
在大规模流式训练场景下, 需要使用的深度学习框架有对应的能力支持, 即:
* 支持大规模分布式训练的能力, 数据量巨大, 需要有良好的分布式训练及扩展能力,才能满足训练的时效要求
* 支持超大规模的Embedding, 能够支持十亿甚至千亿级别的Embedding, 拥有合理的参数输出的能力,能够快速输出模型参数并和线上其他系统进行对接
* Embedding的特征ID需要支持HASH映射,不要求ID的编码,能够自动增长及控制特征的准入(原先不存在的特征可以以适当的条件创建), 能够定期淘汰(能够以一定的策略进行过期的特征的清理) 并拥有准入及淘汰策略
* 最后就是要基于框架开发一套完备的流式训练的 trainer.py, 能够拥有完善的流式训练流程
## 使用PaddleRec内置的 online learning 进行模型的训练
目前,PaddleRec基于飞桨分布式训练框架的能力,实现了这套流式训练的流程。 供大家参考和使用。我们在`models/online_learning`目录下提供了一个ctr-dnn的online_training的版本,供大家更好的理解和参考。
**注意**
1. 使用online learning 需要安装目前Paddle最新的开发者版本, 你可以从 https://www.paddlepaddle.org.cn/documentation/docs/zh/install/Tables.html#whl-dev 此处获得它,需要先卸载当前已经安装的飞桨版本,根据自己的Python环境下载相应的安装包。
2. 使用流式训练及大规模稀疏的能力,需要对模型做一些微调, 因此需要你修改部分代码。
3. 当前只有参数服务器的分布式训练是支持带大规模稀疏的流式训练的,因此运行时,请直接选择参数服务器本地训练或集群训练方法。
## 启动方法
......
# 基于DNN模型的点击率预估模型
## 介绍
`CTR(Click Through Rate)`,即点击率,是“推荐系统/计算广告”等领域的重要指标,对其进行预估是商品推送/广告投放等决策的基础。简单来说,CTR预估对每次广告的点击情况做出预测,预测用户是点击还是不点击。CTR预估模型综合考虑各种因素、特征,在大量历史数据上训练,最终对商业决策提供帮助。本模型实现了下述论文中提出的DNN模型:
```text
@inproceedings{guo2017deepfm,
title={DeepFM: A Factorization-Machine based Neural Network for CTR Prediction},
author={Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li and Xiuqiang He},
booktitle={the Twenty-Sixth International Joint Conference on Artificial Intelligence (IJCAI)},
pages={1725--1731},
year={2017}
}
```
#
## 数据准备
### 数据来源
训练及测试数据集选用[Display Advertising Challenge](https://www.kaggle.com/c/criteo-display-ad-challenge/)所用的Criteo数据集。该数据集包括两部分:训练集和测试集。训练集包含一段时间内Criteo的部分流量,测试集则对应训练数据后一天的广告点击流量。
每一行数据格式如下所示:
```bash
<label> <integer feature 1> ... <integer feature 13> <categorical feature 1> ... <categorical feature 26>
```
其中```<label>```表示广告是否被点击,点击用1表示,未点击用0表示。```<integer feature>```代表数值特征(连续特征),共有13个连续特征。```<categorical feature>```代表分类特征(离散特征),共有26个离散特征。相邻两个特征用```\t```分隔,缺失特征用空格表示。测试集中```<label>```特征已被移除。
### 数据预处理
数据预处理共包括两步:
- 将原始训练集按9:1划分为训练集和验证集
- 数值特征(连续特征)需进行归一化处理,但需要注意的是,对每一个特征```<integer feature i>```,归一化时用到的最大值并不是用全局最大值,而是取排序后95%位置处的特征值作为最大值,同时保留极值。
### 一键下载训练及测试数据
```bash
sh download_data.sh
```
执行该脚本,会从国内源的服务器上下载Criteo数据集,并解压到指定文件夹。全量训练数据放置于`./train_data_full/`,全量测试数据放置于`./test_data_full/`,用于快速验证的训练数据与测试数据放置于`./train_data/``./test_data/`
执行该脚本的理想输出为:
```bash
> sh download_data.sh
--2019-11-26 06:31:33-- https://fleet.bj.bcebos.com/ctr_data.tar.gz
Resolving fleet.bj.bcebos.com... 10.180.112.31
Connecting to fleet.bj.bcebos.com|10.180.112.31|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4041125592 (3.8G) [application/x-gzip]
Saving to: “ctr_data.tar.gz”
100%[==================================================================================================================>] 4,041,125,592 120M/s in 32s
2019-11-26 06:32:05 (120 MB/s) - “ctr_data.tar.gz” saved [4041125592/4041125592]
raw_data/
raw_data/part-55
raw_data/part-113
...
test_data/part-227
test_data/part-222
Complete data download.
Full Train data stored in ./train_data_full
Full Test data stored in ./test_data_full
Rapid Verification train data stored in ./train_data
Rapid Verification test data stored in ./test_data
```
至此,我们已完成数据准备的全部工作。
## 数据读取
为了能高速运行CTR模型的训练,`PaddleRec`封装了`dataset``dataloader`API进行高性能的数据读取。
如何在我们的训练中引入dataset读取方式呢?无需变更数据格式,只需在我们的训练代码中加入以下内容,便可达到媲美二进制读取的高效率,以下是一个比较完整的流程:
### 引入dataset
1. 通过工厂类`fluid.DatasetFactory()`创建一个dataset对象。
2. 将我们定义好的数据输入格式传给dataset,通过`dataset.set_use_var(inputs)`实现。
3. 指定我们的数据读取方式,由`dataset_generator.py`实现数据读取的规则,后面将会介绍读取规则的实现。
4. 指定数据读取的batch_size。
5. 指定数据读取的线程数,该线程数和训练线程应保持一致,两者为耦合的关系。
6. 指定dataset读取的训练文件的列表。
```python
def get_dataset(inputs, args)
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_use_var(inputs)
dataset.set_pipe_command("python dataset_generator.py")
dataset.set_batch_size(args.batch_size)
dataset.set_thread(int(args.cpu_num))
file_list = [
str(args.train_files_path) + "/%s" % x
for x in os.listdir(args.train_files_path)
]
logger.info("file list: {}".format(file_list))
return dataset, file_list
```
### 如何指定数据读取规则
在上文我们提到了由`dataset_generator.py`实现具体的数据读取规则,那么,怎样为dataset创建数据读取的规则呢?
以下是`dataset_generator.py`的全部代码,具体流程如下:
1. 首先我们需要引入dataset的库,位于`paddle.fluid.incubate.data_generator`
2. 声明一些在数据读取中会用到的变量,如示例代码中的`cont_min_``categorical_range_`等。
3. 创建一个子类,继承dataset的基类,基类有多种选择,如果是多种数据类型混合,并且需要转化为数值进行预处理的,建议使用`MultiSlotDataGenerator`;若已经完成了预处理并保存为数据文件,可以直接以`string`的方式进行读取,使用`MultiSlotStringDataGenerator`,能够进一步加速。在示例代码,我们继承并实现了名为`CriteoDataset`的dataset子类,使用`MultiSlotDataGenerator`方法。
4. 继承并实现基类中的`generate_sample`函数,逐行读取数据。该函数应返回一个可以迭代的reader方法(带有yield的函数不再是一个普通的函数,而是一个生成器generator,成为了可以迭代的对象,等价于一个数组、链表、文件、字符串etc.)
5. 在这个可以迭代的函数中,如示例代码中的`def reader()`,我们定义数据读取的逻辑。例如对以行为单位的数据进行截取,转换及预处理。
6. 最后,我们需要将数据整理为特定的格式,才能够被dataset正确读取,并灌入的训练的网络中。简单来说,数据的输出顺序与我们在网络中创建的`inputs`必须是严格一一对应的,并转换为类似字典的形式。在示例代码中,我们使用`zip`的方法将参数名与数值构成的元组组成了一个list,并将其yield输出。如果展开来看,我们输出的数据形如`[('dense_feature',[value]),('C1',[value]),('C2',[value]),...,('C26',[value]),('label',[value])]`
```python
import paddle.fluid.incubate.data_generator as dg
cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
cont_max_ = [20, 600, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
cont_diff_ = [20, 603, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
hash_dim_ = 1000001
continuous_range_ = range(1, 14)
categorical_range_ = range(14, 40)
class CriteoDataset(dg.MultiSlotDataGenerator):
def generate_sample(self, line):
def reader():
features = line.rstrip('\n').split('\t')
dense_feature = []
sparse_feature = []
for idx in continuous_range_:
if features[idx] == "":
dense_feature.append(0.0)
else:
dense_feature.append(
(float(features[idx]) - cont_min_[idx - 1]) /
cont_diff_[idx - 1])
for idx in categorical_range_:
sparse_feature.append(
[hash(str(idx) + features[idx]) % hash_dim_])
label = [int(features[0])]
process_line = dense_feature, sparse_feature, label
feature_name = ["dense_feature"]
for idx in categorical_range_:
feature_name.append("C" + str(idx - 13))
feature_name.append("label")
yield zip(feature_name, [dense_feature] + sparse_feature + [label])
return reader
d = CriteoDataset()
d.run_from_stdin()
```
### 快速调试Dataset
我们可以脱离组网架构,单独验证Dataset的输出是否符合我们预期。使用命令
`cat 数据文件 | python dataset读取python文件`进行dataset代码的调试:
```bash
cat train_data/part-0 | python dataset_generator.py
```
输出的数据格式如下:
` dense_input:size ; dense_input:value ; sparse_input:size ; sparse_input:value ; ... ; sparse_input:size ; sparse_input:value ; label:size ; label:value `
理想的输出为(截取了一个片段):
```bash
...
13 0.05 0.00663349917081 0.05 0.0 0.02159375 0.008 0.15 0.04 0.362 0.1 0.2 0.0 0.04 1 715353 1 817085 1 851010 1 833725 1 286835 1 948614 1 881652 1 507110 1 27346 1 646986 1 643076 1 200960 1 18464 1 202774 1 532679 1 729573 1 342789 1 562805 1 880474 1 984402 1 666449 1 26235 1 700326 1 452909 1 884722 1 787527 1 0
...
```
#
## 模型组网
### 数据输入声明
正如数据准备章节所介绍,Criteo数据集中,分为连续数据与离散(稀疏)数据,所以整体而言,CTR-DNN模型的数据输入层包括三个,分别是:`dense_input`用于输入连续数据,维度由超参数`dense_feature_dim`指定,数据类型是归一化后的浮点型数据。`sparse_input_ids`用于记录离散数据,在Criteo数据集中,共有26个slot,所以我们创建了名为`C1~C26`的26个稀疏参数输入,并设置`lod_level=1`,代表其为变长数据,数据类型为整数;最后是每条样本的`label`,代表了是否被点击,数据类型是整数,0代表负样例,1代表正样例。
在Paddle中数据输入的声明使用`paddle.fluid.data()`,会创建指定类型的占位符,数据IO会依据此定义进行数据的输入。
```python
dense_input = fluid.data(name="dense_input",
shape=[-1, args.dense_feature_dim],
dtype="float32")
sparse_input_ids = [
fluid.data(name="C" + str(i),
shape=[-1, 1],
lod_level=1,
dtype="int64") for i in range(1, 27)
]
label = fluid.data(name="label", shape=[-1, 1], dtype="int64")
inputs = [dense_input] + sparse_input_ids + [label]
```
### CTR-DNN模型组网
CTR-DNN模型的组网比较直观,本质是一个二分类任务,代码参考`model.py`。模型主要组成是一个`Embedding`层,三个`FC`层,以及相应的分类任务的loss计算和auc计算。
#### Embedding层
首先介绍Embedding层的搭建方式:`Embedding`层的输入是`sparse_input`,shape由超参的`sparse_feature_dim``embedding_size`定义。需要特别解释的是`is_sparse`参数,当我们指定`is_sprase=True`后,计算图会将该参数视为稀疏参数,反向更新以及分布式通信时,都以稀疏的方式进行,会极大的提升运行效率,同时保证效果一致。
各个稀疏的输入通过Embedding层后,将其合并起来,置于一个list内,以方便进行concat的操作。
```python
def embedding_layer(input):
return fluid.layers.embedding(
input=input,
is_sparse=True,
size=[args.sparse_feature_dim,
args.embedding_size],
param_attr=fluid.ParamAttr(
name="SparseFeatFactors",
initializer=fluid.initializer.Uniform()),
)
sparse_embed_seq = list(map(embedding_layer, inputs[1:-1])) # [C1~C26]
```
#### FC层
将离散数据通过embedding查表得到的值,与连续数据的输入进行`concat`操作,合为一个整体输入,作为全链接层的原始输入。我们共设计了3层FC,每层FC的输出维度都为400,每层FC都后接一个`relu`激活函数,每层FC的初始化方式为符合正态分布的随机初始化,标准差与上一层的输出维度的平方根成反比。
```python
concated = fluid.layers.concat(sparse_embed_seq + inputs[0:1], axis=1)
fc1 = fluid.layers.fc(
input=concated,
size=400,
act="relu",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(concated.shape[1]))),
)
fc2 = fluid.layers.fc(
input=fc1,
size=400,
act="relu",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fc1.shape[1]))),
)
fc3 = fluid.layers.fc(
input=fc2,
size=400,
act="relu",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fc2.shape[1]))),
)
```
#### Loss及Auc计算
- 预测的结果通过一个输出shape为2的FC层给出,该FC层的激活函数是softmax,会给出每条样本分属于正负样本的概率。
- 每条样本的损失由交叉熵给出,交叉熵的输入维度为[batch_size,2],数据类型为float,label的输入维度为[batch_size,1],数据类型为int。
- 该batch的损失`avg_cost`是各条样本的损失之和
- 我们同时还会计算预测的auc,auc的结果由`fluid.layers.auc()`给出,该层的返回值有三个,分别是全局auc: `auc_var`,当前batch的auc: `batch_auc_var`,以及auc_states: `auc_states`,auc_states包含了`batch_stat_pos, batch_stat_neg, stat_pos, stat_neg`信息。`batch_auc`我们取近20个batch的平均,由参数`slide_steps=20`指定,roc曲线的离散化的临界数值设置为4096,由`num_thresholds=2**12`指定。
```
predict = fluid.layers.fc(
input=fc3,
size=2,
act="softmax",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fc3.shape[1]))),
)
cost = fluid.layers.cross_entropy(input=predict, label=inputs[-1])
avg_cost = fluid.layers.reduce_sum(cost)
accuracy = fluid.layers.accuracy(input=predict, label=inputs[-1])
auc_var, batch_auc_var, auc_states = fluid.layers.auc(
input=predict,
label=inputs[-1],
num_thresholds=2**12,
slide_steps=20)
```
完成上述组网后,我们最终可以通过训练拿到`avg_cost``auc`两个重要指标。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
backend: "PaddleCloud"
cluster_type: k8s # mpi 可选
config:
fs_name: "afs://xxx.com"
fs_ugi: "usr,pwd"
output_path: "" # 填远程地址,如afs:/user/your/path/ 则此处填 /user/your/path
# for mpi
train_data_path: "" # 填远程地址,如afs:/user/your/path/ 则此处填 /user/your/path
test_data_path: "" # 填远程地址,如afs:/user/your/path/ 则此处填 /user/your/path
thirdparty_path: "" # 填远程地址,如afs:/user/your/path/ 则此处填 /user/your/path
paddle_version: "1.7.2" # 填写paddle官方版本号 >= 1.7.2
# for k8s
afs_remote_mount_point: "" # 填远程地址,如afs:/user/your/path/ 则此处填 /user/your/path
# paddle分布式底层超参,无特殊需求不理不改
communicator:
FLAGS_communicator_is_sgd_optimizer: 0
FLAGS_communicator_send_queue_size: 5
FLAGS_communicator_thread_pool_size: 32
FLAGS_communicator_max_merge_var_num: 5
FLAGS_communicator_max_send_grad_num_before_recv: 5
FLAGS_communicator_fake_rpc: 0
FLAGS_rpc_retry_times: 3
submit:
ak: ""
sk: ""
priority: "high"
job_name: "PaddleRec_CTR"
group: ""
start_cmd: "python -m paddlerec.run -m ./config.yaml"
files: ./*.py ./*.yaml
# for mpi ps-cpu
nodes: 2
# for k8s gpu
k8s_trainers: 2
k8s_cpu_cores: 2
k8s_gpu_card: 1
# for k8s ps-cpu
k8s_trainers: 2
k8s_cpu_cores: 4
k8s_ps_num: 2
k8s_ps_cores: 4
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# workspace
workspace: "models/rank/dnn"
# list of dataset
dataset:
- name: dataloader_train # name of dataset to distinguish different datasets
batch_size: 2
type: DataLoader # or QueueDataset
data_path: "{workspace}/data/sample_data/train"
sparse_slots: "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26"
dense_slots: "dense_var:13"
- name: dataset_train # name of dataset to distinguish different datasets
batch_size: 2
type: QueueDataset # or DataLoader
data_path: "{workspace}/data/sample_data/train"
sparse_slots: "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26"
dense_slots: "dense_var:13"
- name: dataset_infer # name
batch_size: 2
type: DataLoader # or QueueDataset
data_path: "{workspace}/data/sample_data/train"
sparse_slots: "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26"
dense_slots: "dense_var:13"
# hyper parameters of user-defined network
hyper_parameters:
# optimizer config
optimizer:
class: Adam
learning_rate: 0.001
strategy: async
# user-defined <key, value> pairs
sparse_inputs_slots: 27
sparse_feature_number: 1000001
sparse_feature_dim: 9
dense_input_dim: 13
fc_sizes: [512, 256, 128, 32]
# select runner by name
mode: [ps_cluster, single_cpu_infer]
# config of each runner.
# runner is a kind of paddle training class, which wraps the train/infer process.
runner:
- name: single_cpu_infer
class: infer
# num of epochs
epochs: 1
# device to run training or infer
device: cpu
init_model_path: "increment_dnn" # load model path
phases: [phase2]
- name: ps_cluster
class: cluster_train
epochs: 2
device: cpu
fleet_mode: ps
save_checkpoint_interval: 1 # save model interval of epochs
save_checkpoint_path: "increment_dnn" # save checkpoint path
init_model_path: "" # load model path
print_interval: 1
phases: [phase1]
# runner will run all the phase in each epoch
phase:
- name: phase1
model: "{workspace}/model.py" # user-defined model
dataset_name: dataloader_train # select dataset by name
thread_num: 1
- name: phase2
model: "{workspace}/model.py" # user-defined model
dataset_name: dataset_infer # select dataset by name
thread_num: 1
wget --no-check-certificate https://fleet.bj.bcebos.com/ctr_data.tar.gz
tar -zxvf ctr_data.tar.gz
mv ./raw_data ./train_data_full
mkdir train_data && cd train_data
cp ../train_data_full/part-0 ../train_data_full/part-1 ./ && cd ..
mv ./test_data ./test_data_full
mkdir test_data && cd test_data
cp ../test_data_full/part-220 ./ && cd ..
echo "Complete data download."
echo "Full Train data stored in ./train_data_full "
echo "Full Test data stored in ./test_data_full "
echo "Rapid Verification train data stored in ./train_data "
echo "Rapid Verification test data stored in ./test_data "
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid.incubate.data_generator as dg
cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
cont_max_ = [20, 600, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
cont_diff_ = [20, 603, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
hash_dim_ = 1000001
continuous_range_ = range(1, 14)
categorical_range_ = range(14, 40)
class CriteoDataset(dg.MultiSlotDataGenerator):
"""
DacDataset: inheritance MultiSlotDataGeneratior, Implement data reading
Help document: http://wiki.baidu.com/pages/viewpage.action?pageId=728820675
"""
def generate_sample(self, line):
"""
Read the data line by line and process it as a dictionary
"""
def reader():
"""
This function needs to be implemented by the user, based on data format
"""
features = line.rstrip('\n').split('\t')
dense_feature = []
sparse_feature = []
for idx in continuous_range_:
if features[idx] == "":
dense_feature.append(0.0)
else:
dense_feature.append(
(float(features[idx]) - cont_min_[idx - 1]) /
cont_diff_[idx - 1])
for idx in categorical_range_:
sparse_feature.append(
[hash(str(idx) + features[idx]) % hash_dim_])
label = [int(features[0])]
process_line = dense_feature, sparse_feature, label
feature_name = ["dense_feature"]
for idx in categorical_range_:
feature_name.append("C" + str(idx - 13))
feature_name.append("label")
s = "click:" + str(label[0])
for i in dense_feature:
s += " dense_feature:" + str(i)
for i in range(1, 1 + len(categorical_range_)):
s += " " + str(i) + ":" + str(sparse_feature[i - 1][0])
print(s.strip())
yield None
return reader
d = CriteoDataset()
d.run_from_stdin()
sh download.sh
mkdir slot_train_data_full
for i in `ls ./train_data_full`
do
cat train_data_full/$i | python get_slot_data.py > slot_train_data_full/$i
done
mkdir slot_test_data_full
for i in `ls ./test_data_full`
do
cat test_data_full/$i | python get_slot_data.py > slot_test_data_full/$i
done
mkdir slot_train_data
for i in `ls ./train_data`
do
cat train_data/$i | python get_slot_data.py > slot_train_data/$i
done
mkdir slot_test_data
for i in `ls ./test_data`
do
cat test_data/$i | python get_slot_data.py > slot_test_data/$i
done
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle.fluid as fluid
from paddlerec.core.utils import envs
from paddlerec.core.model import ModelBase
class Model(ModelBase):
def __init__(self, config):
ModelBase.__init__(self, config)
def _init_hyper_parameters(self):
self.is_distributed = True if envs.get_fleet_mode().upper(
) == "PSLIB" else False
self.sparse_feature_number = envs.get_global_env(
"hyper_parameters.sparse_feature_number")
self.sparse_feature_dim = envs.get_global_env(
"hyper_parameters.sparse_feature_dim")
self.learning_rate = envs.get_global_env(
"hyper_parameters.optimizer.learning_rate")
def net(self, input, is_infer=False):
self.sparse_inputs = self._sparse_data_var[1:]
self.dense_input = self._dense_data_var[0]
self.label_input = self._sparse_data_var[0]
def embedding_layer(input):
emb = fluid.contrib.layers.sparse_embedding(
input=input,
is_test=False,
# for distributed sparse embedding, dim0 just fake.
size=[1024, self.sparse_feature_dim],
param_attr=fluid.ParamAttr(
name="SparseFeatFactors",
initializer=fluid.initializer.Uniform()), )
emb_sum = fluid.layers.sequence_pool(input=emb, pool_type='sum')
return emb_sum
sparse_embed_seq = list(map(embedding_layer, self.sparse_inputs))
concated = fluid.layers.concat(
sparse_embed_seq + [self.dense_input], axis=1)
fcs = [concated]
hidden_layers = envs.get_global_env("hyper_parameters.fc_sizes")
for size in hidden_layers:
output = fluid.layers.fc(
input=fcs[-1],
size=size,
act='relu',
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Normal(
scale=1.0 / math.sqrt(fcs[-1].shape[1]))))
fcs.append(output)
predict = fluid.layers.fc(
input=fcs[-1],
size=2,
act="softmax",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fcs[-1].shape[1]))))
self.predict = predict
auc, batch_auc, _ = fluid.layers.auc(input=self.predict,
label=self.label_input,
num_thresholds=2**12,
slide_steps=20)
if is_infer:
self._infer_results["AUC"] = auc
self._infer_results["BATCH_AUC"] = batch_auc
return
self._metrics["AUC"] = auc
self._metrics["BATCH_AUC"] = batch_auc
cost = fluid.layers.cross_entropy(
input=self.predict, label=self.label_input)
avg_cost = fluid.layers.reduce_mean(cost)
self._cost = avg_cost
def optimizer(self):
optimizer = fluid.optimizer.Adam(self.learning_rate, lazy_mode=True)
return optimizer
def infer_net(self):
pass
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import time
import warnings
import numpy as np
import logging
import paddle.fluid as fluid
from paddlerec.core.utils import envs
from paddlerec.core.metric import Metric
from paddlerec.core.trainers.framework.runner import RunnerBase
logging.basicConfig(
format='%(asctime)s - %(levelname)s: %(message)s', level=logging.INFO)
__all__ = [
"RunnerBase", "SingleRunner", "PSRunner", "CollectiveRunner", "PslibRunner"
]
def as_numpy(tensor):
"""
Convert a Tensor to a numpy.ndarray, its only support Tensor without LoD information.
For higher dimensional sequence data, please use LoDTensor directly.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy
new_scope = fluid.Scope()
with fluid.scope_guard(new_scope):
fluid.global_scope().var("data").get_tensor().set(numpy.ones((2, 2)), fluid.CPUPlace())
tensor = new_scope.find_var("data").get_tensor()
fluid.executor.as_numpy(tensor) # or numpy.array(new_scope.find_var("data").get_tensor())
Args:
tensor(Variable): a instance of Tensor
Returns:
numpy.ndarray
"""
if isinstance(tensor, fluid.core.LoDTensorArray):
return [as_numpy(t) for t in tensor]
if isinstance(tensor, list):
return [as_numpy(t) for t in tensor]
assert isinstance(tensor, fluid.core.LoDTensor)
lod = tensor.lod()
# (todo) need print lod or return it for user
if tensor._is_initialized():
return np.array(tensor)
else:
return None
class OnlineLearningRunner(RunnerBase):
def __init__(self, context):
print("Running OnlineLearningRunner.")
def run(self, context):
epochs = int(
envs.get_global_env("runner." + context["runner_name"] +
".epochs"))
model_dict = context["env"]["phase"][0]
model_class = context["model"][model_dict["name"]]["model"]
metrics = model_class._metrics
dataset_list = []
dataset_index = 0
for day_index in range(len(days)):
day = days[day_index]
cur_path = "%s/%s" % (path, str(day))
fleet_util.rank0_print("dataset_index=%s, path=%s" %
(dataset_index, cur_path))
filelist = fleet.split_files(hdfs_ls([cur_path]))
dataset = create_dataset(use_var, filelist)
dataset_list.append(dataset)
dataset_index += 1
dataset_index = 0
for epoch in range(len(days)):
day = days[day_index]
begin_time = time.time()
result = self._run(context, model_dict)
end_time = time.time()
seconds = end_time - begin_time
message = "epoch {} done, use time: {}".format(epoch, seconds)
# TODO, wait for PaddleCloudRoleMaker supports gloo
from paddle.fluid.incubate.fleet.base.role_maker import GeneralRoleMaker
if context["fleet"] is not None and isinstance(context["fleet"],
GeneralRoleMaker):
metrics_result = []
for key in metrics:
if isinstance(metrics[key], Metric):
_str = metrics[key].calc_global_metrics(
context["fleet"],
context["model"][model_dict["name"]]["scope"])
metrics_result.append(_str)
elif result is not None:
_str = "{}={}".format(key, result[key])
metrics_result.append(_str)
if len(metrics_result) > 0:
message += ", global metrics: " + ", ".join(metrics_result)
print(message)
with fluid.scope_guard(context["model"][model_dict["name"]][
"scope"]):
train_prog = context["model"][model_dict["name"]][
"main_program"]
startup_prog = context["model"][model_dict["name"]][
"startup_program"]
with fluid.program_guard(train_prog, startup_prog):
self.save(epoch, context, True)
context["status"] = "terminal_pass"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册