提交 16a48fc4 编写于 作者: O overlordmax

add fibinet

上级 997f8cb5
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# workspace
workspace: "paddlerec.models.rank.fibinet"
# list of dataset
dataset:
- name: dataloader_train # name of dataset to distinguish different datasets
batch_size: 2
type: DataLoader # or QueueDataset
data_path: "{workspace}/data/sample_data/train"
sparse_slots: "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26"
dense_slots: "dense_var:13"
- name: dataset_train # name of dataset to distinguish different datasets
batch_size: 2
type: QueueDataset # or DataLoader
data_path: "{workspace}/data/sample_data/train"
sparse_slots: "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26"
dense_slots: "dense_var:13"
- name: dataset_infer # name
batch_size: 2
type: DataLoader # or QueueDataset
data_path: "{workspace}/data/sample_data/train"
sparse_slots: "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26"
dense_slots: "dense_var:13"
# hyper parameters of user-defined network
hyper_parameters:
# optimizer config
optimizer:
class: Adam
learning_rate: 0.001
strategy: async
# user-defined <key, value> pairs
sparse_inputs_slots: 27
sparse_feature_number: 1000001
sparse_feature_dim: 9
dense_input_dim: 13
bilinear_type: 'all'
reduction_ratio: 3
dropout_rate: 0.5
# select runner by name
mode: [single_cpu_train, single_cpu_infer]
# config of each runner.
# runner is a kind of paddle training class, which wraps the train/infer process.
runner:
- name: single_cpu_train
class: train
# num of epochs
epochs: 4
# device to run training or infer
device: cpu
save_checkpoint_interval: 2 # save model interval of epochs
save_inference_interval: 4 # save inference
save_checkpoint_path: "increment_model" # save checkpoint path
save_inference_path: "inference" # save inference path
save_inference_feed_varnames: [] # feed vars of save inference
save_inference_fetch_varnames: [] # fetch vars of save inference
init_model_path: "" # load model path
print_interval: 10
phases: [phase1]
- name: single_cpu_infer
class: infer
# num of epochs
epochs: 1
# device to run training or infer
device: cpu
init_model_path: "increment_model" # load model path
phases: [phase2]
# runner will run all the phase in each epoch
phase:
- name: phase1
model: "{workspace}/model.py" # user-defined model
dataset_name: dataloader_train # select dataset by name
thread_num: 1
- name: phase2
model: "{workspace}/model.py" # user-defined model
dataset_name: dataset_infer # select dataset by name
thread_num: 1
wget --no-check-certificate https://fleet.bj.bcebos.com/ctr_data.tar.gz
tar -zxvf ctr_data.tar.gz
mv ./raw_data ./train_data_full
mkdir train_data && cd train_data
cp ../train_data_full/part-0 ../train_data_full/part-1 ./ && cd ..
mv ./test_data ./test_data_full
mkdir test_data && cd test_data
cp ../test_data_full/part-220 ./ && cd ..
echo "Complete data download."
echo "Full Train data stored in ./train_data_full "
echo "Full Test data stored in ./test_data_full "
echo "Rapid Verification train data stored in ./train_data "
echo "Rapid Verification test data stored in ./test_data "
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid.incubate.data_generator as dg
cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
cont_max_ = [20, 600, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
cont_diff_ = [20, 603, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
hash_dim_ = 1000001
continuous_range_ = range(1, 14)
categorical_range_ = range(14, 40)
class CriteoDataset(dg.MultiSlotDataGenerator):
"""
DacDataset: inheritance MultiSlotDataGeneratior, Implement data reading
Help document: http://wiki.baidu.com/pages/viewpage.action?pageId=728820675
"""
def generate_sample(self, line):
"""
Read the data line by line and process it as a dictionary
"""
def reader():
"""
This function needs to be implemented by the user, based on data format
"""
features = line.rstrip('\n').split('\t')
dense_feature = []
sparse_feature = []
for idx in continuous_range_:
if features[idx] == "":
dense_feature.append(0.0)
else:
dense_feature.append(
(float(features[idx]) - cont_min_[idx - 1]) /
cont_diff_[idx - 1])
for idx in categorical_range_:
sparse_feature.append(
[hash(str(idx) + features[idx]) % hash_dim_])
label = [int(features[0])]
process_line = dense_feature, sparse_feature, label
feature_name = ["dense_feature"]
for idx in categorical_range_:
feature_name.append("C" + str(idx - 13))
feature_name.append("label")
s = "click:" + str(label[0])
for i in dense_feature:
s += " dense_feature:" + str(i)
for i in range(1, 1 + len(categorical_range_)):
s += " " + str(i) + ":" + str(sparse_feature[i - 1][0])
print(s.strip())
yield None
return reader
d = CriteoDataset()
d.run_from_stdin()
sh download.sh
mkdir slot_train_data_full
for i in `ls ./train_data_full`
do
cat train_data_full/$i | python get_slot_data.py > slot_train_data_full/$i
done
mkdir slot_test_data_full
for i in `ls ./test_data_full`
do
cat test_data_full/$i | python get_slot_data.py > slot_test_data_full/$i
done
mkdir slot_train_data
for i in `ls ./train_data`
do
cat train_data/$i | python get_slot_data.py > slot_train_data/$i
done
mkdir slot_test_data
for i in `ls ./test_data`
do
cat test_data/$i | python get_slot_data.py > slot_test_data/$i
done
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import itertools
from paddlerec.core.utils import envs
from paddlerec.core.model import ModelBase
class Model(ModelBase):
def __init__(self, config):
ModelBase.__init__(self, config)
def _init_hyper_parameters(self):
self.is_distributed = True if envs.get_fleet_mode().upper(
) == "PSLIB" else False
self.sparse_feature_number = envs.get_global_env(
"hyper_parameters.sparse_feature_number")
self.sparse_feature_dim = envs.get_global_env(
"hyper_parameters.sparse_feature_dim")
self.learning_rate = envs.get_global_env(
"hyper_parameters.optimizer.learning_rate")
def _SENETLayer(self, inputs, filed_size, reduction_ratio=3):
reduction_size = max(1, filed_size // reduction_ratio)
Z = fluid.layers.reduce_mean(inputs, dim=-1)
A_1 = fluid.layers.fc(
input=Z,
size=reduction_size,
param_attr=fluid.initializer.Xavier(uniform=False),
act='relu',
name='W_1')
A_2 = fluid.layers.fc(
input=A_1,
size=filed_size,
param_attr=fluid.initializer.Xavier(uniform=False),
act='relu',
name='W_2')
V = fluid.layers.elementwise_mul(
inputs, y=fluid.layers.unsqueeze(
input=A_2, axes=[2]))
return fluid.layers.split(V, num_or_sections=filed_size, dim=1)
def _BilinearInteraction(self,
inputs,
filed_size,
embedding_size,
bilinear_type="interaction"):
if bilinear_type == "all":
p = [
fluid.layers.elementwise_mul(
fluid.layers.fc(
input=v_i,
size=embedding_size,
param_attr=fluid.initializer.Xavier(uniform=False),
act=None,
name=None),
fluid.layers.squeeze(
input=v_j, axes=[1]))
for v_i, v_j in itertools.combinations(inputs, 2)
]
else:
raise NotImplementedError
return fluid.layers.concat(input=p, axis=1)
def _DNNLayer(self, inputs, dropout_rate=0.5):
deep_input = inputs
for i, hidden_unit in enumerate([400, 400, 400]):
fc_out = fluid.layers.fc(
input=deep_input,
size=hidden_unit,
param_attr=fluid.initializer.Xavier(uniform=False),
act='relu',
name='d_' + str(i))
fc_out = fluid.layers.dropout(fc_out, dropout_prob=dropout_rate)
deep_input = fc_out
return deep_input
def net(self, input, is_infer=False):
self.sparse_inputs = self._sparse_data_var[1:]
self.dense_input = self._dense_data_var[0]
self.label_input = self._sparse_data_var[0]
emb = []
for data in self.sparse_inputs:
feat_emb = fluid.embedding(
input=data,
size=[self.sparse_feature_number, self.sparse_feature_dim],
param_attr=fluid.ParamAttr(
name='dis_emb',
learning_rate=5,
initializer=fluid.initializer.Xavier(
fan_in=self.sparse_feature_dim,
fan_out=self.sparse_feature_dim)),
is_sparse=True)
emb.append(feat_emb)
concat_emb = fluid.layers.concat(emb, axis=1)
filed_size = len(self.sparse_inputs)
bilinear_type = envs.get_global_env("hyper_parameters.bilinear_type")
reduction_ratio = envs.get_global_env(
"hyper_parameters.reduction_ratio")
dropout_rate = envs.get_global_env("hyper_parameters.dropout_rate")
senet_output = self._SENETLayer(concat_emb, filed_size,
reduction_ratio)
senet_bilinear_out = self._BilinearInteraction(
senet_output, filed_size, self.sparse_feature_dim, bilinear_type)
concat_emb = fluid.layers.split(
concat_emb, num_or_sections=filed_size, dim=1)
bilinear_out = self._BilinearInteraction(
concat_emb, filed_size, self.sparse_feature_dim, bilinear_type)
dnn_input = fluid.layers.concat(
input=[senet_bilinear_out, bilinear_out, self.dense_input], axis=1)
dnn_output = self._DNNLayer(dnn_input, dropout_rate)
y_pred = fluid.layers.fc(
input=dnn_output,
size=1,
param_attr=fluid.initializer.Xavier(uniform=False),
act='sigmoid',
name='logit')
self.predict = y_pred
auc, batch_auc, _ = fluid.layers.auc(input=self.predict,
label=self.label_input,
num_thresholds=2**12,
slide_steps=20)
if is_infer:
self._infer_results["AUC"] = auc
self._infer_results["BATCH_AUC"] = batch_auc
return
self._metrics["AUC"] = auc
self._metrics["BATCH_AUC"] = batch_auc
cost = fluid.layers.log_loss(
input=self.predict,
label=fluid.layers.cast(
x=self.label_input, dtype='float32'))
avg_cost = fluid.layers.reduce_mean(cost)
self._cost = avg_cost
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册