From c0c909b96c93b196ac8efce0f437a698d361d4b3 Mon Sep 17 00:00:00 2001 From: guru4elephant Date: Sun, 9 Feb 2020 10:45:25 +0800 Subject: [PATCH] add ctr example on criteo --- python/examples/criteo_ctr/args.py | 85 +++++++++++++++++++++ python/examples/criteo_ctr/clean.sh | 1 + python/examples/criteo_ctr/criteo_reader.py | 60 +++++++++++++++ python/examples/criteo_ctr/get_data.sh | 2 + python/examples/criteo_ctr/local_train.py | 64 ++++++++++++++++ python/examples/criteo_ctr/network_conf.py | 37 +++++++++ python/examples/criteo_ctr/test_client.py | 27 +++++++ python/examples/criteo_ctr/test_server.py | 19 +++++ 8 files changed, 295 insertions(+) create mode 100644 python/examples/criteo_ctr/args.py create mode 100644 python/examples/criteo_ctr/clean.sh create mode 100644 python/examples/criteo_ctr/criteo_reader.py create mode 100644 python/examples/criteo_ctr/get_data.sh create mode 100644 python/examples/criteo_ctr/local_train.py create mode 100644 python/examples/criteo_ctr/network_conf.py create mode 100644 python/examples/criteo_ctr/test_client.py create mode 100644 python/examples/criteo_ctr/test_server.py diff --git a/python/examples/criteo_ctr/args.py b/python/examples/criteo_ctr/args.py new file mode 100644 index 00000000..eb2227a5 --- /dev/null +++ b/python/examples/criteo_ctr/args.py @@ -0,0 +1,85 @@ +import argparse + +def parse_args(): + parser = argparse.ArgumentParser(description="PaddlePaddle CTR example") + parser.add_argument( + '--train_data_path', + type=str, + default='./data/raw/train.txt', + help="The path of training dataset") + parser.add_argument( + '--test_data_path', + type=str, + default='./data/raw/valid.txt', + help="The path of testing dataset") + parser.add_argument( + '--batch_size', + type=int, + default=1000, + help="The size of mini-batch (default:1000)") + parser.add_argument( + '--embedding_size', + type=int, + default=10, + help="The size for embedding layer (default:10)") + parser.add_argument( + '--num_passes', + type=int, + default=10, + help="The number of passes to train (default: 10)") + parser.add_argument( + '--model_output_dir', + type=str, + default='models', + help='The path for model to store (default: models)') + parser.add_argument( + '--sparse_feature_dim', + type=int, + default=1000001, + help='sparse feature hashing space for index processing') + parser.add_argument( + '--is_local', + type=int, + default=1, + help='Local train or distributed train (default: 1)') + parser.add_argument( + '--cloud_train', + type=int, + default=0, + help='Local train or distributed train on paddlecloud (default: 0)') + parser.add_argument( + '--async_mode', + action='store_true', + default=False, + help='Whether start pserver in async mode to support ASGD') + parser.add_argument( + '--no_split_var', + action='store_true', + default=False, + help='Whether split variables into blocks when update_method is pserver') + parser.add_argument( + '--role', + type=str, + default='pserver', # trainer or pserver + help='The path for model to store (default: models)') + parser.add_argument( + '--endpoints', + type=str, + default='127.0.0.1:6000', + help='The pserver endpoints, like: 127.0.0.1:6000,127.0.0.1:6001') + parser.add_argument( + '--current_endpoint', + type=str, + default='127.0.0.1:6000', + help='The path for model to store (default: 127.0.0.1:6000)') + parser.add_argument( + '--trainer_id', + type=int, + default=0, + help='The path for model to store (default: models)') + parser.add_argument( + '--trainers', + type=int, + default=1, + help='The num of trianers, (default: 1)') + return parser.parse_args() diff --git a/python/examples/criteo_ctr/clean.sh b/python/examples/criteo_ctr/clean.sh new file mode 100644 index 00000000..78703636 --- /dev/null +++ b/python/examples/criteo_ctr/clean.sh @@ -0,0 +1 @@ +rm -rf *pyc kvdb raw_data ctr_client_conf ctr_serving_model ctr_data.tar.gz *~ diff --git a/python/examples/criteo_ctr/criteo_reader.py b/python/examples/criteo_ctr/criteo_reader.py new file mode 100644 index 00000000..06f90d27 --- /dev/null +++ b/python/examples/criteo_ctr/criteo_reader.py @@ -0,0 +1,60 @@ +import sys +import paddle.fluid.incubate.data_generator as dg + +class CriteoDataset(dg.MultiSlotDataGenerator): + def setup(self, sparse_feature_dim): + self.cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + self.cont_max_ = [20, 600, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50] + self.cont_diff_ = [20, 603, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50] + self.hash_dim_ = sparse_feature_dim + # here, training data are lines with line_index < train_idx_ + self.train_idx_ = 41256555 + self.continuous_range_ = range(1, 14) + self.categorical_range_ = range(14, 40) + + def _process_line(self, line): + features = line.rstrip('\n').split('\t') + dense_feature = [] + sparse_feature = [] + for idx in self.continuous_range_: + if features[idx] == '': + dense_feature.append(0.0) + else: + dense_feature.append((float(features[idx]) - self.cont_min_[idx - 1]) / \ + self.cont_diff_[idx - 1]) + for idx in self.categorical_range_: + sparse_feature.append([hash(str(idx) + features[idx]) % self.hash_dim_]) + + return dense_feature, sparse_feature, [int(features[0])] + + def infer_reader(self, filelist, batch, buf_size): + def local_iter(): + for fname in filelist: + with open(fname.strip(), "r") as fin: + for line in fin: + dense_feature, sparse_feature, label = self._process_line(line) + #yield dense_feature, sparse_feature, label + yield [dense_feature] + sparse_feature + [label] + import paddle + batch_iter = paddle.batch( + paddle.reader.shuffle( + local_iter, buf_size=buf_size), + batch_size=batch) + return batch_iter + + + def generate_sample(self, line): + def data_iter(): + dense_feature, sparse_feature, label = self._process_line(line) + feature_name = ["dense_input"] + for idx in self.categorical_range_: + feature_name.append("C" + str(idx - 13)) + feature_name.append("label") + yield zip(feature_name, [dense_feature] + sparse_feature + [label]) + + return data_iter + +if __name__ == "__main__": + criteo_dataset = CriteoDataset() + criteo_dataset.setup(int(sys.argv[1])) + criteo_dataset.run_from_stdin() diff --git a/python/examples/criteo_ctr/get_data.sh b/python/examples/criteo_ctr/get_data.sh new file mode 100644 index 00000000..8630c992 --- /dev/null +++ b/python/examples/criteo_ctr/get_data.sh @@ -0,0 +1,2 @@ +wget 10.86.69.44:/home/work/incubate/ctr_data.tar.gz +tar -zxvf ctr_data.tar.gz diff --git a/python/examples/criteo_ctr/local_train.py b/python/examples/criteo_ctr/local_train.py new file mode 100644 index 00000000..ea8d3ad5 --- /dev/null +++ b/python/examples/criteo_ctr/local_train.py @@ -0,0 +1,64 @@ +from __future__ import print_function + +from args import parse_args +import os +import paddle.fluid as fluid +import sys +from network_conf import ctr_dnn_model_dataset + +dense_feature_dim = 13 + +def train(): + args = parse_args() + if not os.path.isdir(args.model_output_dir): + os.mkdir(args.model_output_dir) + + dense_input = fluid.layers.data( + name="dense_input", shape=[dense_feature_dim], dtype='float32') + sparse_input_ids = [ + fluid.layers.data(name="C" + str(i), shape=[1], lod_level=1, dtype="int64") + for i in range(1, 27)] + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + predict_y, loss, auc_var, batch_auc_var = ctr_dnn_model_dataset( + dense_input, sparse_input_ids, label, + args.embedding_size, args.sparse_feature_dim) + + optimizer = fluid.optimizer.SGD(learning_rate=1e-4) + optimizer.minimize(loss) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_use_var([dense_input] + sparse_input_ids + [label]) + python_executable = "python" + pipe_command = "{} criteo_reader.py {}".format(python_executable, args.sparse_feature_dim) + dataset.set_pipe_command(pipe_command) + dataset.set_batch_size(128) + thread_num = 10 + dataset.set_thread(thread_num) + whole_filelist = ["raw_data/part-%d" % x for x in range(len(os.listdir("raw_data")))] + #dataset.set_filelist(whole_filelist[:(len(whole_filelist)-thread_num)]) + dataset.set_filelist(whole_filelist[:thread_num]) + dataset.load_into_memory() + + epochs = 1 + for i in range(epochs): + exe.train_from_dataset(program=fluid.default_main_program(), + dataset=dataset, + debug=True) + print("epoch {} finished".format(i)) + + import paddle_serving_client.io as server_io + feed_var_dict = {} + for i, sparse in enumerate(sparse_input_ids): + feed_var_dict["sparse_{}".format(i)] = sparse + feed_var_dict["dense_0"] = dense_input + fetch_var_dict = {"prob": predict_y} + + server_io.save_model( + "ctr_serving_model", "ctr_client_conf", + feed_var_dict, fetch_var_dict, fluid.default_main_program()) + +if __name__ == '__main__': + train() diff --git a/python/examples/criteo_ctr/network_conf.py b/python/examples/criteo_ctr/network_conf.py new file mode 100644 index 00000000..d763b916 --- /dev/null +++ b/python/examples/criteo_ctr/network_conf.py @@ -0,0 +1,37 @@ +import paddle.fluid as fluid +import math + +dense_feature_dim = 13 + +def ctr_dnn_model_dataset(dense_input, sparse_inputs, label, + embedding_size, sparse_feature_dim): + def embedding_layer(input): + emb = fluid.layers.embedding( + input=input, + is_sparse=True, + is_distributed=False, + size=[sparse_feature_dim, embedding_size], + param_attr=fluid.ParamAttr(name="SparseFeatFactors", + initializer=fluid.initializer.Uniform())) + return fluid.layers.sequence_pool(input=emb, pool_type='sum') + + sparse_embed_seq = list(map(embedding_layer, sparse_inputs)) + concated = fluid.layers.concat(sparse_embed_seq + [dense_input], axis=1) + fc1 = fluid.layers.fc(input=concated, size=400, act='relu', + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(concated.shape[1])))) + fc2 = fluid.layers.fc(input=fc1, size=400, act='relu', + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(fc1.shape[1])))) + fc3 = fluid.layers.fc(input=fc2, size=400, act='relu', + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(fc2.shape[1])))) + predict = fluid.layers.fc(input=fc3, size=2, act='softmax', + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(fc3.shape[1])))) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.reduce_sum(cost) + accuracy = fluid.layers.accuracy(input=predict, label=label) + auc_var, batch_auc_var, auc_states = \ + fluid.layers.auc(input=predict, label=label, num_thresholds=2 ** 12, slide_steps=20) + return predict, avg_cost, auc_var, batch_auc_var diff --git a/python/examples/criteo_ctr/test_client.py b/python/examples/criteo_ctr/test_client.py new file mode 100644 index 00000000..4dc56c0d --- /dev/null +++ b/python/examples/criteo_ctr/test_client.py @@ -0,0 +1,27 @@ +from paddle_serving_client import Client +import paddle +import sys +import os +import criteo_reader as criteo + +client = Client() +client.load_client_config(sys.argv[1]) +client.connect(["127.0.0.1:9292"]) + +batch = 1 +buf_size = 100 +dataset = criteo.CriteoDataset() +dataset.setup(1000001) +test_filelists = ["{}/part-%d".format(sys.argv[2]) % x + for x in range(len(os.listdir(sys.argv[2])))] +reader = dataset.infer_reader(test_filelists[len(test_filelists)-40:], batch, buf_size) + +for data in reader(): + feed_dict = {} + feed_dict["dense_0"] = data[0][0] + for i in range(1, 27): + feed_dict["sparse_{}".format(i - 1)] = data[0][i] + feed_dict["label"] = data[0][-1] + fetch_map = client.predict(feed=feed_dict, fetch=["prob"]) + print("{} {}".format(fetch_map["prob"][0], data[0][-1][0])) + diff --git a/python/examples/criteo_ctr/test_server.py b/python/examples/criteo_ctr/test_server.py new file mode 100644 index 00000000..e77c5fb6 --- /dev/null +++ b/python/examples/criteo_ctr/test_server.py @@ -0,0 +1,19 @@ +import os +import sys +from paddle_serving_server import OpMaker +from paddle_serving_server import OpSeqMaker +from paddle_serving_server import Server + +op_maker = OpMaker() +read_op = op_maker.create('general_reader') +general_infer_op = op_maker.create('general_infer') + +op_seq_maker = OpSeqMaker() +op_seq_maker.add_op(read_op) +op_seq_maker.add_op(general_infer_op) + +server = Server() +server.set_op_sequence(op_seq_maker.get_op_sequence()) +server.load_model_config(sys.argv[1]) +server.prepare_server(workdir="work_dir1", port=9292, device="cpu") +server.run_server() -- GitLab