diff --git a/PaddleRec/ctr/fibinet/README.md b/PaddleRec/ctr/fibinet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2519bab968152a1f0a20a8664749cb758b9f9492 --- /dev/null +++ b/PaddleRec/ctr/fibinet/README.md @@ -0,0 +1,134 @@ +# Fibinet + + 以下是本例的简要目录结构及说明: + +``` +├── README.md # 文档 +├── requirements.txt # 需要的安装包 +├── net.py # Fibinet网络文件 +├── feed_generator.py # 数据读取文件 +├── args.py # 参数脚本 +├── get_data.sh # 生成训练数据脚本 +├── train.py # 训练文件 +├── infer.py # 预测文件 +├── train_gpu.sh # gpu训练shell脚本 +├── train_cpu.sh # cpu训练shell脚本 +├── infer_gpu.sh # gpu预测shell脚本 +├── infer_cpu.sh # cpu预测shell脚本 +``` + +## 简介 + +[《FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction》]( https://arxiv.org/pdf/1905.09433.pdf)是新浪微博机器学习团队发表在RecSys19上的一篇论文,文章指出当前的许多通过特征组合进行CTR预估的工作主要使用特征向量的内积或哈达玛积来计算交叉特征,这种方法忽略了特征本身的重要程度。提出通过使用Squeeze-Excitation network (SENET) 结构动态学习特征的重要性以及使用一个双线性函数来更好的建模交叉特征。 +本项目在paddlepaddle上实现FibiNET的网络结构,并在开源数据集Criteo上验证模型效果, + +## 数据下载及预处理 + +数据地址:[Criteo]( https://fleet.bj.bcebos.com/ctr_data.tar.gz) + +(1)将原始训练集按9:1划分为训练集和验证集 + +(2)数值特征(连续特征)进行归一化处理 + +## 环境 + + PaddlePaddle 1.7.0 + + python3.7 + +## 单机训练 + +GPU环境 + +在train_gpu.sh脚本文件中设置好数据路径、参数。 + +```sh +CUDA_VISIBLE_DEVICES=0 python train.py --use_gpu 1 \ #使用gpu + --train_files_path ./train_data_full \ #全量训练数据 + --model_dir ./model_dir \ #模型路径 + --learning_rate 0.001 \ + --batch_size 1000 \ + --epochs 10 \ + --reduction_ratio 3 \ #SENET超参数 + --dropout_rate 0.5 + --embedding_size 10 +``` + +修改脚本的可执行权限并运行 + +``` +./train_gpu.sh +``` + +CPU环境 + +在train_cpu.sh脚本文件中设置好数据路径、参数。 + +```sh +python train.py --use_gpu 0 \ #使用cpu + --train_files_path ./train_data_full \ #全量训练数据 + --model_dir ./model_dir \ #模型路径 + --learning_rate 0.001 \ + --batch_size 1000 \ + --epochs 10 \ + --reduction_ratio 3 \ #SENET超参数 + --dropout_rate 0.5 + --embedding_size 10 +``` + +修改脚本的可执行权限并运行 + +``` +./train_cpu.sh +``` + +## 单机预测 + +GPU环境 + +在infer_gpu.sh脚本文件中设置好数据路径、参数。 + +```sh +CUDA_VISIBLE_DEVICES=0 python train.py --use_gpu 0 \ #使用gpu + --test_files_path ./test_data_full \ #使用全量测试数据 + --model_dir ./model_dir \ #模型路径 + --test_epoch 10 #选择哪个epoch的模型参数进行预测 +``` + +修改脚本的可执行权限并运行 + +``` +./infer_gpu.sh +``` + +CPU环境 + +在infer_cpu.sh脚本文件中设置好数据路径、参数。 + +```sh +python train.py --use_gpu 0 \ #使用cpu + --test_files_path ./test_data_full \ #使用全量测试数据 + --model_dir ./model_dir \ #模型路径 + --test_epoch 10 #选择哪个epoch的模型参数进行预测 +``` + +修改脚本的可执行权限并运行 + +``` +./infer_cpu.sh +``` + +## 模型效果 + +训练: + +``` +2020-06-10 23:34:45,195-INFO: epoch_id: 0, batch_id: 33952, batch_time: 1.26086s, loss: 0.44914, auc: 0.79089 +2020-06-10 23:34:46,369-INFO: epoch_id: 0, batch_id: 33953, batch_time: 1.17280s, loss: 0.46410, auc: 0.79089 +2020-06-10 23:34:47,413-INFO: epoch_id: 0, batch_id: 33954, batch_time: 1.04139s, loss: 0.43496, auc: 0.79089 +2020-06-10 23:34:48,248-INFO: epoch_id: 0, batch_id: 33955, batch_time: 0.83510s, loss: 0.45980, auc: 0.79089 +2020-06-10 23:34:49,379-INFO: epoch_id: 0, batch_id: 33956, batch_time: 1.13043s, loss: 0.46738, auc: 0.79089 +2020-06-10 23:34:50,392-INFO: epoch_id: 0, batch_id: 33957, batch_time: 1.01046s, loss: 0.46724, auc: 0.79089 +2020-06-10 23:34:51,440-INFO: epoch_id: 0, batch_id: 33958, batch_time: 1.04752s, loss: 0.44079, auc: 0.79089 +``` + diff --git a/PaddleRec/ctr/fibinet/args.py b/PaddleRec/ctr/fibinet/args.py new file mode 100644 index 0000000000000000000000000000000000000000..d9f4dc1bdd6b1416a629cd34d61796d0e4420cb7 --- /dev/null +++ b/PaddleRec/ctr/fibinet/args.py @@ -0,0 +1,92 @@ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import distutils.util +import sys + +def parse_args(): + parser = argparse.ArgumentParser(description=__doc__) + # -------------Data & Model Path------------- + parser.add_argument( + '--train_files_path', + type=str, + default='./train_data_full', + help="The path of training dataset") + parser.add_argument( + '--test_files_path', + type=str, + default='./test_data_full', + help="The path of testing dataset") + parser.add_argument( + '--model_dir', + type=str, + default='./model_dir', + help='The path for model to store (default: models)') + + parser.add_argument( + '--test_epoch', + type=str, + default='10', + help='test_epoch') + + # -------------Training parameter------------- + parser.add_argument( + '--learning_rate', + type=float, + default=0.001, + help="Initial learning rate for training") + parser.add_argument( + '--batch_size', + type=int, + default=1000, + help="The size of mini-batch (default:1000)") + parser.add_argument( + "--epochs", + type=int, + default=1, + help="Number of epochs for training.") + parser.add_argument( + "--reduction_ratio", + type=int, + default=3, + help="reduction_ratio") + parser.add_argument( + '--bilinear_type', + type=str, + default='all', + help="bilinear_type") + parser.add_argument( + "--dropout_rate", + type=int, + default=0.5, + help="dropout_rate") + + # -------------Network parameter------------- + parser.add_argument( + '--embedding_size', + type=int, + default=10, + help="The size for embedding layer (default:10)") + parser.add_argument( + '--sparse_feature_dim', + type=int, + default=1000001, + help='sparse feature hashing space for index processing') + parser.add_argument( + '--dense_feature_dim', + type=int, + default=13, + help='dense feature shape') + + # -------------device parameter------------- + parser.add_argument( + '--use_gpu', + type=int, + default=0, + help='use_gpu') + + + return parser.parse_args() \ No newline at end of file diff --git a/PaddleRec/ctr/fibinet/feed_generator.py b/PaddleRec/ctr/fibinet/feed_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..f3d7945d551b6acfc4c16246120f6dd7356725b5 --- /dev/null +++ b/PaddleRec/ctr/fibinet/feed_generator.py @@ -0,0 +1,52 @@ +continous_features = range(1, 14) +categorial_features = range(14, 40) +continous_clip = [20, 600, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50] + +class CriteoDataset(object): + def __init__(self, sparse_feature_dim): + self.cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + self.cont_max_ = [ + 20, 600, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50 + ] + self.cont_diff_ = [ + 20, 603, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50 + ] + self.hash_dim_ = sparse_feature_dim + # here, training data are lines with line_index < train_idx_ + self.train_idx_ = 41256555 + self.continuous_range_ = range(1, 14) + self.categorical_range_ = range(14, 40) + + def _reader_creator(self, file_list, is_train): + def reader(): + for file in file_list: + with open(file, 'r') as f: + line_idx = 0 + for line in f: + line_idx += 1 + features = line.rstrip('\n').split('\t') + dense_feature = [] + sparse_feature = [] + for idx in self.continuous_range_: + if features[idx] == '': + dense_feature.append(0.0) + else: + dense_feature.append( + (float(features[idx]) - + self.cont_min_[idx - 1]) / + self.cont_diff_[idx - 1]) + for idx in self.categorical_range_: + sparse_feature.append([ + hash(str(idx) + features[idx]) % self.hash_dim_ + ]) + + label = [int(features[0])] + yield [dense_feature] + sparse_feature + [label] + + return reader + + def train(self, file_list): + return self._reader_creator(file_list, True) + + def test(self, file_list): + return self._reader_creator(file_list, False) diff --git a/PaddleRec/ctr/fibinet/get_data.sh b/PaddleRec/ctr/fibinet/get_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..c6a5ff7e63fa0e4a2cc29d43680039bd384445e9 --- /dev/null +++ b/PaddleRec/ctr/fibinet/get_data.sh @@ -0,0 +1,13 @@ +wget --no-check-certificate https://fleet.bj.bcebos.com/ctr_data.tar.gz +tar -zxvf ctr_data.tar.gz +mv ./raw_data ./train_data_full +mkdir train_data && cd train_data +cp ../train_data_full/part-0 ../train_data_full/part-1 ./ && cd .. +mv ./test_data ./test_data_full +mkdir test_data && cd test_data +cp ../test_data_full/part-220 ./ && cd .. +echo "Complete data download." +echo "Full Train data stored in ./train_data_full " +echo "Full Test data stored in ./test_data_full " +echo "Rapid Verification train data stored in ./train_data " +echo "Rapid Verification test data stored in ./test_data " \ No newline at end of file diff --git a/PaddleRec/ctr/fibinet/infer.py b/PaddleRec/ctr/fibinet/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..d08d1ba661daa26e1ff44643238bede82868f247 --- /dev/null +++ b/PaddleRec/ctr/fibinet/infer.py @@ -0,0 +1,69 @@ +import numpy as np +import os +import paddle.fluid as fluid +from net import Fibinet +import feed_generator as generator +import logging +import args +import utils +import time +logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger("fluid") +logger.setLevel(logging.INFO) + +def set_zero(var_name,scope=fluid.global_scope(), place=fluid.CPUPlace(),param_type="int64"): + """ + Set tensor of a Variable to zero. + Args: + var_name(str): name of Variable + scope(Scope): Scope object, default is fluid.global_scope() + place(Place): Place object, default is fluid.CPUPlace() + param_type(str): param data type, default is int64 + """ + param = scope.var(var_name).get_tensor() + param_array = np.zeros(param._get_dims()).astype(param_type) + param.set(param_array, place) + +def run_infer(args): + fibinet_model = Fibinet() + data_generator = generator.CriteoDataset(args.sparse_feature_dim) + + file_list = [os.path.join(args.test_files_path, x) for x in os.listdir(args.test_files_path)] + test_reader = fluid.io.batch(data_generator.test(file_list), batch_size=args.batch_size) + + inference_scope = fluid.Scope() + startup_program = fluid.framework.Program() + test_program = fluid.framework.Program() + + cur_model_path = os.path.join(args.model_dir, 'epoch_' + str(args.test_epoch), "checkpoint") + + with fluid.scope_guard(inference_scope): + with fluid.framework.program_guard(test_program, startup_program): + inputs = fibinet_model.input_data(args.dense_feature_dim) + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + avg_cost, auc_val, batch_auc, auc_states = fibinet_model.net(inputs, args.sparse_feature_dim, args.embedding_size, args.reduction_ratio, + args.bilinear_type, args.dropout_rate) + exe = fluid.Executor(place) + fluid.load(fluid.default_main_program(), cur_model_path, exe) + loader = fluid.io.DataLoader.from_generator(feed_list=inputs, capacity=args.batch_size, iterable=True) + loader.set_sample_list_generator(test_reader, places=place) + + for var in auc_states: # reset auc states + set_zero(var.name, scope=inference_scope, place=place) + + for batch_id, data in enumerate(loader()): + begin = time.time() + auc = exe.run(program=test_program, + feed=data, + fetch_list=[auc_val.name], + return_numpy=True) + end = time.time() + logger.info("batch_id: {}, batch_time: {:.5f}s, auc: {:.5f}".format( + batch_id, end-begin, np.array(auc)[0])) + + +if __name__ == "__main__": + args = args.parse_args() + logger.info("use_gpu:{}, test_files_path: {}, model_dir: {}, test_epoch: {}".format( + args.use_gpu, args.test_files_path, args.model_dir, args.test_epoch)) + run_infer(args) diff --git a/PaddleRec/ctr/fibinet/infer_cpu.sh b/PaddleRec/ctr/fibinet/infer_cpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..8b540bf89d79a47f55ff90a1c02eb38a2fbf9425 --- /dev/null +++ b/PaddleRec/ctr/fibinet/infer_cpu.sh @@ -0,0 +1 @@ +python train.py --use_gpu 0 --test_files_path ./test_data_full --model_dir ./model_dir --test_epoch 10 \ No newline at end of file diff --git a/PaddleRec/ctr/fibinet/infer_gpu.sh b/PaddleRec/ctr/fibinet/infer_gpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..1d0a7a3d7c5a821356fd54f96fa205d2aaf4d321 --- /dev/null +++ b/PaddleRec/ctr/fibinet/infer_gpu.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0 python train.py --use_gpu 1 --test_files_path ./test_data_full --model_dir ./model_dir --test_epoch 10 \ No newline at end of file diff --git a/PaddleRec/ctr/fibinet/net.py b/PaddleRec/ctr/fibinet/net.py new file mode 100644 index 0000000000000000000000000000000000000000..f540019d2470ae500952b4eea44f9247eb261069 --- /dev/null +++ b/PaddleRec/ctr/fibinet/net.py @@ -0,0 +1,86 @@ +import paddle.fluid as fluid +import itertools + +class Fibinet(object): + def input_data(self, dense_feature_dim): + dense_input = fluid.data(name="dense_input", + shape=[-1, dense_feature_dim], + dtype="float32") + + sparse_input_ids = [ + fluid.data(name="C" + str(i), + shape=[-1, 1], + lod_level=1, + dtype="int64") for i in range(1, 27) + ] + + label = fluid.data(name="label", shape=[-1, 1], dtype="int64") + + inputs = [dense_input] + sparse_input_ids + [label] + return inputs + + def fc(self, data, size, active, tag): + output = fluid.layers.fc(input=data, + size=size, + param_attr=fluid.initializer.Xavier(uniform=False), + act=active, + name=tag) + + return output + def SENETLayer(self, inputs, filed_size, reduction_ratio = 3): + reduction_size = max(1, filed_size // reduction_ratio) + Z = fluid.layers.reduce_mean(inputs, dim=-1) + + A_1 = self.fc(Z, reduction_size, 'relu', 'W_1') + A_2 = self.fc(A_1, filed_size, 'relu', 'W_2') + + V = fluid.layers.elementwise_mul(inputs, y = fluid.layers.unsqueeze(input=A_2, axes=[2])) + + return fluid.layers.split(V, num_or_sections=filed_size, dim=1) + + def BilinearInteraction(self, inputs, filed_size, embedding_size, bilinear_type="interaction"): + if bilinear_type == "all": + p = [fluid.layers.elementwise_mul(self.fc(v_i, embedding_size, None, None), fluid.layers.squeeze(input=v_j, axes=[1])) for v_i, v_j in itertools.combinations(inputs, 2)] + else: + raise NotImplementedError + + return fluid.layers.concat(input=p, axis=1) + + def DNNLayer(self, inputs, dropout_rate=0.5): + deep_input = inputs + for i, hidden_unit in enumerate([400, 400, 400]): + fc_out = self.fc(deep_input, hidden_unit, 'relu', 'd_' + str(i)) + fc_out = fluid.layers.dropout(fc_out, dropout_prob=dropout_rate) + deep_input = fc_out + + return deep_input + + def net(self, inputs, sparse_feature_dim, embedding_size, reduction_ratio, bilinear_type, dropout_rate=0.5): + filed_size = len(inputs[1:-1]) + + emb = [] + for data in inputs[1 :-1]: + feat_emb = fluid.embedding(input=data, + size=[sparse_feature_dim, embedding_size], + param_attr=fluid.ParamAttr(name='dis_emb', + learning_rate=5, + initializer=fluid.initializer.Xavier(fan_in=embedding_size,fan_out=embedding_size) + ), + is_sparse=True) + emb.append(feat_emb) + concat_emb = fluid.layers.concat(emb, axis=1) + + senet_output = self.SENETLayer(concat_emb, filed_size, reduction_ratio) + senet_bilinear_out = self.BilinearInteraction(senet_output, filed_size, embedding_size, bilinear_type) + + concat_emb = fluid.layers.split(concat_emb, num_or_sections=filed_size, dim=1) + bilinear_out = self.BilinearInteraction(concat_emb, filed_size, embedding_size, bilinear_type) + dnn_input = fluid.layers.concat(input=[senet_bilinear_out, bilinear_out, inputs[0]], axis=1) + dnn_output = self.DNNLayer(dnn_input, dropout_rate) + label = inputs[-1] + y_pred = self.fc(dnn_output, 1, 'sigmoid', 'logit') + cost = fluid.layers.log_loss(input=y_pred, label=fluid.layers.cast(x=label, dtype='float32')) + avg_cost = fluid.layers.mean(cost) + auc_val, batch_auc, auc_states = fluid.layers.auc(input=y_pred, label=label) + + return avg_cost, auc_val, batch_auc, auc_states diff --git a/PaddleRec/ctr/fibinet/train.py b/PaddleRec/ctr/fibinet/train.py new file mode 100644 index 0000000000000000000000000000000000000000..3dddb00d2ad3c66341a446acd20eb5d1af92ce93 --- /dev/null +++ b/PaddleRec/ctr/fibinet/train.py @@ -0,0 +1,70 @@ +import paddle.fluid as fluid +import logging +import args +import time +import os +import numpy as np +import random +from net import Fibinet +import feed_generator as generator + +logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger("fluid") +logger.setLevel(logging.INFO) + +def get_dataset(inputs, args): + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_use_var(inputs) + dataset.set_pipe_command("python dataset_generator.py") + dataset.set_batch_size(args.batch_size) + thread_num = int(args.cpu_num) + dataset.set_thread(thread_num) + file_list = [ + os.path.join(args.train_files_path, x) for x in os.listdir(args.train_files_path) + ] + logger.info("file list: {}".format(file_list)) + + return dataset, file_list + +def train(args): + fibinet_model = Fibinet() + inputs = fibinet_model.input_data(args.dense_feature_dim) + data_generator = generator.CriteoDataset(args.sparse_feature_dim) + + file_list = [os.path.join(args.train_files_path, x) for x in os.listdir(args.train_files_path)] + train_reader = fluid.io.batch(data_generator.train(file_list), batch_size=args.batch_size) + + avg_cost, auc_val, batch_auc, auc_states = fibinet_model.net(inputs, args.sparse_feature_dim, args.embedding_size, + args.reduction_ratio, args.bilinear_type, args.dropout_rate) + + optimizer = fluid.optimizer.Adam(args.learning_rate) + optimizer.minimize(avg_cost) + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + loader = fluid.io.DataLoader.from_generator(feed_list=inputs, capacity=args.batch_size, iterable=True) + loader.set_sample_list_generator(train_reader, places=place) + + for epoch in range(args.epochs): + for batch_id, data in enumerate(loader()): + begin = time.time() + loss_data, auc = exe.run(program=fluid.default_main_program(), + feed=data, + fetch_list=[avg_cost.name, auc_val.name], + return_numpy=True) + end = time.time() + logger.info("epoch_id: {}, batch_id: {}, batch_time: {:.5f}s, loss: {:.5f}, auc: {:.5f}".format( + epoch, batch_id, end-begin, float(np.array(loss_data)), np.array(auc)[0])) + + model_dir = os.path.join(args.model_dir, 'epoch_' + str(epoch + 1), "checkpoint") + main_program = fluid.default_main_program() + fluid.io.save(main_program, model_dir) + +if __name__ == '__main__': + args = args.parse_args() + logger.info("use_gpu:{}, train_files_path: {}, model_dir: {}, learning_rate: {}, batch_size: {}, epochs: {}, reduction_ratio: {}, dropout_rate: {}, embedding_size: {}".format( + args.use_gpu, args.train_files_path, args.model_dir, args.learning_rate, args.batch_size, args.epochs, args.reduction_ratio, args.dropout_rate, args.embedding_size)) + + train(args) \ No newline at end of file diff --git a/PaddleRec/ctr/fibinet/train_cpu.sh b/PaddleRec/ctr/fibinet/train_cpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..ee938090caad44a26671f31c627d3882f2ac55f9 --- /dev/null +++ b/PaddleRec/ctr/fibinet/train_cpu.sh @@ -0,0 +1 @@ +python train.py --use_gpu 0 --train_files_path ./train_data_full --model_dir ./model_dir --learning_rate 0.001 --batch_size 1000 --epochs 10 --reduction_ratio 3 --dropout_rate 0.5 --embedding_size 10 \ No newline at end of file diff --git a/PaddleRec/ctr/fibinet/train_gpu.sh b/PaddleRec/ctr/fibinet/train_gpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..25b807499b41871db594d1defb3c2eaab01ed46d --- /dev/null +++ b/PaddleRec/ctr/fibinet/train_gpu.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0 python train.py --use_gpu 1 --train_files_path ./train_data_full --model_dir ./model_dir --learning_rate 0.001 --batch_size 1000 --epochs 10 --reduction_ratio 3 --dropout_rate 0.5 --embedding_size 10 \ No newline at end of file