diff --git a/deep_fm/README.md b/deep_fm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..588ee091db8f71e5b0018f8b61f81b3dc9f4e9b9 --- /dev/null +++ b/deep_fm/README.md @@ -0,0 +1,5 @@ +# DeepFM 基于深度因子分解机的点击率预测模型 + +## 简介 + +[TBD] diff --git a/deep_fm/data/download.sh b/deep_fm/data/download.sh new file mode 100755 index 0000000000000000000000000000000000000000..1cadfe5a3ef5d266c20d0af3d99398f8a6057d16 --- /dev/null +++ b/deep_fm/data/download.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +wget https://s3-eu-west-1.amazonaws.com/criteo-labs/dac.tar.gz +tar zxf dac.tar.gz +rm -f dac.tar.gz diff --git a/deep_fm/images/DeepFM.png b/deep_fm/images/DeepFM.png new file mode 100644 index 0000000000000000000000000000000000000000..31444dbf4db65846209380a3e1eebe49cd1e6a73 Binary files /dev/null and b/deep_fm/images/DeepFM.png differ diff --git a/deep_fm/images/FM.png b/deep_fm/images/FM.png new file mode 100644 index 0000000000000000000000000000000000000000..469d636a07c41de68e4dc06513ccb4a5c1a898a3 Binary files /dev/null and b/deep_fm/images/FM.png differ diff --git a/deep_fm/infer.py b/deep_fm/infer.py new file mode 100755 index 0000000000000000000000000000000000000000..63c096cdb4dcf370eb256c68f36b678397b4b3fd --- /dev/null +++ b/deep_fm/infer.py @@ -0,0 +1,63 @@ +import os +import gzip +import argparse +import itertools + +import paddle.v2 as paddle + +from network_conf import DeepFM +import reader + + +def parse_args(): + parser = argparse.ArgumentParser(description="PaddlePaddle DeepFM example") + parser.add_argument( + '--model_gz_path', + type=str, + required=True, + help="path of model parameters gz file") + parser.add_argument( + '--data_path', + type=str, + required=True, + help="path of the dataset to infer") + parser.add_argument( + '--prediction_output_path', + type=str, + required=True, + help="path to output the prediction") + parser.add_argument( + '--factor_size', + type=int, + default=10, + help="the factor size for the factorization machine (default:10)") + + return parser.parse_args() + + +def infer(): + args = parse_args() + + paddle.init(use_gpu=False, trainer_count=1) + + model = DeepFM(args.factor_size, infer=True) + + parameters = paddle.parameters.Parameters.from_tar( + gzip.open(args.model_gz_path, 'r')) + + inferer = paddle.inference.Inference( + output_layer=model, parameters=parameters) + + dataset = reader.Dataset() + + infer_reader = paddle.batch(dataset.infer(args.data_path), batch_size=1000) + + with open(args.prediction_output_path, 'w') as out: + for id, batch in enumerate(infer_reader()): + res = inferer.infer(input=batch) + predictions = [x for x in itertools.chain.from_iterable(res)] + out.write('\n'.join(map(str, predictions)) + '\n') + + +if __name__ == '__main__': + infer() diff --git a/deep_fm/network_conf.py b/deep_fm/network_conf.py new file mode 100644 index 0000000000000000000000000000000000000000..2382c0cdf13dbe12401a3279f95616523e0dd379 --- /dev/null +++ b/deep_fm/network_conf.py @@ -0,0 +1,73 @@ +import paddle.v2 as paddle + +dense_feature_dim = 13 +sparse_feature_dim = 117568 + + +def fm_layer(input, factor_size, fm_param_attr): + first_order = paddle.layer.fc( + input=input, size=1, act=paddle.activation.Linear()) + second_order = paddle.layer.factorization_machine( + input=input, + factor_size=factor_size, + act=paddle.activation.Linear(), + param_attr=fm_param_attr) + out = paddle.layer.addto( + input=[first_order, second_order], + act=paddle.activation.Sigmoid(), + bias_attr=False) + return out + + +def DeepFM(factor_size, infer=False): + dense_input = paddle.layer.data( + name="dense_input", + type=paddle.data_type.dense_vector(dense_feature_dim)) + sparse_input = paddle.layer.data( + name="sparse_input", + type=paddle.data_type.sparse_binary_vector(sparse_feature_dim)) + sparse_input_ids = [ + paddle.layer.data( + name="C" + str(i), + type=paddle.data_type.integer_value(sparse_feature_dim)) + for i in range(1, 27) + ] + + dense_fm = fm_layer( + dense_input, + factor_size, + fm_param_attr=paddle.attr.Param(name="DenseFeatFactors")) + sparse_fm = fm_layer( + sparse_input, + factor_size, + fm_param_attr=paddle.attr.Param(name="SparseFeatFactors")) + + def embedding_layer(input): + return paddle.layer.embedding( + input=input, + size=factor_size, + param_attr=paddle.attr.Param(name="SparseFeatFactors")) + + sparse_embed_seq = map(embedding_layer, sparse_input_ids) + sparse_embed = paddle.layer.concat(sparse_embed_seq) + + fc1 = paddle.layer.fc( + input=[sparse_embed, dense_input], + size=400, + act=paddle.activation.Relu()) + fc2 = paddle.layer.fc(input=fc1, size=400, act=paddle.activation.Relu()) + fc3 = paddle.layer.fc(input=fc2, size=400, act=paddle.activation.Relu()) + + predict = paddle.layer.fc( + input=[dense_fm, sparse_fm, fc3], + size=1, + act=paddle.activation.Sigmoid()) + + if not infer: + label = paddle.layer.data( + name="label", type=paddle.data_type.dense_vector(1)) + cost = paddle.layer.multi_binary_label_cross_entropy_cost( + input=predict, label=label) + return cost + else: + return predict diff --git a/deep_fm/preprocess.py b/deep_fm/preprocess.py new file mode 100755 index 0000000000000000000000000000000000000000..1995b1f4898f11961297ad8b97e72187445f6386 --- /dev/null +++ b/deep_fm/preprocess.py @@ -0,0 +1,147 @@ +""" +Preprocess Criteo dataset. This dataset was used for the Display Advertising +Challenge (https://www.kaggle.com/c/criteo-display-ad-challenge). +""" +import os +import sys +import click +import collections + +# There are 13 integer features and 26 categorical features +continous_features = range(1, 14) +categorial_features = range(14, 40) + + +class CategoryDictGenerator: + """ + Generate dictionary for each of the categorical features + """ + + def __init__(self, num_feature): + self.dicts = [] + self.num_feature = num_feature + for i in range(0, num_feature): + self.dicts.append(collections.defaultdict(int)) + + def build(self, datafile, categorial_features, cutoff=0): + with open(datafile, 'r') as f: + for line in f: + features = line.rstrip('\n').split('\t') + for i in range(0, self.num_feature): + if features[categorial_features[i]] != '': + self.dicts[i][features[categorial_features[i]]] += 1 + for i in range(0, self.num_feature): + self.dicts[i] = filter(lambda x: x[1] >= cutoff, + self.dicts[i].items()) + self.dicts[i] = sorted(self.dicts[i], key=lambda x: (-x[1], x[0])) + vocabs, _ = list(zip(*self.dicts[i])) + self.dicts[i] = dict(zip(vocabs, range(1, len(vocabs) + 1))) + self.dicts[i][''] = 0 + + def gen(self, idx, key): + if key not in self.dicts[idx]: + res = self.dicts[idx][''] + else: + res = self.dicts[idx][key] + return res + + def dicts_sizes(self): + return map(len, self.dicts) + + +class ContinuousFeatureGenerator: + """ + Normalize the integer features to [0, 1] by min-max normalization + """ + + def __init__(self, num_feature): + self.num_feature = num_feature + self.min = [sys.maxint] * num_feature + self.max = [-sys.maxint] * num_feature + + def build(self, datafile, continous_features): + with open(datafile, 'r') as f: + for line in f: + features = line.rstrip('\n').split('\t') + for i in range(0, self.num_feature): + val = features[continous_features[i]] + if val != '': + val = int(val) + self.min[i] = min(self.min[i], val) + self.max[i] = max(self.max[i], val) + + def gen(self, idx, val): + if val == '': + return 0 + val = float(val) + return (val - self.min[idx]) / (self.max[idx] - self.min[idx]) + + +@click.command("preprocess") +@click.option("--datadir", type=str, help="Path to raw criteo dataset") +@click.option("--outdir", type=str, help="Path to save the processed data") +def preprocess(datadir, outdir): + """ + All the 13 integer features are normalzied to continous values and these + continous features are combined into one vecotr with dimension 13. + + Each of the 26 categorical features are one-hot encoded and all the one-hot + vectors are combined into one sparse binary vector. + """ + dists = ContinuousFeatureGenerator(len(continous_features)) + dists.build(os.path.join(datadir, 'train.txt'), continous_features) + + dicts = CategoryDictGenerator(len(categorial_features)) + dicts.build( + os.path.join(datadir, 'train.txt'), categorial_features, cutoff=200) + + dict_sizes = dicts.dicts_sizes() + categorial_feature_offset = [0] + for i in range(1, len(categorial_features)): + offset = categorial_feature_offset[i - 1] + dict_sizes[i - 1] + categorial_feature_offset.append(offset) + + with open(os.path.join(outdir, 'train.txt'), 'w') as out: + with open(os.path.join(datadir, 'train.txt'), 'r') as f: + for line in f: + features = line.rstrip('\n').split('\t') + + continous_vals = [] + for i in range(0, len(continous_features)): + val = dists.gen(i, features[continous_features[i]]) + continous_vals.append(str(val)) + categorial_vals = [] + for i in range(0, len(categorial_features)): + val = dicts.gen(i, features[categorial_features[ + i]]) + categorial_feature_offset[i] + categorial_vals.append(str(val)) + + continous_vals = ','.join(continous_vals) + categorial_vals = ','.join(categorial_vals) + label = features[0] + out.write('\t'.join([continous_vals, categorial_vals, label]) + + '\n') + + with open(os.path.join(outdir, 'test.txt'), 'w') as out: + with open(os.path.join(datadir, 'test.txt'), 'r') as f: + for line in f: + features = line.rstrip('\n').split('\t') + + continous_vals = [] + for i in range(0, len(continous_features)): + val = dists.gen(i, features[continous_features[i] - 1]) + continous_vals.append(str(val)) + categorial_vals = [] + for i in range(0, len(categorial_features)): + val = dicts.gen(i, + features[categorial_features[i] - + 1]) + categorial_feature_offset[i] + categorial_vals.append(str(val)) + + continous_vals = ','.join(continous_vals) + categorial_vals = ','.join(categorial_vals) + out.write('\t'.join([continous_vals, categorial_vals]) + '\n') + + +if __name__ == "__main__": + preprocess() diff --git a/deep_fm/reader.py b/deep_fm/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..2ac30ecc533f99f6d49a5bf02cdc6b15bc3fe0f9 --- /dev/null +++ b/deep_fm/reader.py @@ -0,0 +1,55 @@ +class Dataset: + def _reader_creator(self, path, is_infer): + def reader(): + with open(path, 'r') as f: + for line in f: + features = line.rstrip('\n').split('\t') + dense_feature = map(float, features[0].split(',')) + sparse_feature = map(int, features[1].split(',')) + if not is_infer: + label = [float(features[2])] + yield [dense_feature, sparse_feature + ] + sparse_feature + [label] + else: + yield [dense_feature, sparse_feature] + sparse_feature + + return reader + + def train(self, path): + return self._reader_creator(path, False) + + def infer(self, path): + return self._reader_creator(path, True) + + +feeding = { + 'dense_input': 0, + 'sparse_input': 1, + 'C1': 2, + 'C2': 3, + 'C3': 4, + 'C4': 5, + 'C5': 6, + 'C6': 7, + 'C7': 8, + 'C8': 9, + 'C9': 10, + 'C10': 11, + 'C11': 12, + 'C12': 13, + 'C13': 14, + 'C14': 15, + 'C15': 16, + 'C16': 17, + 'C17': 18, + 'C18': 19, + 'C19': 20, + 'C20': 21, + 'C21': 22, + 'C22': 23, + 'C23': 24, + 'C24': 25, + 'C25': 26, + 'C26': 27, + 'label': 28 +} diff --git a/deep_fm/train.py b/deep_fm/train.py new file mode 100755 index 0000000000000000000000000000000000000000..2be7e7d990616cfffd4bbfce93e3687cba58c1c1 --- /dev/null +++ b/deep_fm/train.py @@ -0,0 +1,91 @@ +import os +import gzip +import logging +import argparse + +import paddle.v2 as paddle + +from network_conf import DeepFM +import reader + +logging.basicConfig() +logger = logging.getLogger("paddle") +logger.setLevel(logging.INFO) + + +def parse_args(): + parser = argparse.ArgumentParser(description="PaddlePaddle DeepFM example") + parser.add_argument( + '--train_data_path', + type=str, + required=True, + help="path of training dataset") + parser.add_argument( + '--batch_size', + type=int, + default=10000, + help="size of mini-batch (default:10000)") + parser.add_argument( + '--num_passes', + type=int, + default=10, + help="number of passes to train (default: 10)") + parser.add_argument( + '--factor_size', + type=int, + default=10, + help="the factor size for the factorization machine (default:10)") + parser.add_argument( + '--model_output_dir', + type=str, + default='models', + help='path for model to store (default: models)') + + return parser.parse_args() + + +def train(): + args = parse_args() + + if not os.path.isdir(args.model_output_dir): + os.mkdir(args.model_output_dir) + + paddle.init(use_gpu=False, trainer_count=1) + + optimizer = paddle.optimizer.Adam(learning_rate=1e-3) + + model = DeepFM(args.factor_size) + + params = paddle.parameters.create(model) + + trainer = paddle.trainer.SGD( + cost=model, parameters=params, update_equation=optimizer) + + dataset = reader.Dataset() + + def __event_handler__(event): + if isinstance(event, paddle.event.EndIteration): + num_samples = event.batch_id * args.batch_size + if event.batch_id % 10 == 0: + logger.warning("Pass %d, Batch %d, Samples %d, Cost %f" % ( + event.pass_id, event.batch_id, num_samples, event.cost)) + + if event.batch_id % 1000 == 0: + path = "{}/model-pass-{}-batch-{}.tar.gz".format( + args.model_output_dir, event.pass_id, event.batch_id) + with gzip.open(path, 'w') as f: + trainer.save_parameter_to_tar(f) + + trainer.train( + reader=paddle.batch( + paddle.reader.shuffle( + dataset.train(args.train_data_path), + buf_size=args.batch_size * 100), + batch_size=args.batch_size), + feeding=reader.feeding, + event_handler=__event_handler__, + num_passes=args.num_passes) + + +if __name__ == '__main__': + train() diff --git a/text_classification/train.py b/text_classification/train.py index cda04bfc6a33ee9e39298910d724bc716f1b53df..888fde356f3aec1addb5e5fcf35e17d0c82f37c3 100644 --- a/text_classification/train.py +++ b/text_classification/train.py @@ -46,10 +46,10 @@ def train(topology, word_dict = paddle.dataset.imdb.word_dict() train_reader = paddle.batch( paddle.reader.shuffle( - lambda: paddle.dataset.imdb.train(word_dict), buf_size=1000), + lambda: paddle.dataset.imdb.train(word_dict)(), buf_size=1000), batch_size=100) test_reader = paddle.batch( - lambda: paddle.dataset.imdb.test(word_dict), batch_size=100) + lambda: paddle.dataset.imdb.test(word_dict)(), batch_size=100) class_num = 2 else: