未验证 提交 1e469e11 编写于 作者: Y yaoxuefeng 提交者: GitHub

add KDD 2019 Context-Aware Multi-Modal Transportation Recommendation PaddlePaddle baseline (#2243)

上级 f183b3fc
# Paddle_baseline_KDD2019
Paddle baseline for KDD2019 "Context-Aware Multi-Modal Transportation Recommendation"(https://dianshi.baidu.com/competition/29/question)
This repository is the demo codes for the KDD2019 "Context-Aware Multi-Modal Transportation Recommendation" competition using PaddlePaddle. It is written by python and uses PaddlePaddle to solve the task. Note that this repository is on developing and welcome everyone to contribute. The current baseline solution codes can get 0.68 - 0.69 score of online submission. As an example, my submission based on these networks programmed by PaddlePaddle is 0.6898.
The reason of the publication of this baseline codes is to encourage us to use PaddlePaddle and build the most powerful recommendation model via PaddlePaddle.
The example codes are ran on Linux, python2.7, single machine with CPU . Note that distributed train options are not provided here, if you want to learn more about this, please check more modes examples on https://github.com/PaddlePaddle/models. About the speed of training, for one epoch, 1000 batch size, it would take about 8 mins to train the whole training instances generated from raw data using SGD optimizer (it would take relatively longer using Adam optimizer).
The configuration and process of all the networks are fundamental, a lot of optimizations can be done based on them to achieve better results e.g. better cost function, more powerful feature engineering, designed model validation, NN optimization tricks...
The code is rough and from my daily use. They will be trimmed these days...
## Install PaddlePaddle
please visit the official site of PaddlePaddle(http://www.paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/install/index_cn.html)
## preprocess feature
```python
python preprocess_dense.py # change for different feature strategy
python pre_test_dense.py
```
preprocess.py and preprocess_dense.py is the code for preprocessing the raw data. Two versions are provided to deal with all sparse features and sparse plus dense features. Correspondingly, pre_process_test.py and pre_test_dense.py are the codes to preproccess test raw data. The training instances are saved in json. It is very easy to add new features. In our demo, all features are generated from provided raw data except for weather feature, which is gengerated from open weather records.
Note that the feature generated in this step need to fit in the input of the model input. Make sure we use the right version. In demo codes, The sparse plus dense features are used for network_confv6.
## build the network
main network logic is in network_confv?.py. The networks are base on fm & deep related algorithms. I try several networks and public some of them. There may be some defects in the networks but all of them are functional.
## train the network
```python
python local_train.py
```
In local_train.py and map_reader.py, I use dataset API, so we need to download the corresponding .whl package or clone codes on develop branch of PaddlePaddle. The reason to use this is the speed of feeding data is much faster.
Note that the input format feed into the network is self-defined. make sure we build the same format between training and test.
## test results
```python
python generate_test.py
python build_submit.py
```
In generate_test.py and build_submit, for convenience, I use the whole train data to train the network and test the network with provided data without label
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle CTR example")
parser.add_argument(
'--train_data_path',
type=str,
default='./data/raw/train.txt',
help="The path of training dataset")
parser.add_argument(
'--test_data_path',
type=str,
default='./data/raw/valid.txt',
help="The path of testing dataset")
parser.add_argument(
'--batch_size',
type=int,
default=1000,
help="The size of mini-batch (default:1000)")
parser.add_argument(
'--embedding_size',
type=int,
default=16,
help="The size for embedding layer (default:10)")
parser.add_argument(
'--num_passes',
type=int,
default=10,
help="The number of passes to train (default: 10)")
parser.add_argument(
'--model_output_dir',
type=str,
default='models',
help='The path for model to store (default: models)')
parser.add_argument(
'--sparse_feature_dim',
type=int,
default=1000001,
help='sparse feature hashing space for index processing')
parser.add_argument(
'--is_local',
type=int,
default=1,
help='Local train or distributed train (default: 1)')
parser.add_argument(
'--cloud_train',
type=int,
default=0,
help='Local train or distributed train on paddlecloud (default: 0)')
parser.add_argument(
'--async_mode',
action='store_true',
default=False,
help='Whether start pserver in async mode to support ASGD')
parser.add_argument(
'--no_split_var',
action='store_true',
default=False,
help='Whether split variables into blocks when update_method is pserver')
parser.add_argument(
'--role',
type=str,
default='pserver', # trainer or pserver
help='The path for model to store (default: models)')
parser.add_argument(
'--endpoints',
type=str,
default='127.0.0.1:6000',
help='The pserver endpoints, like: 127.0.0.1:6000,127.0.0.1:6001')
parser.add_argument(
'--current_endpoint',
type=str,
default='127.0.0.1:6000',
help='The path for model to store (default: 127.0.0.1:6000)')
parser.add_argument(
'--trainer_id',
type=int,
default=0,
help='The path for model to store (default: models)')
parser.add_argument(
'--trainers',
type=int,
default=1,
help='The num of trianers, (default: 1)')
return parser.parse_args()
import json
import csv
import io
def build():
submit_map = {}
with io.open('./submit/submit.csv', 'wb') as csv_file:
writer = csv.writer(csv_file, delimiter=',')
writer.writerow(['sid', 'recommend_mode'])
with open('./out/normed_test_session.txt', 'r') as f1:
with open('./testres/res8', 'r') as f2:
cur_session =''
for x, y in zip(f1.readlines(), f2.readlines()):
m1 = json.loads(x)
session_id = m1["session_id"]
if cur_session == '':
cur_session = session_id
transport_mode = m1["plan"]["transport_mode"]
if cur_session != session_id:
writer.writerow([str(cur_session), str(submit_map[cur_session]["transport_mode"])])
cur_session = session_id
if session_id not in submit_map:
submit_map[session_id] = {}
submit_map[session_id]["transport_mode"] = transport_mode
submit_map[session_id]["probability"] = y
#if int(submit_map[session_id]["transport_mode"]) == 0 and submit_map[session_id]["probability"] > 0.02:
#submit_map[session_id]["probability"] = 0.99
else:
if float(y) > float(submit_map[session_id]["probability"]):
submit_map[session_id]["transport_mode"] = transport_mode
submit_map[session_id]["probability"] = y
#if int(submit_map[session_id]["transport_mode"]) == 0 and submit_map[session_id]["probability"] > 0.02:
#submit_map[session_id]["transport_mode"] = 0
#submit_map[session_id]["probability"] = 0.99
writer.writerow([cur_session, submit_map[cur_session]["transport_mode"]])
if __name__ == "__main__":
build()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import numpy as np
# disable gpu training for this example
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import paddle
import paddle.fluid as fluid
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
num_context_feature = 22
def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle DeepFM example")
parser.add_argument(
'--model_path',
type=str,
#required=True,
default='models',
help="The path of model parameters gz file")
parser.add_argument(
'--data_path',
type=str,
required=False,
help="The path of the dataset to infer")
parser.add_argument(
'--embedding_size',
type=int,
default=16,
help="The size for embedding layer (default:10)")
parser.add_argument(
'--sparse_feature_dim',
type=int,
default=1000001,
help="The size for embedding layer (default:1000001)")
parser.add_argument(
'--batch_size',
type=int,
default=1000,
help="The size of mini-batch (default:1000)")
return parser.parse_args()
def to_lodtensor(data, place):
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def data2tensor(data, place):
feed_dict = {}
dense = data[0]
sparse = data[1:-1]
y = data[-1]
#user_data = np.array([x[0] for x in data]).astype("float32")
#user_data = user_data.reshape([-1, 10])
#feed_dict["user_profile"] = user_data
dense_data = np.array([x[0] for x in data]).astype("float32")
dense_data = dense_data.reshape([-1, 3])
feed_dict["dense_feature"] = dense_data
for i in range(num_context_feature):
sparse_data = to_lodtensor([x[1 + i] for x in data], place)
feed_dict["context" + str(i)] = sparse_data
context_fm = to_lodtensor(np.array([x[-2] for x in data]).astype("float32"), place)
feed_dict["context_fm"] = context_fm
y_data = np.array([x[-1] for x in data]).astype("int64")
y_data = y_data.reshape([-1, 1])
feed_dict["label"] = y_data
return feed_dict
def test():
args = parse_args()
place = fluid.CPUPlace()
test_scope = fluid.core.Scope()
# filelist = ["%s/%s" % (args.data_path, x) for x in os.listdir(args.data_path)]
from map_reader import MapDataset
map_dataset = MapDataset()
map_dataset.setup(args.sparse_feature_dim)
exe = fluid.Executor(place)
whole_filelist = ["./out/normed_test_session.txt"]
test_files = whole_filelist[int(0.0 * len(whole_filelist)):int(1.0 * len(whole_filelist))]
epochs = 1
for i in range(epochs):
cur_model_path = args.model_path + "/epoch" + str(1) + ".model"
with open("./testres/res" + str(i), 'w') as r:
with fluid.scope_guard(test_scope):
[inference_program, feed_target_names, fetch_targets] = \
fluid.io.load_inference_model(cur_model_path, exe)
test_reader = map_dataset.test_reader(test_files, 1000, 100000)
k = 0
for batch_id, data in enumerate(test_reader()):
print(len(data[0]))
feed_dict = data2tensor(data, place)
loss_val, auc_val, accuracy, predict, _ = exe.run(inference_program,
feed=feed_dict,
fetch_list=fetch_targets, return_numpy=False)
x = np.array(predict)
for j in range(x.shape[0]):
r.write(str(x[j][1]))
r.write("\n")
if __name__ == '__main__':
test()
import argparse
import logging
import numpy as np
# disable gpu training for this example
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import paddle
import paddle.fluid as fluid
import map_reader
from network_conf import ctr_deepfm_dataset
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle DeepFM example")
parser.add_argument(
'--model_path',
type=str,
#required=True,
default='models',
help="The path of model parameters gz file")
parser.add_argument(
'--data_path',
type=str,
required=False,
help="The path of the dataset to infer")
parser.add_argument(
'--embedding_size',
type=int,
default=16,
help="The size for embedding layer (default:10)")
parser.add_argument(
'--sparse_feature_dim',
type=int,
default=1000001,
help="The size for embedding layer (default:1000001)")
parser.add_argument(
'--batch_size',
type=int,
default=1000,
help="The size of mini-batch (default:1000)")
return parser.parse_args()
def to_lodtensor(data, place):
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def data2tensor(data, place):
feed_dict = {}
test_dict = {}
dense = data[0]
sparse = data[1:-1]
y = data[-1]
dense_data = np.array([x[0] for x in data]).astype("float32")
dense_data = dense_data.reshape([-1, 65])
feed_dict["user_profile"] = dense_data
for i in range(10):
sparse_data = to_lodtensor([x[1 + i] for x in data], place)
feed_dict["context" + str(i)] = sparse_data
y_data = np.array([x[-1] for x in data]).astype("int64")
y_data = y_data.reshape([-1, 1])
feed_dict["label"] = y_data
test_dict["test"] = [1]
return feed_dict, test_dict
def infer():
args = parse_args()
place = fluid.CPUPlace()
inference_scope = fluid.core.Scope()
filelist = ["%s/%s" % (args.data_path, x) for x in os.listdir(args.data_path)]
from map_reader import MapDataset
map_dataset = MapDataset()
map_dataset.setup(args.sparse_feature_dim)
exe = fluid.Executor(place)
whole_filelist = ["raw_data/part-%d" % x for x in range(len(os.listdir("raw_data")))]
#whole_filelist = ["./out/normed_train09", "./out/normed_train10", "./out/normed_train11"]
test_files = whole_filelist[int(0.0 * len(whole_filelist)):int(1.0 * len(whole_filelist))]
# file_groups = [whole_filelist[i:i+train_thread_num] for i in range(0, len(whole_filelist), train_thread_num)]
def set_zero(var_name):
param = inference_scope.var(var_name).get_tensor()
param_array = np.zeros(param._get_dims()).astype("int64")
param.set(param_array, place)
epochs = 2
for i in range(epochs):
cur_model_path = args.model_path + "/epoch" + str(i + 1) + ".model"
with fluid.scope_guard(inference_scope):
[inference_program, feed_target_names, fetch_targets] = \
fluid.io.load_inference_model(cur_model_path, exe)
auc_states_names = ['_generated_var_2', '_generated_var_3']
for name in auc_states_names:
set_zero(name)
test_reader = map_dataset.infer_reader(test_files, 1000, 100000)
for batch_id, data in enumerate(test_reader()):
loss_val, auc_val, accuracy, predict, label = exe.run(inference_program,
feed=data2tensor(data, place),
fetch_list=fetch_targets, return_numpy=False)
#print(np.array(predict))
#x = np.array(predict)
#print(.shape)x
#print("train_pass_%d, test_pass_%d\t%f\t" % (i - 1, i, auc_val))
if __name__ == '__main__':
infer()
from __future__ import print_function
from args import parse_args
import os
import paddle.fluid as fluid
import sys
from network_confv6 import ctr_deepfm_dataset
NUM_CONTEXT_FEATURE = 22
DIM_USER_PROFILE = 10
DIM_DENSE_FEATURE = 3
PYTHON_PATH = "/home/yaoxuefeng/whls/paddle_release_home/python/bin/python" # this is mine change yours
def train():
args = parse_args()
if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
#set the input format for our model. Note that you need to carefully modify them when you define a new network
#user_profile = fluid.layers.data(
#name="user_profile", shape=[DIM_USER_PROFILE], dtype='int64', lod_level=1)
dense_feature = fluid.layers.data(
name="dense_feature", shape=[DIM_DENSE_FEATURE], dtype='float32')
context_feature = [
fluid.layers.data(name="context" + str(i), shape=[1], lod_level=1, dtype="int64")
for i in range(0, NUM_CONTEXT_FEATURE)]
context_feature_fm = fluid.layers.data(
name="context_fm", shape=[1], dtype='int64', lod_level=1)
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
print("ready to network")
#self define network
loss, auc_var, batch_auc_var, accuracy, predict = ctr_deepfm_dataset(dense_feature, context_feature, context_feature_fm, label,
args.embedding_size, args.sparse_feature_dim)
print("ready to optimize")
optimizer = fluid.optimizer.SGD(learning_rate=1e-4)
optimizer.minimize(loss)
#single machine CPU training. more options on trainig please visit PaddlePaddle site
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
#use dataset api for much faster speed
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_use_var([dense_feature] + context_feature + [context_feature_fm] + [label])
#self define how to process generated training insatnces in map_reader.py
pipe_command = PYTHON_PATH + " map_reader.py %d" % args.sparse_feature_dim
dataset.set_pipe_command(pipe_command)
dataset.set_batch_size(args.batch_size)
thread_num = 1
dataset.set_thread(thread_num)
#self define how to split training files for example:"split -a 2 -d -l 200000 normed_train.txt normed_train"
whole_filelist = ["./out/normed_train%d" % x for x in range(len(os.listdir("out")))]
whole_filelist = ["./out/normed_train00", "./out/normed_train01", "./out/normed_train02", "./out/normed_train03",
"./out/normed_train04", "./out/normed_train05", "./out/normed_train06", "./out/normed_train07",
"./out/normed_train08",
"./out/normed_train09", "./out/normed_train10", "./out/normed_train11"]
print("ready to epochs")
epochs = 10
for i in range(epochs):
print("start %dth epoch" % i)
dataset.set_filelist(whole_filelist[:int(len(whole_filelist))])
#print the informations you want by setting fetch_list and fetch_info
exe.train_from_dataset(program=fluid.default_main_program(),
dataset=dataset,
fetch_list=[auc_var, accuracy, predict, label],
fetch_info=["auc", "accuracy", "predict", "label"],
debug=False)
model_dir = args.model_output_dir + '/epoch' + str(i + 1) + ".model"
sys.stderr.write("epoch%d finished" % (i + 1))
#save model
fluid.io.save_inference_model(model_dir, [dense_feature.name] + [x.name for x in context_feature] + [context_feature_fm.name] + [label.name],
[loss, auc_var, accuracy, predict, label], exe)
if __name__ == '__main__':
train()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import json
import paddle.fluid.incubate.data_generator as dg
class MapDataset(dg.MultiSlotDataGenerator):
def setup(self, sparse_feature_dim):
self.profile_length = 65
self.dense_length = 3
#feature names
self.dense_feature_list = ["distance", "price", "eta"]
self.pid_list = ["pid"]
self.query_feature_list = ["weekday", "hour", "o1", "o2", "d1", "d2"]
self.plan_feature_list = ["transport_mode"]
self.rank_feature_list = ["plan_rank", "whole_rank", "price_rank", "eta_rank", "distance_rank"]
self.rank_whole_pic_list = ["mode_rank1", "mode_rank2", "mode_rank3", "mode_rank4",
"mode_rank5"]
self.weather_feature_list = ["max_temp", "min_temp", "wea", "wind"]
self.hash_dim = 1000001
self.train_idx_ = 2000000
#carefully set if you change the features
self.categorical_range_ = range(0, 22)
#process one instance
def _process_line(self, line):
instance = json.loads(line)
"""
profile = instance["profile"]
len_profile = len(profile)
if len_profile >= 10:
user_profile_feature = profile[0:10]
else:
profile.extend([0]*(10-len_profile))
user_profile_feature = profile
if len(profile) > 1 or (len(profile) == 1 and profile[0] != 0):
for p in profile:
if p >= 1 and p <= 65:
user_profile_feature[p - 1] = 1
"""
context_feature = []
context_feature_fm = []
dense_feature = [0] * self.dense_length
plan = instance["plan"]
for i, val in enumerate(self.dense_feature_list):
dense_feature[i] = plan[val]
if (instance["pid"] == ""):
instance["pid"] = 0
query = instance["query"]
weather_dic = instance["weather"]
for fea in self.pid_list:
context_feature.append([hash(fea + str(instance[fea])) % self.hash_dim])
context_feature_fm.append(hash(fea + str(instance[fea])) % self.hash_dim)
for fea in self.query_feature_list:
context_feature.append([hash(fea + str(query[fea])) % self.hash_dim])
context_feature_fm.append(hash(fea + str(query[fea])) % self.hash_dim)
for fea in self.plan_feature_list:
context_feature.append([hash(fea + str(plan[fea])) % self.hash_dim])
context_feature_fm.append(hash(fea + str(plan[fea])) % self.hash_dim)
for fea in self.rank_feature_list:
context_feature.append([hash(fea + str(instance[fea])) % self.hash_dim])
context_feature_fm.append(hash(fea + str(instance[fea])) % self.hash_dim)
for fea in self.rank_whole_pic_list:
context_feature.append([hash(fea + str(instance[fea])) % self.hash_dim])
context_feature_fm.append(hash(fea + str(instance[fea])) % self.hash_dim)
for fea in self.weather_feature_list:
context_feature.append([hash(fea + str(weather_dic[fea])) % self.hash_dim])
context_feature_fm.append(hash(fea + str(weather_dic[fea])) % self.hash_dim)
label = [int(instance["label"])]
return dense_feature, context_feature, context_feature_fm, label
def infer_reader(self, filelist, batch, buf_size):
print(filelist)
def local_iter():
for fname in filelist:
with open(fname.strip(), "r") as fin:
for line in fin:
dense_feature, sparse_feature, sparse_feature_fm, label = self._process_line(line)
yield [dense_feature] + sparse_feature + [sparse_feature_fm] + [label]
import paddle
batch_iter = paddle.batch(
paddle.reader.shuffle(
local_iter, buf_size=buf_size),
batch_size=batch)
return batch_iter
#generat inputs for testing
def test_reader(self, filelist, batch, buf_size):
print(filelist)
def local_iter():
for fname in filelist:
with open(fname.strip(), "r") as fin:
for line in fin:
dense_feature, sparse_feature, sparse_feature_fm, label = self._process_line(line)
yield [dense_feature] + sparse_feature + [sparse_feature_fm] + [label]
import paddle
batch_iter = paddle.batch(
paddle.reader.buffered(
local_iter, size=buf_size),
batch_size=batch)
return batch_iter
#generate inputs for trainig
def generate_sample(self, line):
def data_iter():
dense_feature, sparse_feature, sparse_feature_fm, label = self._process_line(line)
#feature_name = ["user_profile"]
feature_name = []
feature_name.append("dense_feature")
for idx in self.categorical_range_:
feature_name.append("context" + str(idx))
feature_name.append("context_fm")
feature_name.append("label")
yield zip(feature_name, [dense_feature] + sparse_feature + [sparse_feature_fm] + [label])
return data_iter
if __name__ == "__main__":
map_dataset = MapDataset()
map_dataset.setup(int(sys.argv[1]))
map_dataset.run_from_stdin()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import math
user_profile_dim = 65
dense_feature_dim = 3
def ctr_deepfm_dataset(dense_feature, context_feature, context_feature_fm, label,
embedding_size, sparse_feature_dim):
def dense_fm_layer(input, emb_dict_size, factor_size, fm_param_attr):
first_order = fluid.layers.fc(input=input, size=1)
emb_table = fluid.layers.create_parameter(shape=[emb_dict_size, factor_size],
dtype='float32', attr=fm_param_attr)
input_mul_factor = fluid.layers.matmul(input, emb_table)
input_mul_factor_square = fluid.layers.square(input_mul_factor)
input_square = fluid.layers.square(input)
factor_square = fluid.layers.square(emb_table)
input_square_mul_factor_square = fluid.layers.matmul(input_square, factor_square)
second_order = 0.5 * (input_mul_factor_square - input_square_mul_factor_square)
return first_order, second_order
dense_fm_param_attr = fluid.param_attr.ParamAttr(name="DenseFeatFactors",
initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(dense_feature_dim)))
dense_fm_first, dense_fm_second = dense_fm_layer(
dense_feature, dense_feature_dim, 16, dense_fm_param_attr)
def sparse_fm_layer(input, emb_dict_size, factor_size, fm_param_attr):
first_embeddings = fluid.layers.embedding(
input=input, dtype='float32', size=[emb_dict_size, 1], is_sparse=True)
first_order = fluid.layers.sequence_pool(input=first_embeddings, pool_type='sum')
nonzero_embeddings = fluid.layers.embedding(
input=input, dtype='float32', size=[emb_dict_size, factor_size],
param_attr=fm_param_attr, is_sparse=True)
summed_features_emb = fluid.layers.sequence_pool(input=nonzero_embeddings, pool_type='sum')
summed_features_emb_square = fluid.layers.square(summed_features_emb)
squared_features_emb = fluid.layers.square(nonzero_embeddings)
squared_sum_features_emb = fluid.layers.sequence_pool(
input=squared_features_emb, pool_type='sum')
second_order = 0.5 * (summed_features_emb_square - squared_sum_features_emb)
return first_order, second_order
sparse_fm_param_attr = fluid.param_attr.ParamAttr(name="SparseFeatFactors",
initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(sparse_feature_dim)))
#data = fluid.layers.data(name='ids', shape=[1], dtype='float32')
sparse_fm_first, sparse_fm_second = sparse_fm_layer(
context_feature_fm, sparse_feature_dim, 16, sparse_fm_param_attr)
def embedding_layer(input):
return fluid.layers.embedding(
input=input,
is_sparse=True,
# you need to patch https://github.com/PaddlePaddle/Paddle/pull/14190
# if you want to set is_distributed to True
is_distributed=False,
size=[sparse_feature_dim, embedding_size],
param_attr=fluid.ParamAttr(name="SparseFeatFactors",
initializer=fluid.initializer.Uniform()))
sparse_embed_seq = list(map(embedding_layer, context_feature))
concated_ori = fluid.layers.concat(sparse_embed_seq + [dense_feature], axis=1)
concated = fluid.layers.batch_norm(input=concated_ori, name="bn", epsilon=1e-4)
deep = deep_net(concated)
predict = fluid.layers.fc(input=[deep, sparse_fm_first, sparse_fm_second, dense_fm_first, dense_fm_second], size=2, act="softmax",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(deep.shape[1])), learning_rate=0.01))
#similarity_norm = fluid.layers.sigmoid(fluid.layers.clip(predict, min=-15.0, max=15.0), name="similarity_norm")
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.reduce_sum(cost)
accuracy = fluid.layers.accuracy(input=predict, label=label)
auc_var, batch_auc_var, auc_states = \
fluid.layers.auc(input=predict, label=label, num_thresholds=2 ** 12, slide_steps=20)
return avg_cost, auc_var, batch_auc_var, accuracy, predict
def deep_net(concated, lr_x=0.0001):
fc_layers_input = [concated]
fc_layers_size = [400, 400, 400]
fc_layers_act = ["relu"] * (len(fc_layers_size))
for i in range(len(fc_layers_size)):
fc = fluid.layers.fc(
input=fc_layers_input[-1],
size=fc_layers_size[i],
act=fc_layers_act[i],
param_attr=fluid.ParamAttr(learning_rate=lr_x * 0.5))
fc_layers_input.append(fc)
#w_res = fluid.layers.create_parameter(shape=[353, 16], dtype='float32', name="w_res")
#high_path = fluid.layers.matmul(concated, w_res)
#return fluid.layers.elementwise_add(high_path, fc_layers_input[-1])
return fc_layers_input[-1]
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import math
user_profile_dim = 65
num_context = 25
dim_fm_vector = 16
dim_concated = user_profile_dim + dim_fm_vector * (num_context)
def ctr_deepfm_dataset(user_profile, context_feature, label,
embedding_size, sparse_feature_dim):
def embedding_layer(input):
return fluid.layers.embedding(
input=input,
is_sparse=True,
# you need to patch https://github.com/PaddlePaddle/Paddle/pull/14190
# if you want to set is_distributed to True
is_distributed=False,
size=[sparse_feature_dim, embedding_size],
param_attr=fluid.ParamAttr(name="SparseFeatFactors",
initializer=fluid.initializer.Uniform()))
sparse_embed_seq = list(map(embedding_layer, context_feature))
w = fluid.layers.create_parameter(
shape=[65, 65], dtype='float32',
name="w_fm")
user_profile_emb = fluid.layers.matmul(user_profile, w)
concated_ori = fluid.layers.concat(sparse_embed_seq + [user_profile_emb], axis=1)
concated = fluid.layers.batch_norm(input=concated_ori, name="bn", epsilon=1e-4)
deep = deep_net(concated)
linear_term, second_term = fm(concated, dim_concated, 8) #depend on the number of context feature
predict = fluid.layers.fc(input=[deep, linear_term, second_term], size=2, act="softmax",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(deep.shape[1])), learning_rate=0.01))
#similarity_norm = fluid.layers.sigmoid(fluid.layers.clip(predict, min=-15.0, max=15.0), name="similarity_norm")
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.reduce_sum(cost)
accuracy = fluid.layers.accuracy(input=predict, label=label)
auc_var, batch_auc_var, auc_states = \
fluid.layers.auc(input=predict, label=label, num_thresholds=2 ** 12, slide_steps=20)
return avg_cost, auc_var, batch_auc_var, accuracy, predict
def deep_net(concated, lr_x=0.0001):
fc_layers_input = [concated]
fc_layers_size = [128, 64, 32, 16]
fc_layers_act = ["relu"] * (len(fc_layers_size))
for i in range(len(fc_layers_size)):
fc = fluid.layers.fc(
input=fc_layers_input[-1],
size=fc_layers_size[i],
act=fc_layers_act[i],
param_attr=fluid.ParamAttr(learning_rate=lr_x * 0.5))
fc_layers_input.append(fc)
return fc_layers_input[-1]
def fm(concated, emb_dict_size, factor_size, lr_x=0.0001):
linear_term = fluid.layers.fc(input=concated, size=8, act=None, param_attr=fluid.ParamAttr(learning_rate=lr_x))
emb_table = fluid.layers.create_parameter(shape=[emb_dict_size, factor_size],
dtype='float32')
input_mul_factor = fluid.layers.matmul(concated, emb_table)
input_mul_factor_square = fluid.layers.square(input_mul_factor)
input_square = fluid.layers.square(concated)
factor_square = fluid.layers.square(emb_table)
input_square_mul_factor_square = fluid.layers.matmul(input_square, factor_square)
second_term = 0.5 * (input_mul_factor_square - input_square_mul_factor_square)
return linear_term, second_term
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import math
user_profile_dim = 65
slot_1 = [0, 1, 2, 3, 4, 5]
slot_2 = [6]
slot_3 = [7, 8, 9, 10, 11]
slot_4 = [12, 13, 14, 15, 16]
slot_5 = [17, 18, 19, 20]
num_context = 25
num_slots_pair = 5
dim_fm_vector = 16
dim_concated = user_profile_dim + dim_fm_vector * (num_context + num_slots_pair)
def ctr_deepfm_dataset(user_profile, dense_feature, context_feature, label,
embedding_size, sparse_feature_dim):
def embedding_layer(input):
return fluid.layers.embedding(
input=input,
is_sparse=True,
# you need to patch https://github.com/PaddlePaddle/Paddle/pull/14190
# if you want to set is_distributed to True
is_distributed=False,
size=[sparse_feature_dim, embedding_size],
param_attr=fluid.ParamAttr(name="SparseFeatFactors",
initializer=fluid.initializer.Uniform()))
sparse_embed_seq = list(map(embedding_layer, context_feature))
w = fluid.layers.create_parameter(
shape=[65, 65], dtype='float32',
name="w_fm")
user_emb_list = []
user_profile_emb = fluid.layers.matmul(user_profile, w)
user_emb_list.append(user_profile_emb)
user_emb_list.append(dense_feature)
w1 = fluid.layers.create_parameter(shape=[65, dim_fm_vector], dtype='float32', name="w_1")
w2 = fluid.layers.create_parameter(shape=[65, dim_fm_vector], dtype='float32', name="w_2")
w3 = fluid.layers.create_parameter(shape=[65, dim_fm_vector], dtype='float32', name="w_3")
w4 = fluid.layers.create_parameter(shape=[65, dim_fm_vector], dtype='float32', name="w_4")
w5 = fluid.layers.create_parameter(shape=[65, dim_fm_vector], dtype='float32', name="w_5")
user_profile_emb_1 = fluid.layers.matmul(user_profile, w1)
user_profile_emb_2 = fluid.layers.matmul(user_profile, w2)
user_profile_emb_3 = fluid.layers.matmul(user_profile, w3)
user_profile_emb_4 = fluid.layers.matmul(user_profile, w4)
user_profile_emb_5 = fluid.layers.matmul(user_profile, w5)
sparse_embed_seq_1 = embedding_layer(context_feature[slot_1[0]])
sparse_embed_seq_2 = embedding_layer(context_feature[slot_2[0]])
sparse_embed_seq_3 = embedding_layer(context_feature[slot_3[0]])
sparse_embed_seq_4 = embedding_layer(context_feature[slot_4[0]])
sparse_embed_seq_5 = embedding_layer(context_feature[slot_5[0]])
for i in slot_1[1:-1]:
sparse_embed_seq_1 = fluid.layers.elementwise_add(sparse_embed_seq_1, embedding_layer(context_feature[i]))
for i in slot_2[1:-1]:
sparse_embed_seq_2 = fluid.layers.elementwise_add(sparse_embed_seq_2, embedding_layer(context_feature[i]))
for i in slot_3[1:-1]:
sparse_embed_seq_3 = fluid.layers.elementwise_add(sparse_embed_seq_3, embedding_layer(context_feature[i]))
for i in slot_4[1:-1]:
sparse_embed_seq_4 = fluid.layers.elementwise_add(sparse_embed_seq_4, embedding_layer(context_feature[i]))
for i in slot_5[1:-1]:
sparse_embed_seq_5 = fluid.layers.elementwise_add(sparse_embed_seq_5, embedding_layer(context_feature[i]))
ele_product_1 = fluid.layers.elementwise_mul(user_profile_emb_1, sparse_embed_seq_1)
user_emb_list.append(ele_product_1)
ele_product_2 = fluid.layers.elementwise_mul(user_profile_emb_2, sparse_embed_seq_2)
user_emb_list.append(ele_product_2)
ele_product_3 = fluid.layers.elementwise_mul(user_profile_emb_3, sparse_embed_seq_3)
user_emb_list.append(ele_product_3)
ele_product_4 = fluid.layers.elementwise_mul(user_profile_emb_4, sparse_embed_seq_4)
user_emb_list.append(ele_product_4)
ele_product_5 = fluid.layers.elementwise_mul(user_profile_emb_5, sparse_embed_seq_5)
user_emb_list.append(ele_product_5)
ffm_1 = fluid.layers.reduce_sum(ele_product_1, dim=1, keep_dim=True)
ffm_2 = fluid.layers.reduce_sum(ele_product_2, dim=1, keep_dim=True)
ffm_3 = fluid.layers.reduce_sum(ele_product_3, dim=1, keep_dim=True)
ffm_4 = fluid.layers.reduce_sum(ele_product_4, dim=1, keep_dim=True)
ffm_5 = fluid.layers.reduce_sum(ele_product_5, dim=1, keep_dim=True)
concated_ori = fluid.layers.concat(sparse_embed_seq + user_emb_list, axis=1)
concated = fluid.layers.batch_norm(input=concated_ori, name="bn", epsilon=1e-4)
deep = deep_net(concated)
linear_term, second_term = fm(concated, dim_concated, 8) #depend on the number of context feature
predict = fluid.layers.fc(input=[deep, linear_term, second_term, ffm_1, ffm_2, ffm_3, ffm_4, ffm_5], size=2, act="softmax",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(deep.shape[1])), learning_rate=0.01))
#similarity_norm = fluid.layers.sigmoid(fluid.layers.clip(predict, min=-15.0, max=15.0), name="similarity_norm")
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.reduce_sum(cost)
accuracy = fluid.layers.accuracy(input=predict, label=label)
auc_var, batch_auc_var, auc_states = \
fluid.layers.auc(input=predict, label=label, num_thresholds=2 ** 12, slide_steps=20)
return avg_cost, auc_var, batch_auc_var, accuracy, predict
def deep_net(concated, lr_x=0.0001):
fc_layers_input = [concated]
fc_layers_size = [256, 128, 64, 32, 16]
fc_layers_act = ["relu"] * (len(fc_layers_size))
for i in range(len(fc_layers_size)):
fc = fluid.layers.fc(
input=fc_layers_input[-1],
size=fc_layers_size[i],
act=fc_layers_act[i],
param_attr=fluid.ParamAttr(learning_rate=lr_x * 0.5))
fc_layers_input.append(fc)
w_res = fluid.layers.create_parameter(shape=[dim_concated, 16], dtype='float32', name="w_res")
high_path = fluid.layers.matmul(concated, w_res)
return fluid.layers.elementwise_add(high_path, fc_layers_input[-1])
#return fc_layers_input[-1]
def fm(concated, emb_dict_size, factor_size, lr_x=0.0001):
linear_term = fluid.layers.fc(input=concated, size=8, act=None, param_attr=fluid.ParamAttr(learning_rate=lr_x))
emb_table = fluid.layers.create_parameter(shape=[emb_dict_size, factor_size],
dtype='float32')
input_mul_factor = fluid.layers.matmul(concated, emb_table)
input_mul_factor_square = fluid.layers.square(input_mul_factor)
input_square = fluid.layers.square(concated)
factor_square = fluid.layers.square(emb_table)
input_square_mul_factor_square = fluid.layers.matmul(input_square, factor_square)
second_term = 0.5 * (input_mul_factor_square - input_square_mul_factor_square)
return linear_term, second_term
\ No newline at end of file
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import math
user_profile_dim = 65
dense_feature_dim = 3
def ctr_deepfm_dataset(dense_feature, context_feature, context_feature_fm, label,
embedding_size, sparse_feature_dim):
def dense_fm_layer(input, emb_dict_size, factor_size, fm_param_attr):
first_order = fluid.layers.fc(input=input, size=1)
emb_table = fluid.layers.create_parameter(shape=[emb_dict_size, factor_size],
dtype='float32', attr=fm_param_attr)
input_mul_factor = fluid.layers.matmul(input, emb_table)
input_mul_factor_square = fluid.layers.square(input_mul_factor)
input_square = fluid.layers.square(input)
factor_square = fluid.layers.square(emb_table)
input_square_mul_factor_square = fluid.layers.matmul(input_square, factor_square)
second_order = 0.5 * (input_mul_factor_square - input_square_mul_factor_square)
return first_order, second_order
dense_fm_param_attr = fluid.param_attr.ParamAttr(name="DenseFeatFactors",
initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(dense_feature_dim)))
dense_fm_first, dense_fm_second = dense_fm_layer(
dense_feature, dense_feature_dim, 16, dense_fm_param_attr)
def sparse_fm_layer(input, emb_dict_size, factor_size, fm_param_attr):
first_embeddings = fluid.layers.embedding(
input=input, dtype='float32', size=[emb_dict_size, 1], is_sparse=True)
first_order = fluid.layers.sequence_pool(input=first_embeddings, pool_type='sum')
nonzero_embeddings = fluid.layers.embedding(
input=input, dtype='float32', size=[emb_dict_size, factor_size],
param_attr=fm_param_attr, is_sparse=True)
summed_features_emb = fluid.layers.sequence_pool(input=nonzero_embeddings, pool_type='sum')
summed_features_emb_square = fluid.layers.square(summed_features_emb)
squared_features_emb = fluid.layers.square(nonzero_embeddings)
squared_sum_features_emb = fluid.layers.sequence_pool(
input=squared_features_emb, pool_type='sum')
second_order = 0.5 * (summed_features_emb_square - squared_sum_features_emb)
return first_order, second_order
sparse_fm_param_attr = fluid.param_attr.ParamAttr(name="SparseFeatFactors",
initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(sparse_feature_dim)))
#data = fluid.layers.data(name='ids', shape=[1], dtype='float32')
sparse_fm_first, sparse_fm_second = sparse_fm_layer(
context_feature_fm, sparse_feature_dim, 16, sparse_fm_param_attr)
def embedding_layer(input):
return fluid.layers.embedding(
input=input,
is_sparse=True,
# you need to patch https://github.com/PaddlePaddle/Paddle/pull/14190
# if you want to set is_distributed to True
is_distributed=False,
size=[sparse_feature_dim, embedding_size],
param_attr=fluid.ParamAttr(name="SparseFeatFactors",
initializer=fluid.initializer.Uniform()))
sparse_embed_seq = list(map(embedding_layer, context_feature))
concated_ori = fluid.layers.concat(sparse_embed_seq + [dense_feature], axis=1)
concated = fluid.layers.batch_norm(input=concated_ori, name="bn", epsilon=1e-4)
deep = deep_net(concated)
predict = fluid.layers.fc(input=[deep, sparse_fm_first, sparse_fm_second, dense_fm_first, dense_fm_second], size=2, act="softmax",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(deep.shape[1])), learning_rate=0.01))
#similarity_norm = fluid.layers.sigmoid(fluid.layers.clip(predict, min=-15.0, max=15.0), name="similarity_norm")
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.reduce_sum(cost)
accuracy = fluid.layers.accuracy(input=predict, label=label)
auc_var, batch_auc_var, auc_states = \
fluid.layers.auc(input=predict, label=label, num_thresholds=2 ** 12, slide_steps=20)
return avg_cost, auc_var, batch_auc_var, accuracy, predict
def deep_net(concated, lr_x=0.0001):
fc_layers_input = [concated]
fc_layers_size = [400, 400, 400]
fc_layers_act = ["relu"] * (len(fc_layers_size))
for i in range(len(fc_layers_size)):
fc = fluid.layers.fc(
input=fc_layers_input[-1],
size=fc_layers_size[i],
act=fc_layers_act[i],
param_attr=fluid.ParamAttr(learning_rate=lr_x * 0.5))
fc_layers_input.append(fc)
#w_res = fluid.layers.create_parameter(shape=[353, 16], dtype='float32', name="w_res")
#high_path = fluid.layers.matmul(concated, w_res)
#return fluid.layers.elementwise_add(high_path, fc_layers_input[-1])
return fc_layers_input[-1]
\ No newline at end of file
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys, time, random, csv, datetime, json
import pandas as pd
import numpy as np
import argparse
import logging
import time
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("preprocess")
logger.setLevel(logging.INFO)
TEST_QUERIES_PATH = "./data_set_phase1/test_queries.csv"
TEST_PLANS_PATH = "./data_set_phase1/test_plans.csv"
TRAIN_CLICK_PATH = "./data_set_phase1/train_clicks.csv"
PROFILES_PATH = "./data_set_phase1/profiles.csv"
OUT_NORM_TEST_PATH = "./out/normed_test_session.txt"
OUT_RAW_TEST_PATH = "./out/test_session.txt"
O1_MIN = 115.47
O1_MAX = 117.29
O2_MIN = 39.46
O2_MAX = 40.97
D1_MIN = 115.44
D1_MAX = 117.37
D2_MIN = 39.46
D2_MAX = 40.96
SCALE_OD = 0.02
DISTANCE_MIN = 1.0
DISTANCE_MAX = 225864.0
THRESHOLD_DIS = 40000.0
SCALE_DIS = 500
PRICE_MIN = 200.0
PRICE_MAX = 92300.0
THRESHOLD_PRICE = 20000
SCALE_PRICE = 100
ETA_MIN = 1.0
ETA_MAX = 72992.0
THRESHOLD_ETA = 10800.0
SCALE_ETA = 120
def build_norm_feature():
with open(OUT_NORM_TEST_PATH, 'w') as nf:
with open(OUT_RAW_TEST_PATH, 'r') as f:
for line in f:
cur_map = json.loads(line)
if cur_map["plan"]["distance"] > THRESHOLD_DIS:
cur_map["plan"]["distance"] = int(THRESHOLD_DIS)
elif cur_map["plan"]["distance"] > 0:
cur_map["plan"]["distance"] = int(cur_map["plan"]["distance"] / SCALE_DIS)
if cur_map["plan"]["price"] and cur_map["plan"]["price"] > THRESHOLD_PRICE:
cur_map["plan"]["price"] = int(THRESHOLD_PRICE)
elif not cur_map["plan"]["price"] or cur_map["plan"]["price"] < 0:
cur_map["plan"]["price"] = 0
else:
cur_map["plan"]["price"] = int(cur_map["plan"]["price"] / SCALE_PRICE)
if cur_map["plan"]["eta"] > THRESHOLD_ETA:
cur_map["plan"]["eta"] = int(THRESHOLD_ETA)
elif cur_map["plan"]["eta"] > 0:
cur_map["plan"]["eta"] = int(cur_map["plan"]["eta"] / SCALE_ETA)
# o1
if cur_map["query"]["o1"] > O1_MAX:
cur_map["query"]["o1"] = int((O1_MAX - O1_MIN) / SCALE_OD + 1)
elif cur_map["query"]["o1"] < O1_MIN:
cur_map["query"]["o1"] = 0
else:
cur_map["query"]["o1"] = int((cur_map["query"]["o1"] - O1_MIN) / 0.02)
# o2
if cur_map["query"]["o2"] > O2_MAX:
cur_map["query"]["o2"] = int((O2_MAX - O2_MIN) / SCALE_OD + 1)
elif cur_map["query"]["o2"] < O2_MIN:
cur_map["query"]["o2"] = 0
else:
cur_map["query"]["o2"] = int((cur_map["query"]["o2"] - O2_MIN) / 0.02)
# d1
if cur_map["query"]["d1"] > D1_MAX:
cur_map["query"]["d1"] = int((D1_MAX - D1_MIN) / SCALE_OD + 1)
elif cur_map["query"]["d1"] < D1_MIN:
cur_map["query"]["d1"] = 0
else:
cur_map["query"]["d1"] = int((cur_map["query"]["d1"] - D1_MIN) / SCALE_OD)
# d2
if cur_map["query"]["d2"] > D2_MAX:
cur_map["query"]["d2"] = int((D2_MAX - D2_MIN) / SCALE_OD + 1)
elif cur_map["query"]["d2"] < D2_MIN:
cur_map["query"]["d2"] = 0
else:
cur_map["query"]["d2"] = int((cur_map["query"]["d2"] - D2_MIN) / SCALE_OD)
cur_json_instance = json.dumps(cur_map)
nf.write(cur_json_instance + '\n')
def preprocess():
"""
Construct the train data indexed by session id and mode id jointly. Convert some of the raw features (user profile,
od pair, req time, click time, eta, price, distance, transport mode) to one-hot ids used for
embedding. We split the one-hot features into two categories: user feature and context feature for
better understanding of FM algorithm.
Note that the user profile is already provided by one-hot encoded form, we convert it back to the
ids for unity with the context feature and easily using of PaddlePaddle embedding layer. Given the
train clicks data, we label each train instance with 1 or 0 depend on if this instance is clicked or
not.
:return:
"""
train_data_dict = {}
with open("./weather.json", 'r') as f:
weather_dict = json.load(f)
with open(TEST_QUERIES_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
train_index_list = []
for k, line in enumerate(csv_reader):
if k == 0: continue
if line[0] == "": continue
if line[1] == "":
train_index_list.append(line[0] + "_0")
else:
train_index_list.append(line[0] + "_" + line[1])
train_index = line[0]
train_data_dict[train_index] = {}
train_data_dict[train_index]["pid"] = line[1]
train_data_dict[train_index]["query"] = {}
reqweekday = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%w")
reqhour = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%H")
date_key = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%m-%d")
train_data_dict[train_index]["weather"] = {}
train_data_dict[train_index]["weather"].update({"max_temp": weather_dict[date_key]["max_temp"]})
train_data_dict[train_index]["weather"].update({"min_temp": weather_dict[date_key]["min_temp"]})
train_data_dict[train_index]["weather"].update({"wea": weather_dict[date_key]["weather"]})
train_data_dict[train_index]["weather"].update({"wind": weather_dict[date_key]["wind"]})
train_data_dict[train_index]["query"].update({"weekday":reqweekday})
train_data_dict[train_index]["query"].update({"hour":reqhour})
o = line[3].split(',')
o_first = o[0]
o_second = o[1]
train_data_dict[train_index]["query"].update({"o1":float(o_first)})
train_data_dict[train_index]["query"].update({"o2":float(o_second)})
d = line[4].split(',')
d_first = d[0]
d_second = d[1]
train_data_dict[train_index]["query"].update({"d1":float(d_first)})
train_data_dict[train_index]["query"].update({"d2":float(d_second)})
plan_map = {}
plan_data = pd.read_csv(TEST_PLANS_PATH)
for index, row in plan_data.iterrows():
plans_str = row['plans']
plans_list = json.loads(plans_str)
session_id = str(row['sid'])
# train_data_dict[session_id]["plans"] = []
plan_map[session_id] = plans_list
profile_map = {}
with open(PROFILES_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
for k, line in enumerate(csv_reader):
if k == 0: continue
profile_map[line[0]] = [i for i in range(len(line)) if line[i] == "1.0"]
session_click_map = {}
with open(TRAIN_CLICK_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
for k, line in enumerate(csv_reader):
if k == 0: continue
if line[0] == "" or line[1] == "" or line[2] == "":
continue
session_click_map[line[0]] = line[2]
#return train_data_dict, profile_map, session_click_map, plan_map
generate_sparse_features(train_data_dict, profile_map, session_click_map, plan_map)
def generate_sparse_features(train_data_dict, profile_map, session_click_map, plan_map):
if not os.path.isdir("./out/"):
os.mkdir("./out/")
with open(os.path.join("./out/", "test_session.txt"), 'w') as f_train:
for session_id, plan_list in plan_map.items():
if session_id not in train_data_dict:
continue
cur_map = train_data_dict[session_id]
cur_map["session_id"] = session_id
if cur_map["pid"] != "":
cur_map["profile"] = profile_map[cur_map["pid"]]
else:
cur_map["profile"] = [0]
del cur_map["pid"]
whole_rank = 0
for plan in plan_list:
whole_rank += 1
cur_map["mode_rank" + str(whole_rank)] = plan["transport_mode"]
if whole_rank < 5:
for r in range(whole_rank + 1, 6):
cur_map["mode_rank" + str(r)] = -1
cur_map["whole_rank"] = whole_rank
flag_click = False
rank = 1
price_list = []
eta_list = []
distance_list = []
for plan in plan_list:
if not plan["price"]:
price_list.append(0)
else:
price_list.append(int(plan["price"]))
eta_list.append(int(plan["eta"]))
distance_list.append(int(plan["distance"]))
price_list.sort(reverse=False)
eta_list.sort(reverse=False)
distance_list.sort(reverse=False)
for plan in plan_list:
if plan["price"] and int(plan["price"]) == price_list[0]:
cur_map["mode_min_price"] = plan["transport_mode"]
if plan["price"] and int(plan["price"]) == price_list[-1]:
cur_map["mode_max_price"] = plan["transport_mode"]
if int(plan["eta"]) == eta_list[0]:
cur_map["mode_min_eta"] = plan["transport_mode"]
if int(plan["eta"]) == eta_list[-1]:
cur_map["mode_max_eta"] = plan["transport_mode"]
if int(plan["distance"]) == distance_list[0]:
cur_map["mode_min_distance"] = plan["transport_mode"]
if int(plan["distance"]) == distance_list[-1]:
cur_map["mode_max_distance"] = plan["transport_mode"]
if "mode_min_price" not in cur_map:
cur_map["mode_min_price"] = -1
if "mode_max_price" not in cur_map:
cur_map["mode_max_price"] = -1
for plan in plan_list:
cur_price = int(plan["price"]) if plan["price"] else 0
cur_eta = int(plan["eta"])
cur_distance = int(plan["distance"])
cur_map["price_rank"] = price_list.index(cur_price) + 1
cur_map["eta_rank"] = eta_list.index(cur_eta) + 1
cur_map["distance_rank"] = distance_list.index(cur_distance) + 1
if ("transport_mode" in plan) and (session_id in session_click_map) and (
int(plan["transport_mode"]) == int(session_click_map[session_id])):
cur_map["plan"] = plan
cur_map["label"] = 1
flag_click = True
# print("label is 1")
else:
cur_map["plan"] = plan
cur_map["label"] = 0
cur_map["plan_rank"] = rank
rank += 1
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
cur_map["plan"]["distance"] = -1
cur_map["plan"]["price"] = -1
cur_map["plan"]["eta"] = -1
cur_map["plan"]["transport_mode"] = 0
cur_map["plan_rank"] = 0
cur_map["price_rank"] = 0
cur_map["eta_rank"] = 0
cur_map["plan_rank"] = 0
cur_map["label"] = 1
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
build_norm_feature()
if __name__ == "__main__":
preprocess()
\ No newline at end of file
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys, time, random, csv, datetime, json
import pandas as pd
import numpy as np
import argparse
import logging
import time
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("preprocess")
logger.setLevel(logging.INFO)
TRAIN_QUERIES_PATH = "./data_set_phase1/test_queries.csv"
TRAIN_PLANS_PATH = "./data_set_phase1/test_plans.csv"
TRAIN_CLICK_PATH = "./data_set_phase1/train_clicks.csv"
PROFILES_PATH = "./data_set_phase1/profiles.csv"
O1_MIN = 115.47
O1_MAX = 117.29
O2_MIN = 39.46
O2_MAX = 40.97
D1_MIN = 115.44
D1_MAX = 117.37
D2_MIN = 39.46
D2_MAX = 40.96
DISTANCE_MIN = 1.0
DISTANCE_MAX = 225864.0
THRESHOLD_DIS = 200000.0
PRICE_MIN = 200.0
PRICE_MAX = 92300.0
THRESHOLD_PRICE = 20000
ETA_MIN = 1.0
ETA_MAX = 72992.0
THRESHOLD_ETA = 10800.0
def build_norm_feature():
with open("./out/normed_test_session.txt", 'w') as nf:
with open("./out/test_session.txt", 'r') as f:
for line in f:
cur_map = json.loads(line)
cur_map["plan"]["distance"] = (cur_map["plan"]["distance"] - DISTANCE_MIN) / (DISTANCE_MAX - DISTANCE_MIN)
if cur_map["plan"]["price"]:
cur_map["plan"]["price"] = (cur_map["plan"]["price"] - PRICE_MIN) / (PRICE_MAX - PRICE_MIN)
else:
cur_map["plan"]["price"] = 0.0
cur_map["plan"]["eta"] = (cur_map["plan"]["eta"] - ETA_MIN) / (ETA_MAX - ETA_MIN)
cur_json_instance = json.dumps(cur_map)
nf.write(cur_json_instance + '\n')
def preprocess():
"""
Construct the train data indexed by session id and mode id jointly. Convert all the raw features (user profile,
od pair, req time, click time, eta, price, distance, transport mode) to one-hot ids used for
embedding. We split the one-hot features into two categories: user feature and context feature for
better understanding of FFM algorithm.
Note that the user profile is already provided by one-hot encoded form, we convert it back to the
ids for unity with the context feature and easily using of PaddlePaddle embedding layer. Given the
train clicks data, we label each train instance with 1 or 0 depend on if this instance is clicked or
not.
:return:
"""
#args = parse_args()
train_data_dict = {}
with open("./weather.json", 'r') as f:
weather_dict = json.load(f)
with open(TRAIN_QUERIES_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
train_index_list = []
for k, line in enumerate(csv_reader):
if k == 0: continue
if line[0] == "": continue
if line[1] == "":
train_index_list.append(line[0] + "_0")
else:
train_index_list.append(line[0] + "_" + line[1])
train_index = line[0]
train_data_dict[train_index] = {}
train_data_dict[train_index]["pid"] = line[1]
train_data_dict[train_index]["query"] = {}
reqweekday = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%w")
reqhour = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%H")
date_key = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%m-%d")
train_data_dict[train_index]["weather"] = {}
train_data_dict[train_index]["weather"].update({"max_temp": weather_dict[date_key]["max_temp"]})
train_data_dict[train_index]["weather"].update({"min_temp": weather_dict[date_key]["min_temp"]})
train_data_dict[train_index]["weather"].update({"wea": weather_dict[date_key]["weather"]})
train_data_dict[train_index]["weather"].update({"wind": weather_dict[date_key]["wind"]})
train_data_dict[train_index]["query"].update({"weekday":reqweekday})
train_data_dict[train_index]["query"].update({"hour":reqhour})
o = line[3].split(',')
o_first = o[0]
o_second = o[1]
train_data_dict[train_index]["query"].update({"o1":float(o_first)})
train_data_dict[train_index]["query"].update({"o2":float(o_second)})
d = line[4].split(',')
d_first = d[0]
d_second = d[1]
train_data_dict[train_index]["query"].update({"d1":float(d_first)})
train_data_dict[train_index]["query"].update({"d2":float(d_second)})
plan_map = {}
plan_data = pd.read_csv(TRAIN_PLANS_PATH)
for index, row in plan_data.iterrows():
plans_str = row['plans']
plans_list = json.loads(plans_str)
session_id = str(row['sid'])
# train_data_dict[session_id]["plans"] = []
plan_map[session_id] = plans_list
profile_map = {}
with open(PROFILES_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
for k, line in enumerate(csv_reader):
if k == 0: continue
profile_map[line[0]] = [i for i in range(len(line)) if line[i] == "1.0"]
session_click_map = {}
with open(TRAIN_CLICK_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
for k, line in enumerate(csv_reader):
if k == 0: continue
if line[0] == "" or line[1] == "" or line[2] == "":
continue
session_click_map[line[0]] = line[2]
#return train_data_dict, profile_map, session_click_map, plan_map
generate_sparse_features(train_data_dict, profile_map, session_click_map, plan_map)
def generate_sparse_features(train_data_dict, profile_map, session_click_map, plan_map):
if not os.path.isdir("./out/"):
os.mkdir("./out/")
with open(os.path.join("./out/", "test_session.txt"), 'w') as f_train:
for session_id, plan_list in plan_map.items():
if session_id not in train_data_dict:
continue
cur_map = train_data_dict[session_id]
cur_map["session_id"] = session_id
if cur_map["pid"] != "":
cur_map["profile"] = profile_map[cur_map["pid"]]
else:
cur_map["profile"] = [0]
# del cur_map["pid"]
whole_rank = 0
for plan in plan_list:
whole_rank += 1
cur_map["mode_rank" + str(whole_rank)] = plan["transport_mode"]
if whole_rank < 5:
for r in range(whole_rank + 1, 6):
cur_map["mode_rank" + str(r)] = -1
cur_map["whole_rank"] = whole_rank
rank = 1
price_list = []
eta_list = []
distance_list = []
for plan in plan_list:
if not plan["price"]:
price_list.append(0)
else:
price_list.append(int(plan["price"]))
eta_list.append(int(plan["eta"]))
distance_list.append(int(plan["distance"]))
price_list.sort(reverse=False)
eta_list.sort(reverse=False)
distance_list.sort(reverse=False)
for plan in plan_list:
if plan["price"] and int(plan["price"]) == price_list[0]:
cur_map["mode_min_price"] = plan["transport_mode"]
if plan["price"] and int(plan["price"]) == price_list[-1]:
cur_map["mode_max_price"] = plan["transport_mode"]
if int(plan["eta"]) == eta_list[0]:
cur_map["mode_min_eta"] = plan["transport_mode"]
if int(plan["eta"]) == eta_list[-1]:
cur_map["mode_max_eta"] = plan["transport_mode"]
if int(plan["distance"]) == distance_list[0]:
cur_map["mode_min_distance"] = plan["transport_mode"]
if int(plan["distance"]) == distance_list[-1]:
cur_map["mode_max_distance"] = plan["transport_mode"]
if "mode_min_price" not in cur_map:
cur_map["mode_min_price"] = -1
if "mode_max_price" not in cur_map:
cur_map["mode_max_price"] = -1
for plan in plan_list:
cur_price = int(plan["price"]) if plan["price"] else 0
cur_eta = int(plan["eta"])
cur_distance = int(plan["distance"])
cur_map["price_rank"] = price_list.index(cur_price) + 1
cur_map["eta_rank"] = eta_list.index(cur_eta) + 1
cur_map["distance_rank"] = distance_list.index(cur_distance) + 1
if ("transport_mode" in plan) and (session_id in session_click_map) and (
int(plan["transport_mode"]) == int(session_click_map[session_id])):
cur_map["plan"] = plan
cur_map["label"] = 1
else:
cur_map["plan"] = plan
cur_map["label"] = 0
cur_map["plan_rank"] = rank
rank += 1
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
cur_map["plan"]["distance"] = -1
cur_map["plan"]["price"] = -1
cur_map["plan"]["eta"] = -1
cur_map["plan"]["transport_mode"] = 0
cur_map["plan_rank"] = 0
cur_map["price_rank"] = 0
cur_map["eta_rank"] = 0
cur_map["plan_rank"] = 0
cur_map["label"] = 1
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
build_norm_feature()
if __name__ == "__main__":
preprocess()
\ No newline at end of file
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys, time, random, csv, datetime, json
import pandas as pd
import numpy as np
import argparse
import logging
import time
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("preprocess")
logger.setLevel(logging.INFO)
TRAIN_QUERIES_PATH = "./data_set_phase1/train_queries.csv"
TRAIN_PLANS_PATH = "./data_set_phase1/train_plans.csv"
TRAIN_CLICK_PATH = "./data_set_phase1/train_clicks.csv"
PROFILES_PATH = "./data_set_phase1/profiles.csv"
OUT_NORM_TRAIN_PATH = "./out/normed_train.txt"
OUT_RAW_TRAIN_PATH = "./out/train.txt"
OUT_DIR = "./out"
O1_MIN = 115.47
O1_MAX = 117.29
O2_MIN = 39.46
O2_MAX = 40.97
D1_MIN = 115.44
D1_MAX = 117.37
D2_MIN = 39.46
D2_MAX = 40.96
SCALE_OD = 0.02
DISTANCE_MIN = 1.0
DISTANCE_MAX = 225864.0
THRESHOLD_DIS = 40000.0
SCALE_DIS = 500
PRICE_MIN = 200.0
PRICE_MAX = 92300.0
THRESHOLD_PRICE = 20000
SCALE_PRICE = 100
ETA_MIN = 1.0
ETA_MAX = 72992.0
THRESHOLD_ETA = 10800.0
SCALE_ETA = 120
def build_norm_feature():
with open(OUT_NORM_TRAIN_PATH, 'w') as nf:
with open(OUT_RAW_TRAIN_PATH, 'r') as f:
for line in f:
cur_map = json.loads(line)
if cur_map["plan"]["distance"] > THRESHOLD_DIS:
cur_map["plan"]["distance"] = int(THRESHOLD_DIS)
elif cur_map["plan"]["distance"] > 0:
cur_map["plan"]["distance"] = int(cur_map["plan"]["distance"] / SCALE_DIS)
if cur_map["plan"]["price"] and cur_map["plan"]["price"] > THRESHOLD_PRICE:
cur_map["plan"]["price"] = int(THRESHOLD_PRICE)
elif not cur_map["plan"]["price"] or cur_map["plan"]["price"] < 0:
cur_map["plan"]["price"] = 0
else:
cur_map["plan"]["price"] = int(cur_map["plan"]["price"] / SCALE_PRICE)
if cur_map["plan"]["eta"] > THRESHOLD_ETA:
cur_map["plan"]["eta"] = int(THRESHOLD_ETA)
elif cur_map["plan"]["eta"] > 0:
cur_map["plan"]["eta"] = int(cur_map["plan"]["eta"] / SCALE_ETA)
# o1
if cur_map["query"]["o1"] > O1_MAX:
cur_map["query"]["o1"] = int((O1_MAX - O1_MIN) / SCALE_OD + 1)
elif cur_map["query"]["o1"] < O1_MIN:
cur_map["query"]["o1"] = 0
else:
cur_map["query"]["o1"] = int((cur_map["query"]["o1"] - O1_MIN) / 0.02)
# o2
if cur_map["query"]["o2"] > O2_MAX:
cur_map["query"]["o2"] = int((O2_MAX - O2_MIN) / SCALE_OD + 1)
elif cur_map["query"]["o2"] < O2_MIN:
cur_map["query"]["o2"] = 0
else:
cur_map["query"]["o2"] = int((cur_map["query"]["o2"] - O2_MIN) / 0.02)
# d1
if cur_map["query"]["d1"] > D1_MAX:
cur_map["query"]["d1"] = int((D1_MAX - D1_MIN) / SCALE_OD + 1)
elif cur_map["query"]["d1"] < D1_MIN:
cur_map["query"]["d1"] = 0
else:
cur_map["query"]["d1"] = int((cur_map["query"]["d1"] - D1_MIN) / SCALE_OD)
# d2
if cur_map["query"]["d2"] > D2_MAX:
cur_map["query"]["d2"] = int((D2_MAX - D2_MIN) / SCALE_OD + 1)
elif cur_map["query"]["d2"] < D2_MIN:
cur_map["query"]["d2"] = 0
else:
cur_map["query"]["d2"] = int((cur_map["query"]["d2"] - D2_MIN) / SCALE_OD)
cur_json_instance = json.dumps(cur_map)
nf.write(cur_json_instance + '\n')
def preprocess():
"""
Construct the train data indexed by session id and mode id jointly. Convert all the raw features (user profile,
od pair, req time, click time, eta, price, distance, transport mode) to one-hot ids used for
embedding. We split the one-hot features into two categories: user feature and context feature for
better understanding of FM algorithm.
Note that the user profile is already provided by one-hot encoded form, we treat it as embedded vector
for unity with the context feature and easily using of PaddlePaddle embedding layer. Given the
train clicks data, we label each train instance with 1 or 0 depend on if this instance is clicked or
not include non-click case.
:return:
"""
train_data_dict = {}
with open(TRAIN_QUERIES_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
train_index_list = []
for k, line in enumerate(csv_reader):
if k == 0: continue
if line[0] == "": continue
if line[1] == "":
train_index_list.append(line[0] + "_0")
else:
train_index_list.append(line[0] + "_" + line[1])
train_index = line[0]
train_data_dict[train_index] = {}
train_data_dict[train_index]["pid"] = line[1]
train_data_dict[train_index]["query"] = {}
reqweekday = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%w")
reqhour = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%H")
train_data_dict[train_index]["query"].update({"weekday":reqweekday})
train_data_dict[train_index]["query"].update({"hour":reqhour})
o = line[3].split(',')
o_first = o[0]
o_second = o[1]
train_data_dict[train_index]["query"].update({"o1":float(o_first)})
train_data_dict[train_index]["query"].update({"o2":float(o_second)})
d = line[4].split(',')
d_first = d[0]
d_second = d[1]
train_data_dict[train_index]["query"].update({"d1":float(d_first)})
train_data_dict[train_index]["query"].update({"d2":float(d_second)})
plan_map = {}
plan_data = pd.read_csv(TRAIN_PLANS_PATH)
for index, row in plan_data.iterrows():
plans_str = row['plans']
plans_list = json.loads(plans_str)
session_id = str(row['sid'])
# train_data_dict[session_id]["plans"] = []
plan_map[session_id] = plans_list
profile_map = {}
with open(PROFILES_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
for k, line in enumerate(csv_reader):
if k == 0: continue
profile_map[line[0]] = [i for i in range(len(line)) if line[i] == "1.0"]
session_click_map = {}
with open(TRAIN_CLICK_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
for k, line in enumerate(csv_reader):
if k == 0: continue
if line[0] == "" or line[1] == "" or line[2] == "":
continue
session_click_map[line[0]] = line[2]
#return train_data_dict, profile_map, session_click_map, plan_map
generate_sparse_features(train_data_dict, profile_map, session_click_map, plan_map)
def generate_sparse_features(train_data_dict, profile_map, session_click_map, plan_map):
if not os.path.isdir(OUT_DIR):
os.mkdir(OUT_DIR)
with open(os.path.join("./out/", "train.txt"), 'w') as f_train:
for session_id, plan_list in plan_map.items():
if session_id not in train_data_dict:
continue
cur_map = train_data_dict[session_id]
if cur_map["pid"] != "":
cur_map["profile"] = profile_map[cur_map["pid"]]
else:
cur_map["profile"] = [0]
del cur_map["pid"]
whole_rank = 0
for plan in plan_list:
whole_rank += 1
cur_map["whole_rank"] = whole_rank
flag_click = False
rank = 1
for plan in plan_list:
if ("transport_mode" in plan) and (session_id in session_click_map) and (
int(plan["transport_mode"]) == int(session_click_map[session_id])):
cur_map["plan"] = plan
cur_map["label"] = 1
flag_click = True
# print("label is 1")
else:
cur_map["plan"] = plan
cur_map["label"] = 0
cur_map["rank"] = rank
rank += 1
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
if not flag_click:
cur_map["plan"]["distance"] = -1
cur_map["plan"]["price"] = -1
cur_map["plan"]["eta"] = -1
cur_map["plan"]["transport_mode"] = 0
cur_map["rank"] = 0
cur_map["label"] = 1
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
else:
cur_map["plan"]["distance"] = -1
cur_map["plan"]["price"] = -1
cur_map["plan"]["eta"] = -1
cur_map["plan"]["transport_mode"] = 0
cur_map["rank"] = 0
cur_map["label"] = 0
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
build_norm_feature()
if __name__ == "__main__":
preprocess()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, random, csv, datetime, json
import pandas as pd
import numpy as np
import argparse
import logging
import time
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("preprocess")
logger.setLevel(logging.INFO)
TRAIN_QUERIES_PATH = "./data_set_phase1/train_queries.csv"
TRAIN_PLANS_PATH = "./data_set_phase1/train_plans.csv"
TRAIN_CLICK_PATH = "./data_set_phase1/train_clicks.csv"
PROFILES_PATH = "./data_set_phase1/profiles.csv"
OUT_DIR = "./out"
ORI_TRAIN_PATH = "train.txt"
NORM_TRAIN_PATH = "normed_train.txt"
#variable to control the ratio of positive and negative instance of transmode 0 which is original label of no click
THRESHOLD_LABEL = 0.5
O1_MIN = 115.47
O1_MAX = 117.29
O2_MIN = 39.46
O2_MAX = 40.97
D1_MIN = 115.44
D1_MAX = 117.37
D2_MIN = 39.46
D2_MAX = 40.96
DISTANCE_MIN = 1.0
DISTANCE_MAX = 225864.0
THRESHOLD_DIS = 200000.0
PRICE_MIN = 200.0
PRICE_MAX = 92300.0
THRESHOLD_PRICE = 20000
ETA_MIN = 1.0
ETA_MAX = 72992.0
THRESHOLD_ETA = 10800.0
def build_norm_feature():
with open(os.path.join(OUT_DIR, NORM_TRAIN_PATH), 'w') as nf:
with open(os.path.join(OUT_DIR, ORI_TRAIN_PATH), 'r') as f:
for line in f:
cur_map = json.loads(line)
cur_map["plan"]["distance"] = (cur_map["plan"]["distance"] - DISTANCE_MIN) / (DISTANCE_MAX - DISTANCE_MIN)
if cur_map["plan"]["price"]:
cur_map["plan"]["price"] = (cur_map["plan"]["price"] - PRICE_MIN) / (PRICE_MAX - PRICE_MIN)
else:
cur_map["plan"]["price"] = 0.0
cur_map["plan"]["eta"] = (cur_map["plan"]["eta"] - ETA_MIN) / (ETA_MAX - ETA_MIN)
cur_json_instance = json.dumps(cur_map)
nf.write(cur_json_instance + '\n')
def preprocess():
"""
Construct the train data indexed by session id and mode id jointly. Convert all the raw features (user profile,
od pair, req time, click time, eta, price, distance, transport mode) to one-hot ids used for
embedding. We split the one-hot features into two categories: user feature and context feature for
better understanding of FM algorithm.
Note that the user profile is already provided by one-hot encoded form, we treat it as embedded vector
for unity with the context feature and easily using of PaddlePaddle embedding layer. Given the
train clicks data, we label each train instance with 1 or 0 depend on if this instance is clicked or
not include non-click case. To Be Changed
:return:
"""
train_data_dict = {}
with open("./weather.json", 'r') as f:
weather_dict = json.load(f)
with open(TRAIN_QUERIES_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
train_index_list = []
for k, line in enumerate(csv_reader):
if k == 0: continue
if line[0] == "": continue
if line[1] == "":
train_index_list.append(line[0] + "_0")
else:
train_index_list.append(line[0] + "_" + line[1])
train_index = line[0]
train_data_dict[train_index] = {}
train_data_dict[train_index]["pid"] = line[1]
train_data_dict[train_index]["query"] = {}
train_data_dict[train_index]["weather"] = {}
reqweekday = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%w")
reqhour = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%H")
# weather related features, no big use, maybe more detailed weather information is better
date_key = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%m-%d")
train_data_dict[train_index]["weather"] = {}
train_data_dict[train_index]["weather"].update({"max_temp": weather_dict[date_key]["max_temp"]})
train_data_dict[train_index]["weather"].update({"min_temp": weather_dict[date_key]["min_temp"]})
train_data_dict[train_index]["weather"].update({"wea": weather_dict[date_key]["weather"]})
train_data_dict[train_index]["weather"].update({"wind": weather_dict[date_key]["wind"]})
train_data_dict[train_index]["query"].update({"weekday":reqweekday})
train_data_dict[train_index]["query"].update({"hour":reqhour})
o = line[3].split(',')
o_first = o[0]
o_second = o[1]
train_data_dict[train_index]["query"].update({"o1":float(o_first)})
train_data_dict[train_index]["query"].update({"o2":float(o_second)})
d = line[4].split(',')
d_first = d[0]
d_second = d[1]
train_data_dict[train_index]["query"].update({"d1":float(d_first)})
train_data_dict[train_index]["query"].update({"d2":float(d_second)})
plan_map = {}
plan_data = pd.read_csv(TRAIN_PLANS_PATH)
for index, row in plan_data.iterrows():
plans_str = row['plans']
plans_list = json.loads(plans_str)
session_id = str(row['sid'])
# train_data_dict[session_id]["plans"] = []
plan_map[session_id] = plans_list
profile_map = {}
with open(PROFILES_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
for k, line in enumerate(csv_reader):
if k == 0: continue
profile_map[line[0]] = [i for i in range(len(line)) if line[i] == "1.0"]
session_click_map = {}
with open(TRAIN_CLICK_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
for k, line in enumerate(csv_reader):
if k == 0: continue
if line[0] == "" or line[1] == "" or line[2] == "":
continue
session_click_map[line[0]] = line[2]
#return train_data_dict, profile_map, session_click_map, plan_map
generate_sparse_features(train_data_dict, profile_map, session_click_map, plan_map)
def generate_sparse_features(train_data_dict, profile_map, session_click_map, plan_map):
if not os.path.isdir(OUT_DIR):
os.mkdir(OUT_DIR)
with open(os.path.join(OUT_DIR, ORI_TRAIN_PATH), 'w') as f_train:
for session_id, plan_list in plan_map.items():
if session_id not in train_data_dict:
continue
cur_map = train_data_dict[session_id]
if cur_map["pid"] != "":
cur_map["profile"] = profile_map[cur_map["pid"]]
else:
cur_map["profile"] = [0]
#rank information related feature
whole_rank = 0
for plan in plan_list:
whole_rank += 1
cur_map["mode_rank" + str(whole_rank)] = plan["transport_mode"]
if whole_rank < 5:
for r in range(whole_rank + 1, 6):
cur_map["mode_rank" + str(r)] = -1
cur_map["whole_rank"] = whole_rank
flag_click = False
rank = 1
price_list = []
eta_list = []
distance_list = []
for plan in plan_list:
if not plan["price"]:
price_list.append(0)
else:
price_list.append(int(plan["price"]))
eta_list.append(int(plan["eta"]))
distance_list.append(int(plan["distance"]))
price_list.sort(reverse=False)
eta_list.sort(reverse=False)
distance_list.sort(reverse=False)
for plan in plan_list:
if plan["price"] and int(plan["price"]) == price_list[0]:
cur_map["mode_min_price"] = plan["transport_mode"]
if plan["price"] and int(plan["price"]) == price_list[-1]:
cur_map["mode_max_price"] = plan["transport_mode"]
if int(plan["eta"]) == eta_list[0]:
cur_map["mode_min_eta"] = plan["transport_mode"]
if int(plan["eta"]) == eta_list[-1]:
cur_map["mode_max_eta"] = plan["transport_mode"]
if int(plan["distance"]) == distance_list[0]:
cur_map["mode_min_distance"] = plan["transport_mode"]
if int(plan["distance"]) == distance_list[-1]:
cur_map["mode_max_distance"] = plan["transport_mode"]
if "mode_min_price" not in cur_map:
cur_map["mode_min_price"] = -1
if "mode_max_price" not in cur_map:
cur_map["mode_max_price"] = -1
for plan in plan_list:
if ("transport_mode" in plan) and (session_id in session_click_map) and (
int(plan["transport_mode"]) == int(session_click_map[session_id])):
flag_click = True
if flag_click:
for plan in plan_list:
cur_price = int(plan["price"]) if plan["price"] else 0
cur_eta = int(plan["eta"])
cur_distance = int(plan["distance"])
cur_map["price_rank"] = price_list.index(cur_price) + 1
cur_map["eta_rank"] = eta_list.index(cur_eta) + 1
cur_map["distance_rank"] = distance_list.index(cur_distance) + 1
if ("transport_mode" in plan) and (session_id in session_click_map) and (
int(plan["transport_mode"]) == int(session_click_map[session_id])):
cur_map["plan"] = plan
cur_map["label"] = 1
else:
cur_map["plan"] = plan
cur_map["label"] = 0
cur_map["plan_rank"] = rank
rank += 1
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
cur_map["plan"] = {}
#since we define a new ctr task from original task, we use a basic way to generate instances of transport mode 0.
#There should be a optimal strategy to generate instances of transport mode 0
if not flag_click:
cur_map["plan"]["distance"] = -1
cur_map["plan"]["price"] = -1
cur_map["plan"]["eta"] = -1
cur_map["plan"]["transport_mode"] = 0
cur_map["plan_rank"] = 0
cur_map["price_rank"] = 0
cur_map["eta_rank"] = 0
cur_map["distance_rank"] = 0
cur_map["label"] = 1
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
else:
if random.random() < THRESHOLD_LABEL:
cur_map["plan"]["distance"] = -1
cur_map["plan"]["price"] = -1
cur_map["plan"]["eta"] = -1
cur_map["plan"]["transport_mode"] = 0
cur_map["plan_rank"] = 0
cur_map["price_rank"] = 0
cur_map["eta_rank"] = 0
cur_map["distance_rank"] = 0
cur_map["label"] = 0
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
build_norm_feature()
if __name__ == "__main__":
preprocess()
{"10-01": {"max_temp": "24", "min_temp": "12", "weather": "q", "wind": "45"}, "10-02": {"max_temp": "24", "min_temp": "11", "weather": "q", "wind": "12"}, "10-03": {"max_temp": "25", "min_temp": "10", "weather": "q", "wind": "12"}, "10-04": {"max_temp": "25", "min_temp": "12", "weather": "q", "wind": "12"}, "10-05": {"max_temp": "24", "min_temp": "14", "weather": "dy", "wind": "12"}, "10-06": {"max_temp": "20", "min_temp": "8", "weather": "q", "wind": "45"}, "10-07": {"max_temp": "21", "min_temp": "7", "weather": "q", "wind": "12"}, "10-08": {"max_temp": "21", "min_temp": "8", "weather": "dy", "wind": "12"}, "10-09": {"max_temp": "15", "min_temp": "4", "weather": "dyq", "wind": "45"}, "10-10": {"max_temp": "17", "min_temp": "4", "weather": "dyq", "wind": "12"}, "10-11": {"max_temp": "18", "min_temp": "5", "weather": "qdy", "wind": "12"}, "10-12": {"max_temp": "20", "min_temp": "5", "weather": "dyq", "wind": "12"}, "10-13": {"max_temp": "20", "min_temp": "8", "weather": "dy", "wind": "12"}, "10-14": {"max_temp": "21", "min_temp": "10", "weather": "dy", "wind": "12"}, "10-15": {"max_temp": "17", "min_temp": "11", "weather": "xq", "wind": "12"}, "10-16": {"max_temp": "17", "min_temp": "7", "weather": "dyq", "wind": "12"}, "10-17": {"max_temp": "17", "min_temp": "5", "weather": "q", "wind": "12"}, "10-18": {"max_temp": "18", "min_temp": "5", "weather": "q", "wind": "12"}, "10-19": {"max_temp": "19", "min_temp": "7", "weather": "dy", "wind": "12"}, "10-20": {"max_temp": "18", "min_temp": "7", "weather": "dy", "wind": "12"}, "10-21": {"max_temp": "18", "min_temp": "7", "weather": "dy", "wind": "12"}, "10-22": {"max_temp": "19", "min_temp": "5", "weather": "dyq", "wind": "12"}, "10-23": {"max_temp": "19", "min_temp": "4", "weather": "q", "wind": "34"}, "10-24": {"max_temp": "20", "min_temp": "6", "weather": "qdy", "wind": "12"}, "10-25": {"max_temp": "15", "min_temp": "8", "weather": "dy", "wind": "12"}, "10-26": {"max_temp": "14", "min_temp": "3", "weather": "q", "wind": "45"}, "10-27": {"max_temp": "17", "min_temp": "5", "weather": "dy", "wind": "12"}, "10-28": {"max_temp": "17", "min_temp": "4", "weather": "dyq", "wind": "45"}, "10-29": {"max_temp": "15", "min_temp": "3", "weather": "q", "wind": "34"}, "10-30": {"max_temp": "16", "min_temp": "1", "weather": "q", "wind": "12"}, "10-31": {"max_temp": "17", "min_temp": "3", "weather": "q", "wind": "12"}, "11-01": {"max_temp": "17", "min_temp": "3", "weather": "q", "wind": "12"}, "11-02": {"max_temp": "18", "min_temp": "4", "weather": "q", "wind": "12"}, "11-03": {"max_temp": "16", "min_temp": "6", "weather": "dy", "wind": "12"}, "11-04": {"max_temp": "10", "min_temp": "2", "weather": "xydy", "wind": "34"}, "11-05": {"max_temp": "10", "min_temp": "2", "weather": "dy", "wind": "12"}, "11-06": {"max_temp": "12", "min_temp": "0", "weather": "dy", "wind": "12"}, "11-07": {"max_temp": "13", "min_temp": "3", "weather": "dy", "wind": "12"}, "11-08": {"max_temp": "14", "min_temp": "2", "weather": "dy", "wind": "12"}, "11-09": {"max_temp": "15", "min_temp": "1", "weather": "qdy", "wind": "34"}, "11-10": {"max_temp": "11", "min_temp": "0", "weather": "dy", "wind": "12"}, "11-11": {"max_temp": "13", "min_temp": "1", "weather": "dyq", "wind": "12"}, "11-12": {"max_temp": "14", "min_temp": "2", "weather": "q", "wind": "12"}, "11-13": {"max_temp": "13", "min_temp": "5", "weather": "dy", "wind": "12"}, "11-14": {"max_temp": "13", "min_temp": "5", "weather": "dy", "wind": "12"}, "11-15": {"max_temp": "8", "min_temp": "1", "weather": "xydy", "wind": "34"}, "11-16": {"max_temp": "8", "min_temp": "-1", "weather": "q", "wind": "12"}, "11-17": {"max_temp": "9", "min_temp": "-2", "weather": "dyq", "wind": "12"}, "11-18": {"max_temp": "11", "min_temp": "-3", "weather": "q", "wind": "34"}, "11-19": {"max_temp": "10", "min_temp": "-2", "weather": "qdy", "wind": "12"}, "11-20": {"max_temp": "9", "min_temp": "-1", "weather": "dy", "wind": "12"}, "11-21": {"max_temp": "9", "min_temp": "-3", "weather": "q", "wind": "2"}, "11-22": {"max_temp": "8", "min_temp": "-3", "weather": "qdy", "wind": "1"}, "11-23": {"max_temp": "7", "min_temp": "0", "weather": "dy", "wind": "2"}, "11-24": {"max_temp": "9", "min_temp": "-3", "weather": "qdy", "wind": "2"}, "11-25": {"max_temp": "10", "min_temp": "-3", "weather": "q", "wind": "1"}, "11-26": {"max_temp": "10", "min_temp": "0", "weather": "dy", "wind": "1"}, "11-27": {"max_temp": "9", "min_temp": "-3", "weather": "qdy", "wind": "2"}, "11-28": {"max_temp": "8", "min_temp": "-3", "weather": "q", "wind": "1"}, "11-29": {"max_temp": "7", "min_temp": "-4", "weather": "q", "wind": "1"}, "11-30": {"max_temp": "8", "min_temp": "-3", "weather": "q", "wind": "1"}, "12-01": {"max_temp": "7", "min_temp": "0", "weather": "dy", "wind": "1"}, "12-02": {"max_temp": "9", "min_temp": "2", "weather": "dy", "wind": "1"}, "12-03": {"max_temp": "8", "min_temp": "-3", "weather": "dyq", "wind": "3"}, "12-04": {"max_temp": "4", "min_temp": "-6", "weather": "qdy", "wind": "2"}, "12-05": {"max_temp": "1", "min_temp": "-4", "weather": "dy", "wind": "1"}, "12-06": {"max_temp": "-2", "min_temp": "-9", "weather": "q", "wind": "3"}, "12-07": {"max_temp": "-4", "min_temp": "-10", "weather": "q", "wind": "3"}, "12-08": {"max_temp": "-2", "min_temp": "-10", "weather": "qdy", "wind": "2"}, "12-09": {"max_temp": "-1", "min_temp": "-10", "weather": "dyq", "wind": "1"}}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册