未验证 提交 a8a58316 编写于 作者: Z zhang wenhui 提交者: GitHub

Update paddlerec 1.8 (#4622)

* update api 1.8

* update model 1.8
上级 57c0bf10
# Paddle_baseline_KDD2019
Paddle baseline for KDD2019 "Context-Aware Multi-Modal Transportation Recommendation"(https://dianshi.baidu.com/competition/29/question)
This repository is the demo codes for the KDD2019 "Context-Aware Multi-Modal Transportation Recommendation" competition using PaddlePaddle. It is written by python and uses PaddlePaddle to solve the task. Note that this repository is on developing and welcome everyone to contribute. The current baseline solution codes can get 0.68 - 0.69 score of online submission. As an example, my submission based on these networks programmed by PaddlePaddle is 0.6898.
The reason of the publication of this baseline codes is to encourage us to use PaddlePaddle and build the most powerful recommendation model via PaddlePaddle.
The example codes are ran on Linux, python2.7, single machine with CPU . Note that distributed train options are not provided here, if you want to learn more about this, please check more modes examples on https://github.com/PaddlePaddle/models. About the speed of training, for one epoch, 1000 batch size, it would take about 8 mins to train the whole training instances generated from raw data using SGD optimizer (it would take relatively longer using Adam optimizer).
The configuration and process of all the networks are fundamental, a lot of optimizations can be done based on them to achieve better results e.g. better cost function, more powerful feature engineering, designed model validation, NN optimization tricks...
The code is rough and from my daily use. They will be trimmed these days...
## Install PaddlePaddle
please visit the official site of PaddlePaddle(http://www.paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/install/index_cn.html)
## preprocess feature
```python
python preprocess_dense.py # change for different feature strategy
python pre_test_dense.py
```
preprocess.py and preprocess_dense.py is the code for preprocessing the raw data. Two versions are provided to deal with all sparse features and sparse plus dense features. Correspondingly, pre_process_test.py and pre_test_dense.py are the codes to preproccess test raw data. The training instances are saved in json. It is very easy to add new features. In our demo, all features are generated from provided raw data except for weather feature, which is gengerated from open weather records.
Note that the feature generated in this step need to fit in the input of the model input. Make sure we use the right version. In demo codes, The sparse plus dense features are used for network_confv6.
## build the network
main network logic is in network_confv?.py. The networks are base on fm & deep related algorithms. I try several networks and public some of them. There may be some defects in the networks but all of them are functional.
## train the network
```python
python local_train.py
```
In local_train.py and map_reader.py, I use dataset API, so we need to download the corresponding .whl package or clone codes on develop branch of PaddlePaddle. The reason to use this is the speed of feeding data is much faster.
Note that the input format feed into the network is self-defined. make sure we build the same format between training and test.
## test results
```python
python generate_test.py
python build_submit.py
```
In generate_test.py and build_submit, for convenience, I use the whole train data to train the network and test the network with provided data without label
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle CTR example")
parser.add_argument(
'--train_data_path',
type=str,
default='./data/raw/train.txt',
help="The path of training dataset")
parser.add_argument(
'--test_data_path',
type=str,
default='./data/raw/valid.txt',
help="The path of testing dataset")
parser.add_argument(
'--batch_size',
type=int,
default=1000,
help="The size of mini-batch (default:1000)")
parser.add_argument(
'--embedding_size',
type=int,
default=16,
help="The size for embedding layer (default:10)")
parser.add_argument(
'--num_passes',
type=int,
default=10,
help="The number of passes to train (default: 10)")
parser.add_argument(
'--model_output_dir',
type=str,
default='models',
help='The path for model to store (default: models)')
parser.add_argument(
'--sparse_feature_dim',
type=int,
default=1000001,
help='sparse feature hashing space for index processing')
parser.add_argument(
'--is_local',
type=int,
default=1,
help='Local train or distributed train (default: 1)')
parser.add_argument(
'--cloud_train',
type=int,
default=0,
help='Local train or distributed train on paddlecloud (default: 0)')
parser.add_argument(
'--async_mode',
action='store_true',
default=False,
help='Whether start pserver in async mode to support ASGD')
parser.add_argument(
'--no_split_var',
action='store_true',
default=False,
help='Whether split variables into blocks when update_method is pserver')
parser.add_argument(
'--role',
type=str,
default='pserver', # trainer or pserver
help='The path for model to store (default: models)')
parser.add_argument(
'--endpoints',
type=str,
default='127.0.0.1:6000',
help='The pserver endpoints, like: 127.0.0.1:6000,127.0.0.1:6001')
parser.add_argument(
'--current_endpoint',
type=str,
default='127.0.0.1:6000',
help='The path for model to store (default: 127.0.0.1:6000)')
parser.add_argument(
'--trainer_id',
type=int,
default=0,
help='The path for model to store (default: models)')
parser.add_argument(
'--trainers',
type=int,
default=1,
help='The num of trianers, (default: 1)')
return parser.parse_args()
import json
import csv
import io
def build():
submit_map = {}
with io.open('./submit/submit.csv', 'wb') as csv_file:
writer = csv.writer(csv_file, delimiter=',')
writer.writerow(['sid', 'recommend_mode'])
with open('./out/normed_test_session.txt', 'r') as f1:
with open('./testres/res8', 'r') as f2:
cur_session =''
for x, y in zip(f1.readlines(), f2.readlines()):
m1 = json.loads(x)
session_id = m1["session_id"]
if cur_session == '':
cur_session = session_id
transport_mode = m1["plan"]["transport_mode"]
if cur_session != session_id:
writer.writerow([str(cur_session), str(submit_map[cur_session]["transport_mode"])])
cur_session = session_id
if session_id not in submit_map:
submit_map[session_id] = {}
submit_map[session_id]["transport_mode"] = transport_mode
submit_map[session_id]["probability"] = y
#if int(submit_map[session_id]["transport_mode"]) == 0 and submit_map[session_id]["probability"] > 0.02:
#submit_map[session_id]["probability"] = 0.99
else:
if float(y) > float(submit_map[session_id]["probability"]):
submit_map[session_id]["transport_mode"] = transport_mode
submit_map[session_id]["probability"] = y
#if int(submit_map[session_id]["transport_mode"]) == 0 and submit_map[session_id]["probability"] > 0.02:
#submit_map[session_id]["transport_mode"] = 0
#submit_map[session_id]["probability"] = 0.99
writer.writerow([cur_session, submit_map[cur_session]["transport_mode"]])
if __name__ == "__main__":
build()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import numpy as np
# disable gpu training for this example
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import paddle
import paddle.fluid as fluid
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
num_context_feature = 22
def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle DeepFM example")
parser.add_argument(
'--model_path',
type=str,
#required=True,
default='models',
help="The path of model parameters gz file")
parser.add_argument(
'--data_path',
type=str,
required=False,
help="The path of the dataset to infer")
parser.add_argument(
'--embedding_size',
type=int,
default=16,
help="The size for embedding layer (default:10)")
parser.add_argument(
'--sparse_feature_dim',
type=int,
default=1000001,
help="The size for embedding layer (default:1000001)")
parser.add_argument(
'--batch_size',
type=int,
default=1000,
help="The size of mini-batch (default:1000)")
return parser.parse_args()
def to_lodtensor(data, place):
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def data2tensor(data, place):
feed_dict = {}
dense = data[0]
sparse = data[1:-1]
y = data[-1]
#user_data = np.array([x[0] for x in data]).astype("float32")
#user_data = user_data.reshape([-1, 10])
#feed_dict["user_profile"] = user_data
dense_data = np.array([x[0] for x in data]).astype("float32")
dense_data = dense_data.reshape([-1, 3])
feed_dict["dense_feature"] = dense_data
for i in range(num_context_feature):
sparse_data = to_lodtensor([x[1 + i] for x in data], place)
feed_dict["context" + str(i)] = sparse_data
context_fm = to_lodtensor(
np.array([x[-2] for x in data]).astype("float32"), place)
feed_dict["context_fm"] = context_fm
y_data = np.array([x[-1] for x in data]).astype("int64")
y_data = y_data.reshape([-1, 1])
feed_dict["label"] = y_data
return feed_dict
def test():
args = parse_args()
place = fluid.CPUPlace()
test_scope = fluid.core.Scope()
# filelist = ["%s/%s" % (args.data_path, x) for x in os.listdir(args.data_path)]
from map_reader import MapDataset
map_dataset = MapDataset()
map_dataset.setup(args.sparse_feature_dim)
exe = fluid.Executor(place)
whole_filelist = ["./out/normed_test_session.txt"]
test_files = whole_filelist[int(0.0 * len(whole_filelist)):int(1.0 * len(
whole_filelist))]
epochs = 1
for i in range(epochs):
cur_model_path = os.path.join(args.model_path,
"epoch" + str(1) + ".model")
with open("./testres/res" + str(i), 'w') as r:
with fluid.scope_guard(test_scope):
[inference_program, feed_target_names, fetch_targets] = \
fluid.io.load_inference_model(cur_model_path, exe)
test_reader = map_dataset.test_reader(test_files, 1000, 100000)
k = 0
for batch_id, data in enumerate(test_reader()):
print(len(data[0]))
feed_dict = data2tensor(data, place)
loss_val, auc_val, accuracy, predict, _ = exe.run(
inference_program,
feed=feed_dict,
fetch_list=fetch_targets,
return_numpy=False)
x = np.array(predict)
for j in range(x.shape[0]):
r.write(str(x[j][1]))
r.write("\n")
if __name__ == '__main__':
test()
import argparse
import logging
import numpy as np
# disable gpu training for this example
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import paddle
import paddle.fluid as fluid
import map_reader
from network_conf import ctr_deepfm_dataset
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle DeepFM example")
parser.add_argument(
'--model_path',
type=str,
#required=True,
default='models',
help="The path of model parameters gz file")
parser.add_argument(
'--data_path',
type=str,
required=False,
help="The path of the dataset to infer")
parser.add_argument(
'--embedding_size',
type=int,
default=16,
help="The size for embedding layer (default:10)")
parser.add_argument(
'--sparse_feature_dim',
type=int,
default=1000001,
help="The size for embedding layer (default:1000001)")
parser.add_argument(
'--batch_size',
type=int,
default=1000,
help="The size of mini-batch (default:1000)")
return parser.parse_args()
def to_lodtensor(data, place):
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def data2tensor(data, place):
feed_dict = {}
test_dict = {}
dense = data[0]
sparse = data[1:-1]
y = data[-1]
dense_data = np.array([x[0] for x in data]).astype("float32")
dense_data = dense_data.reshape([-1, 65])
feed_dict["user_profile"] = dense_data
for i in range(10):
sparse_data = to_lodtensor([x[1 + i] for x in data], place)
feed_dict["context" + str(i)] = sparse_data
y_data = np.array([x[-1] for x in data]).astype("int64")
y_data = y_data.reshape([-1, 1])
feed_dict["label"] = y_data
test_dict["test"] = [1]
return feed_dict, test_dict
def infer():
args = parse_args()
place = fluid.CPUPlace()
inference_scope = fluid.core.Scope()
filelist = [
"%s/%s" % (args.data_path, x) for x in os.listdir(args.data_path)
]
from map_reader import MapDataset
map_dataset = MapDataset()
map_dataset.setup(args.sparse_feature_dim)
exe = fluid.Executor(place)
whole_filelist = [
"raw_data/part-%d" % x for x in range(len(os.listdir("raw_data")))
]
#whole_filelist = ["./out/normed_train09", "./out/normed_train10", "./out/normed_train11"]
test_files = whole_filelist[int(0.0 * len(whole_filelist)):int(1.0 * len(
whole_filelist))]
# file_groups = [whole_filelist[i:i+train_thread_num] for i in range(0, len(whole_filelist), train_thread_num)]
def set_zero(var_name):
param = inference_scope.var(var_name).get_tensor()
param_array = np.zeros(param._get_dims()).astype("int64")
param.set(param_array, place)
epochs = 2
for i in range(epochs):
cur_model_path = os.path.join(args.model_path,
"epoch" + str(i + 1) + ".model")
with fluid.scope_guard(inference_scope):
[inference_program, feed_target_names, fetch_targets] = \
fluid.io.load_inference_model(cur_model_path, exe)
auc_states_names = ['_generated_var_2', '_generated_var_3']
for name in auc_states_names:
set_zero(name)
test_reader = map_dataset.infer_reader(test_files, 1000, 100000)
for batch_id, data in enumerate(test_reader()):
loss_val, auc_val, accuracy, predict, label = exe.run(
inference_program,
feed=data2tensor(data, place),
fetch_list=fetch_targets,
return_numpy=False)
#print(np.array(predict))
#x = np.array(predict)
#print(.shape)x
#print("train_pass_%d, test_pass_%d\t%f\t" % (i - 1, i, auc_val))
if __name__ == '__main__':
infer()
from __future__ import print_function
from args import parse_args
import os
import paddle.fluid as fluid
import sys
from network_confv6 import ctr_deepfm_dataset
NUM_CONTEXT_FEATURE = 22
DIM_USER_PROFILE = 10
DIM_DENSE_FEATURE = 3
PYTHON_PATH = "/home/yaoxuefeng/whls/paddle_release_home/python/bin/python" # this is mine change yours
def train():
args = parse_args()
if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
#set the input format for our model. Note that you need to carefully modify them when you define a new network
#user_profile = fluid.layers.data(
#name="user_profile", shape=[DIM_USER_PROFILE], dtype='int64', lod_level=1)
dense_feature = fluid.layers.data(
name="dense_feature", shape=[DIM_DENSE_FEATURE], dtype='float32')
context_feature = [
fluid.layers.data(
name="context" + str(i), shape=[1], lod_level=1, dtype="int64")
for i in range(0, NUM_CONTEXT_FEATURE)
]
context_feature_fm = fluid.layers.data(
name="context_fm", shape=[1], dtype='int64', lod_level=1)
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
print("ready to network")
#self define network
loss, auc_var, batch_auc_var, accuracy, predict = ctr_deepfm_dataset(
dense_feature, context_feature, context_feature_fm, label,
args.embedding_size, args.sparse_feature_dim)
print("ready to optimize")
optimizer = fluid.optimizer.SGD(learning_rate=1e-4)
optimizer.minimize(loss)
#single machine CPU training. more options on trainig please visit PaddlePaddle site
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
#use dataset api for much faster speed
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_use_var([dense_feature] + context_feature +
[context_feature_fm] + [label])
#self define how to process generated training insatnces in map_reader.py
pipe_command = PYTHON_PATH + " map_reader.py %d" % args.sparse_feature_dim
dataset.set_pipe_command(pipe_command)
dataset.set_batch_size(args.batch_size)
thread_num = 1
dataset.set_thread(thread_num)
#self define how to split training files for example:"split -a 2 -d -l 200000 normed_train.txt normed_train"
whole_filelist = [
"./out/normed_train%d" % x for x in range(len(os.listdir("out")))
]
whole_filelist = [
"./out/normed_train00", "./out/normed_train01", "./out/normed_train02",
"./out/normed_train03", "./out/normed_train04", "./out/normed_train05",
"./out/normed_train06", "./out/normed_train07", "./out/normed_train08",
"./out/normed_train09", "./out/normed_train10", "./out/normed_train11"
]
print("ready to epochs")
epochs = 10
for i in range(epochs):
print("start %dth epoch" % i)
dataset.set_filelist(whole_filelist[:int(len(whole_filelist))])
#print the informations you want by setting fetch_list and fetch_info
exe.train_from_dataset(
program=fluid.default_main_program(),
dataset=dataset,
fetch_list=[auc_var, accuracy, predict, label],
fetch_info=["auc", "accuracy", "predict", "label"],
debug=False)
model_dir = os.path.join(args.model_output_dir,
'/epoch' + str(i + 1) + ".model")
sys.stderr.write("epoch%d finished" % (i + 1))
#save model
fluid.io.save_inference_model(
model_dir,
[dense_feature.name] + [x.name for x in context_feature] +
[context_feature_fm.name] + [label.name],
[loss, auc_var, accuracy, predict, label], exe)
if __name__ == '__main__':
train()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import json
import paddle.fluid.incubate.data_generator as dg
class MapDataset(dg.MultiSlotDataGenerator):
def setup(self, sparse_feature_dim):
self.profile_length = 65
self.dense_length = 3
#feature names
self.dense_feature_list = ["distance", "price", "eta"]
self.pid_list = ["pid"]
self.query_feature_list = ["weekday", "hour", "o1", "o2", "d1", "d2"]
self.plan_feature_list = ["transport_mode"]
self.rank_feature_list = ["plan_rank", "whole_rank", "price_rank", "eta_rank", "distance_rank"]
self.rank_whole_pic_list = ["mode_rank1", "mode_rank2", "mode_rank3", "mode_rank4",
"mode_rank5"]
self.weather_feature_list = ["max_temp", "min_temp", "wea", "wind"]
self.hash_dim = 1000001
self.train_idx_ = 2000000
#carefully set if you change the features
self.categorical_range_ = range(0, 22)
#process one instance
def _process_line(self, line):
instance = json.loads(line)
"""
profile = instance["profile"]
len_profile = len(profile)
if len_profile >= 10:
user_profile_feature = profile[0:10]
else:
profile.extend([0]*(10-len_profile))
user_profile_feature = profile
if len(profile) > 1 or (len(profile) == 1 and profile[0] != 0):
for p in profile:
if p >= 1 and p <= 65:
user_profile_feature[p - 1] = 1
"""
context_feature = []
context_feature_fm = []
dense_feature = [0] * self.dense_length
plan = instance["plan"]
for i, val in enumerate(self.dense_feature_list):
dense_feature[i] = plan[val]
if (instance["pid"] == ""):
instance["pid"] = 0
query = instance["query"]
weather_dic = instance["weather"]
for fea in self.pid_list:
context_feature.append([hash(fea + str(instance[fea])) % self.hash_dim])
context_feature_fm.append(hash(fea + str(instance[fea])) % self.hash_dim)
for fea in self.query_feature_list:
context_feature.append([hash(fea + str(query[fea])) % self.hash_dim])
context_feature_fm.append(hash(fea + str(query[fea])) % self.hash_dim)
for fea in self.plan_feature_list:
context_feature.append([hash(fea + str(plan[fea])) % self.hash_dim])
context_feature_fm.append(hash(fea + str(plan[fea])) % self.hash_dim)
for fea in self.rank_feature_list:
context_feature.append([hash(fea + str(instance[fea])) % self.hash_dim])
context_feature_fm.append(hash(fea + str(instance[fea])) % self.hash_dim)
for fea in self.rank_whole_pic_list:
context_feature.append([hash(fea + str(instance[fea])) % self.hash_dim])
context_feature_fm.append(hash(fea + str(instance[fea])) % self.hash_dim)
for fea in self.weather_feature_list:
context_feature.append([hash(fea + str(weather_dic[fea])) % self.hash_dim])
context_feature_fm.append(hash(fea + str(weather_dic[fea])) % self.hash_dim)
label = [int(instance["label"])]
return dense_feature, context_feature, context_feature_fm, label
def infer_reader(self, filelist, batch, buf_size):
print(filelist)
def local_iter():
for fname in filelist:
with open(fname.strip(), "r") as fin:
for line in fin:
dense_feature, sparse_feature, sparse_feature_fm, label = self._process_line(line)
yield [dense_feature] + sparse_feature + [sparse_feature_fm] + [label]
import paddle
batch_iter = paddle.batch(
paddle.reader.shuffle(
local_iter, buf_size=buf_size),
batch_size=batch)
return batch_iter
#generat inputs for testing
def test_reader(self, filelist, batch, buf_size):
print(filelist)
def local_iter():
for fname in filelist:
with open(fname.strip(), "r") as fin:
for line in fin:
dense_feature, sparse_feature, sparse_feature_fm, label = self._process_line(line)
yield [dense_feature] + sparse_feature + [sparse_feature_fm] + [label]
import paddle
batch_iter = paddle.batch(
paddle.reader.buffered(
local_iter, size=buf_size),
batch_size=batch)
return batch_iter
#generate inputs for trainig
def generate_sample(self, line):
def data_iter():
dense_feature, sparse_feature, sparse_feature_fm, label = self._process_line(line)
#feature_name = ["user_profile"]
feature_name = []
feature_name.append("dense_feature")
for idx in self.categorical_range_:
feature_name.append("context" + str(idx))
feature_name.append("context_fm")
feature_name.append("label")
yield zip(feature_name, [dense_feature] + sparse_feature + [sparse_feature_fm] + [label])
return data_iter
if __name__ == "__main__":
map_dataset = MapDataset()
map_dataset.setup(int(sys.argv[1]))
map_dataset.run_from_stdin()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import math
user_profile_dim = 65
dense_feature_dim = 3
def ctr_deepfm_dataset(dense_feature, context_feature, context_feature_fm, label,
embedding_size, sparse_feature_dim):
def dense_fm_layer(input, emb_dict_size, factor_size, fm_param_attr):
first_order = fluid.layers.fc(input=input, size=1)
emb_table = fluid.layers.create_parameter(shape=[emb_dict_size, factor_size],
dtype='float32', attr=fm_param_attr)
input_mul_factor = fluid.layers.matmul(input, emb_table)
input_mul_factor_square = fluid.layers.square(input_mul_factor)
input_square = fluid.layers.square(input)
factor_square = fluid.layers.square(emb_table)
input_square_mul_factor_square = fluid.layers.matmul(input_square, factor_square)
second_order = 0.5 * (input_mul_factor_square - input_square_mul_factor_square)
return first_order, second_order
dense_fm_param_attr = fluid.param_attr.ParamAttr(name="DenseFeatFactors",
initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(dense_feature_dim)))
dense_fm_first, dense_fm_second = dense_fm_layer(
dense_feature, dense_feature_dim, 16, dense_fm_param_attr)
def sparse_fm_layer(input, emb_dict_size, factor_size, fm_param_attr):
first_embeddings = fluid.layers.embedding(
input=input, dtype='float32', size=[emb_dict_size, 1], is_sparse=True)
first_order = fluid.layers.sequence_pool(input=first_embeddings, pool_type='sum')
nonzero_embeddings = fluid.layers.embedding(
input=input, dtype='float32', size=[emb_dict_size, factor_size],
param_attr=fm_param_attr, is_sparse=True)
summed_features_emb = fluid.layers.sequence_pool(input=nonzero_embeddings, pool_type='sum')
summed_features_emb_square = fluid.layers.square(summed_features_emb)
squared_features_emb = fluid.layers.square(nonzero_embeddings)
squared_sum_features_emb = fluid.layers.sequence_pool(
input=squared_features_emb, pool_type='sum')
second_order = 0.5 * (summed_features_emb_square - squared_sum_features_emb)
return first_order, second_order
sparse_fm_param_attr = fluid.param_attr.ParamAttr(name="SparseFeatFactors",
initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(sparse_feature_dim)))
#data = fluid.layers.data(name='ids', shape=[1], dtype='float32')
sparse_fm_first, sparse_fm_second = sparse_fm_layer(
context_feature_fm, sparse_feature_dim, 16, sparse_fm_param_attr)
def embedding_layer(input):
return fluid.layers.embedding(
input=input,
is_sparse=True,
# you need to patch https://github.com/PaddlePaddle/Paddle/pull/14190
# if you want to set is_distributed to True
is_distributed=False,
size=[sparse_feature_dim, embedding_size],
param_attr=fluid.ParamAttr(name="SparseFeatFactors",
initializer=fluid.initializer.Uniform()))
sparse_embed_seq = list(map(embedding_layer, context_feature))
concated_ori = fluid.layers.concat(sparse_embed_seq + [dense_feature], axis=1)
concated = fluid.layers.batch_norm(input=concated_ori, name="bn", epsilon=1e-4)
deep = deep_net(concated)
predict = fluid.layers.fc(input=[deep, sparse_fm_first, sparse_fm_second, dense_fm_first, dense_fm_second], size=2, act="softmax",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(deep.shape[1])), learning_rate=0.01))
#similarity_norm = fluid.layers.sigmoid(fluid.layers.clip(predict, min=-15.0, max=15.0), name="similarity_norm")
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.reduce_sum(cost)
accuracy = fluid.layers.accuracy(input=predict, label=label)
auc_var, batch_auc_var, auc_states = \
fluid.layers.auc(input=predict, label=label, num_thresholds=2 ** 12, slide_steps=20)
return avg_cost, auc_var, batch_auc_var, accuracy, predict
def deep_net(concated, lr_x=0.0001):
fc_layers_input = [concated]
fc_layers_size = [400, 400, 400]
fc_layers_act = ["relu"] * (len(fc_layers_size))
for i in range(len(fc_layers_size)):
fc = fluid.layers.fc(
input=fc_layers_input[-1],
size=fc_layers_size[i],
act=fc_layers_act[i],
param_attr=fluid.ParamAttr(learning_rate=lr_x * 0.5))
fc_layers_input.append(fc)
#w_res = fluid.layers.create_parameter(shape=[353, 16], dtype='float32', name="w_res")
#high_path = fluid.layers.matmul(concated, w_res)
#return fluid.layers.elementwise_add(high_path, fc_layers_input[-1])
return fc_layers_input[-1]
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import math
user_profile_dim = 65
num_context = 25
dim_fm_vector = 16
dim_concated = user_profile_dim + dim_fm_vector * (num_context)
def ctr_deepfm_dataset(user_profile, context_feature, label,
embedding_size, sparse_feature_dim):
def embedding_layer(input):
return fluid.layers.embedding(
input=input,
is_sparse=True,
# you need to patch https://github.com/PaddlePaddle/Paddle/pull/14190
# if you want to set is_distributed to True
is_distributed=False,
size=[sparse_feature_dim, embedding_size],
param_attr=fluid.ParamAttr(name="SparseFeatFactors",
initializer=fluid.initializer.Uniform()))
sparse_embed_seq = list(map(embedding_layer, context_feature))
w = fluid.layers.create_parameter(
shape=[65, 65], dtype='float32',
name="w_fm")
user_profile_emb = fluid.layers.matmul(user_profile, w)
concated_ori = fluid.layers.concat(sparse_embed_seq + [user_profile_emb], axis=1)
concated = fluid.layers.batch_norm(input=concated_ori, name="bn", epsilon=1e-4)
deep = deep_net(concated)
linear_term, second_term = fm(concated, dim_concated, 8) #depend on the number of context feature
predict = fluid.layers.fc(input=[deep, linear_term, second_term], size=2, act="softmax",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(deep.shape[1])), learning_rate=0.01))
#similarity_norm = fluid.layers.sigmoid(fluid.layers.clip(predict, min=-15.0, max=15.0), name="similarity_norm")
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.reduce_sum(cost)
accuracy = fluid.layers.accuracy(input=predict, label=label)
auc_var, batch_auc_var, auc_states = \
fluid.layers.auc(input=predict, label=label, num_thresholds=2 ** 12, slide_steps=20)
return avg_cost, auc_var, batch_auc_var, accuracy, predict
def deep_net(concated, lr_x=0.0001):
fc_layers_input = [concated]
fc_layers_size = [128, 64, 32, 16]
fc_layers_act = ["relu"] * (len(fc_layers_size))
for i in range(len(fc_layers_size)):
fc = fluid.layers.fc(
input=fc_layers_input[-1],
size=fc_layers_size[i],
act=fc_layers_act[i],
param_attr=fluid.ParamAttr(learning_rate=lr_x * 0.5))
fc_layers_input.append(fc)
return fc_layers_input[-1]
def fm(concated, emb_dict_size, factor_size, lr_x=0.0001):
linear_term = fluid.layers.fc(input=concated, size=8, act=None, param_attr=fluid.ParamAttr(learning_rate=lr_x))
emb_table = fluid.layers.create_parameter(shape=[emb_dict_size, factor_size],
dtype='float32')
input_mul_factor = fluid.layers.matmul(concated, emb_table)
input_mul_factor_square = fluid.layers.square(input_mul_factor)
input_square = fluid.layers.square(concated)
factor_square = fluid.layers.square(emb_table)
input_square_mul_factor_square = fluid.layers.matmul(input_square, factor_square)
second_term = 0.5 * (input_mul_factor_square - input_square_mul_factor_square)
return linear_term, second_term
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import math
user_profile_dim = 65
slot_1 = [0, 1, 2, 3, 4, 5]
slot_2 = [6]
slot_3 = [7, 8, 9, 10, 11]
slot_4 = [12, 13, 14, 15, 16]
slot_5 = [17, 18, 19, 20]
num_context = 25
num_slots_pair = 5
dim_fm_vector = 16
dim_concated = user_profile_dim + dim_fm_vector * (num_context + num_slots_pair)
def ctr_deepfm_dataset(user_profile, dense_feature, context_feature, label,
embedding_size, sparse_feature_dim):
def embedding_layer(input):
return fluid.layers.embedding(
input=input,
is_sparse=True,
# you need to patch https://github.com/PaddlePaddle/Paddle/pull/14190
# if you want to set is_distributed to True
is_distributed=False,
size=[sparse_feature_dim, embedding_size],
param_attr=fluid.ParamAttr(name="SparseFeatFactors",
initializer=fluid.initializer.Uniform()))
sparse_embed_seq = list(map(embedding_layer, context_feature))
w = fluid.layers.create_parameter(
shape=[65, 65], dtype='float32',
name="w_fm")
user_emb_list = []
user_profile_emb = fluid.layers.matmul(user_profile, w)
user_emb_list.append(user_profile_emb)
user_emb_list.append(dense_feature)
w1 = fluid.layers.create_parameter(shape=[65, dim_fm_vector], dtype='float32', name="w_1")
w2 = fluid.layers.create_parameter(shape=[65, dim_fm_vector], dtype='float32', name="w_2")
w3 = fluid.layers.create_parameter(shape=[65, dim_fm_vector], dtype='float32', name="w_3")
w4 = fluid.layers.create_parameter(shape=[65, dim_fm_vector], dtype='float32', name="w_4")
w5 = fluid.layers.create_parameter(shape=[65, dim_fm_vector], dtype='float32', name="w_5")
user_profile_emb_1 = fluid.layers.matmul(user_profile, w1)
user_profile_emb_2 = fluid.layers.matmul(user_profile, w2)
user_profile_emb_3 = fluid.layers.matmul(user_profile, w3)
user_profile_emb_4 = fluid.layers.matmul(user_profile, w4)
user_profile_emb_5 = fluid.layers.matmul(user_profile, w5)
sparse_embed_seq_1 = embedding_layer(context_feature[slot_1[0]])
sparse_embed_seq_2 = embedding_layer(context_feature[slot_2[0]])
sparse_embed_seq_3 = embedding_layer(context_feature[slot_3[0]])
sparse_embed_seq_4 = embedding_layer(context_feature[slot_4[0]])
sparse_embed_seq_5 = embedding_layer(context_feature[slot_5[0]])
for i in slot_1[1:-1]:
sparse_embed_seq_1 = fluid.layers.elementwise_add(sparse_embed_seq_1, embedding_layer(context_feature[i]))
for i in slot_2[1:-1]:
sparse_embed_seq_2 = fluid.layers.elementwise_add(sparse_embed_seq_2, embedding_layer(context_feature[i]))
for i in slot_3[1:-1]:
sparse_embed_seq_3 = fluid.layers.elementwise_add(sparse_embed_seq_3, embedding_layer(context_feature[i]))
for i in slot_4[1:-1]:
sparse_embed_seq_4 = fluid.layers.elementwise_add(sparse_embed_seq_4, embedding_layer(context_feature[i]))
for i in slot_5[1:-1]:
sparse_embed_seq_5 = fluid.layers.elementwise_add(sparse_embed_seq_5, embedding_layer(context_feature[i]))
ele_product_1 = fluid.layers.elementwise_mul(user_profile_emb_1, sparse_embed_seq_1)
user_emb_list.append(ele_product_1)
ele_product_2 = fluid.layers.elementwise_mul(user_profile_emb_2, sparse_embed_seq_2)
user_emb_list.append(ele_product_2)
ele_product_3 = fluid.layers.elementwise_mul(user_profile_emb_3, sparse_embed_seq_3)
user_emb_list.append(ele_product_3)
ele_product_4 = fluid.layers.elementwise_mul(user_profile_emb_4, sparse_embed_seq_4)
user_emb_list.append(ele_product_4)
ele_product_5 = fluid.layers.elementwise_mul(user_profile_emb_5, sparse_embed_seq_5)
user_emb_list.append(ele_product_5)
ffm_1 = fluid.layers.reduce_sum(ele_product_1, dim=1, keep_dim=True)
ffm_2 = fluid.layers.reduce_sum(ele_product_2, dim=1, keep_dim=True)
ffm_3 = fluid.layers.reduce_sum(ele_product_3, dim=1, keep_dim=True)
ffm_4 = fluid.layers.reduce_sum(ele_product_4, dim=1, keep_dim=True)
ffm_5 = fluid.layers.reduce_sum(ele_product_5, dim=1, keep_dim=True)
concated_ori = fluid.layers.concat(sparse_embed_seq + user_emb_list, axis=1)
concated = fluid.layers.batch_norm(input=concated_ori, name="bn", epsilon=1e-4)
deep = deep_net(concated)
linear_term, second_term = fm(concated, dim_concated, 8) #depend on the number of context feature
predict = fluid.layers.fc(input=[deep, linear_term, second_term, ffm_1, ffm_2, ffm_3, ffm_4, ffm_5], size=2, act="softmax",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(deep.shape[1])), learning_rate=0.01))
#similarity_norm = fluid.layers.sigmoid(fluid.layers.clip(predict, min=-15.0, max=15.0), name="similarity_norm")
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.reduce_sum(cost)
accuracy = fluid.layers.accuracy(input=predict, label=label)
auc_var, batch_auc_var, auc_states = \
fluid.layers.auc(input=predict, label=label, num_thresholds=2 ** 12, slide_steps=20)
return avg_cost, auc_var, batch_auc_var, accuracy, predict
def deep_net(concated, lr_x=0.0001):
fc_layers_input = [concated]
fc_layers_size = [256, 128, 64, 32, 16]
fc_layers_act = ["relu"] * (len(fc_layers_size))
for i in range(len(fc_layers_size)):
fc = fluid.layers.fc(
input=fc_layers_input[-1],
size=fc_layers_size[i],
act=fc_layers_act[i],
param_attr=fluid.ParamAttr(learning_rate=lr_x * 0.5))
fc_layers_input.append(fc)
w_res = fluid.layers.create_parameter(shape=[dim_concated, 16], dtype='float32', name="w_res")
high_path = fluid.layers.matmul(concated, w_res)
return fluid.layers.elementwise_add(high_path, fc_layers_input[-1])
#return fc_layers_input[-1]
def fm(concated, emb_dict_size, factor_size, lr_x=0.0001):
linear_term = fluid.layers.fc(input=concated, size=8, act=None, param_attr=fluid.ParamAttr(learning_rate=lr_x))
emb_table = fluid.layers.create_parameter(shape=[emb_dict_size, factor_size],
dtype='float32')
input_mul_factor = fluid.layers.matmul(concated, emb_table)
input_mul_factor_square = fluid.layers.square(input_mul_factor)
input_square = fluid.layers.square(concated)
factor_square = fluid.layers.square(emb_table)
input_square_mul_factor_square = fluid.layers.matmul(input_square, factor_square)
second_term = 0.5 * (input_mul_factor_square - input_square_mul_factor_square)
return linear_term, second_term
\ No newline at end of file
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import math
user_profile_dim = 65
dense_feature_dim = 3
def ctr_deepfm_dataset(dense_feature, context_feature, context_feature_fm, label,
embedding_size, sparse_feature_dim):
def dense_fm_layer(input, emb_dict_size, factor_size, fm_param_attr):
first_order = fluid.layers.fc(input=input, size=1)
emb_table = fluid.layers.create_parameter(shape=[emb_dict_size, factor_size],
dtype='float32', attr=fm_param_attr)
input_mul_factor = fluid.layers.matmul(input, emb_table)
input_mul_factor_square = fluid.layers.square(input_mul_factor)
input_square = fluid.layers.square(input)
factor_square = fluid.layers.square(emb_table)
input_square_mul_factor_square = fluid.layers.matmul(input_square, factor_square)
second_order = 0.5 * (input_mul_factor_square - input_square_mul_factor_square)
return first_order, second_order
dense_fm_param_attr = fluid.param_attr.ParamAttr(name="DenseFeatFactors",
initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(dense_feature_dim)))
dense_fm_first, dense_fm_second = dense_fm_layer(
dense_feature, dense_feature_dim, 16, dense_fm_param_attr)
def sparse_fm_layer(input, emb_dict_size, factor_size, fm_param_attr):
first_embeddings = fluid.layers.embedding(
input=input, dtype='float32', size=[emb_dict_size, 1], is_sparse=True)
first_order = fluid.layers.sequence_pool(input=first_embeddings, pool_type='sum')
nonzero_embeddings = fluid.layers.embedding(
input=input, dtype='float32', size=[emb_dict_size, factor_size],
param_attr=fm_param_attr, is_sparse=True)
summed_features_emb = fluid.layers.sequence_pool(input=nonzero_embeddings, pool_type='sum')
summed_features_emb_square = fluid.layers.square(summed_features_emb)
squared_features_emb = fluid.layers.square(nonzero_embeddings)
squared_sum_features_emb = fluid.layers.sequence_pool(
input=squared_features_emb, pool_type='sum')
second_order = 0.5 * (summed_features_emb_square - squared_sum_features_emb)
return first_order, second_order
sparse_fm_param_attr = fluid.param_attr.ParamAttr(name="SparseFeatFactors",
initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(sparse_feature_dim)))
#data = fluid.layers.data(name='ids', shape=[1], dtype='float32')
sparse_fm_first, sparse_fm_second = sparse_fm_layer(
context_feature_fm, sparse_feature_dim, 16, sparse_fm_param_attr)
def embedding_layer(input):
return fluid.layers.embedding(
input=input,
is_sparse=True,
# you need to patch https://github.com/PaddlePaddle/Paddle/pull/14190
# if you want to set is_distributed to True
is_distributed=False,
size=[sparse_feature_dim, embedding_size],
param_attr=fluid.ParamAttr(name="SparseFeatFactors",
initializer=fluid.initializer.Uniform()))
sparse_embed_seq = list(map(embedding_layer, context_feature))
concated_ori = fluid.layers.concat(sparse_embed_seq + [dense_feature], axis=1)
concated = fluid.layers.batch_norm(input=concated_ori, name="bn", epsilon=1e-4)
deep = deep_net(concated)
predict = fluid.layers.fc(input=[deep, sparse_fm_first, sparse_fm_second, dense_fm_first, dense_fm_second], size=2, act="softmax",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(deep.shape[1])), learning_rate=0.01))
#similarity_norm = fluid.layers.sigmoid(fluid.layers.clip(predict, min=-15.0, max=15.0), name="similarity_norm")
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.reduce_sum(cost)
accuracy = fluid.layers.accuracy(input=predict, label=label)
auc_var, batch_auc_var, auc_states = \
fluid.layers.auc(input=predict, label=label, num_thresholds=2 ** 12, slide_steps=20)
return avg_cost, auc_var, batch_auc_var, accuracy, predict
def deep_net(concated, lr_x=0.0001):
fc_layers_input = [concated]
fc_layers_size = [400, 400, 400]
fc_layers_act = ["relu"] * (len(fc_layers_size))
for i in range(len(fc_layers_size)):
fc = fluid.layers.fc(
input=fc_layers_input[-1],
size=fc_layers_size[i],
act=fc_layers_act[i],
param_attr=fluid.ParamAttr(learning_rate=lr_x * 0.5))
fc_layers_input.append(fc)
#w_res = fluid.layers.create_parameter(shape=[353, 16], dtype='float32', name="w_res")
#high_path = fluid.layers.matmul(concated, w_res)
#return fluid.layers.elementwise_add(high_path, fc_layers_input[-1])
return fc_layers_input[-1]
\ No newline at end of file
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys, time, random, csv, datetime, json
import pandas as pd
import numpy as np
import argparse
import logging
import time
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("preprocess")
logger.setLevel(logging.INFO)
TEST_QUERIES_PATH = "./data_set_phase1/test_queries.csv"
TEST_PLANS_PATH = "./data_set_phase1/test_plans.csv"
TRAIN_CLICK_PATH = "./data_set_phase1/train_clicks.csv"
PROFILES_PATH = "./data_set_phase1/profiles.csv"
OUT_NORM_TEST_PATH = "./out/normed_test_session.txt"
OUT_RAW_TEST_PATH = "./out/test_session.txt"
O1_MIN = 115.47
O1_MAX = 117.29
O2_MIN = 39.46
O2_MAX = 40.97
D1_MIN = 115.44
D1_MAX = 117.37
D2_MIN = 39.46
D2_MAX = 40.96
SCALE_OD = 0.02
DISTANCE_MIN = 1.0
DISTANCE_MAX = 225864.0
THRESHOLD_DIS = 40000.0
SCALE_DIS = 500
PRICE_MIN = 200.0
PRICE_MAX = 92300.0
THRESHOLD_PRICE = 20000
SCALE_PRICE = 100
ETA_MIN = 1.0
ETA_MAX = 72992.0
THRESHOLD_ETA = 10800.0
SCALE_ETA = 120
def build_norm_feature():
with open(OUT_NORM_TEST_PATH, 'w') as nf:
with open(OUT_RAW_TEST_PATH, 'r') as f:
for line in f:
cur_map = json.loads(line)
if cur_map["plan"]["distance"] > THRESHOLD_DIS:
cur_map["plan"]["distance"] = int(THRESHOLD_DIS)
elif cur_map["plan"]["distance"] > 0:
cur_map["plan"]["distance"] = int(cur_map["plan"]["distance"] / SCALE_DIS)
if cur_map["plan"]["price"] and cur_map["plan"]["price"] > THRESHOLD_PRICE:
cur_map["plan"]["price"] = int(THRESHOLD_PRICE)
elif not cur_map["plan"]["price"] or cur_map["plan"]["price"] < 0:
cur_map["plan"]["price"] = 0
else:
cur_map["plan"]["price"] = int(cur_map["plan"]["price"] / SCALE_PRICE)
if cur_map["plan"]["eta"] > THRESHOLD_ETA:
cur_map["plan"]["eta"] = int(THRESHOLD_ETA)
elif cur_map["plan"]["eta"] > 0:
cur_map["plan"]["eta"] = int(cur_map["plan"]["eta"] / SCALE_ETA)
# o1
if cur_map["query"]["o1"] > O1_MAX:
cur_map["query"]["o1"] = int((O1_MAX - O1_MIN) / SCALE_OD + 1)
elif cur_map["query"]["o1"] < O1_MIN:
cur_map["query"]["o1"] = 0
else:
cur_map["query"]["o1"] = int((cur_map["query"]["o1"] - O1_MIN) / 0.02)
# o2
if cur_map["query"]["o2"] > O2_MAX:
cur_map["query"]["o2"] = int((O2_MAX - O2_MIN) / SCALE_OD + 1)
elif cur_map["query"]["o2"] < O2_MIN:
cur_map["query"]["o2"] = 0
else:
cur_map["query"]["o2"] = int((cur_map["query"]["o2"] - O2_MIN) / 0.02)
# d1
if cur_map["query"]["d1"] > D1_MAX:
cur_map["query"]["d1"] = int((D1_MAX - D1_MIN) / SCALE_OD + 1)
elif cur_map["query"]["d1"] < D1_MIN:
cur_map["query"]["d1"] = 0
else:
cur_map["query"]["d1"] = int((cur_map["query"]["d1"] - D1_MIN) / SCALE_OD)
# d2
if cur_map["query"]["d2"] > D2_MAX:
cur_map["query"]["d2"] = int((D2_MAX - D2_MIN) / SCALE_OD + 1)
elif cur_map["query"]["d2"] < D2_MIN:
cur_map["query"]["d2"] = 0
else:
cur_map["query"]["d2"] = int((cur_map["query"]["d2"] - D2_MIN) / SCALE_OD)
cur_json_instance = json.dumps(cur_map)
nf.write(cur_json_instance + '\n')
def preprocess():
"""
Construct the train data indexed by session id and mode id jointly. Convert some of the raw features (user profile,
od pair, req time, click time, eta, price, distance, transport mode) to one-hot ids used for
embedding. We split the one-hot features into two categories: user feature and context feature for
better understanding of FM algorithm.
Note that the user profile is already provided by one-hot encoded form, we convert it back to the
ids for unity with the context feature and easily using of PaddlePaddle embedding layer. Given the
train clicks data, we label each train instance with 1 or 0 depend on if this instance is clicked or
not.
:return:
"""
train_data_dict = {}
with open("./weather.json", 'r') as f:
weather_dict = json.load(f)
with open(TEST_QUERIES_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
train_index_list = []
for k, line in enumerate(csv_reader):
if k == 0: continue
if line[0] == "": continue
if line[1] == "":
train_index_list.append(line[0] + "_0")
else:
train_index_list.append(line[0] + "_" + line[1])
train_index = line[0]
train_data_dict[train_index] = {}
train_data_dict[train_index]["pid"] = line[1]
train_data_dict[train_index]["query"] = {}
reqweekday = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%w")
reqhour = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%H")
date_key = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%m-%d")
train_data_dict[train_index]["weather"] = {}
train_data_dict[train_index]["weather"].update({"max_temp": weather_dict[date_key]["max_temp"]})
train_data_dict[train_index]["weather"].update({"min_temp": weather_dict[date_key]["min_temp"]})
train_data_dict[train_index]["weather"].update({"wea": weather_dict[date_key]["weather"]})
train_data_dict[train_index]["weather"].update({"wind": weather_dict[date_key]["wind"]})
train_data_dict[train_index]["query"].update({"weekday":reqweekday})
train_data_dict[train_index]["query"].update({"hour":reqhour})
o = line[3].split(',')
o_first = o[0]
o_second = o[1]
train_data_dict[train_index]["query"].update({"o1":float(o_first)})
train_data_dict[train_index]["query"].update({"o2":float(o_second)})
d = line[4].split(',')
d_first = d[0]
d_second = d[1]
train_data_dict[train_index]["query"].update({"d1":float(d_first)})
train_data_dict[train_index]["query"].update({"d2":float(d_second)})
plan_map = {}
plan_data = pd.read_csv(TEST_PLANS_PATH)
for index, row in plan_data.iterrows():
plans_str = row['plans']
plans_list = json.loads(plans_str)
session_id = str(row['sid'])
# train_data_dict[session_id]["plans"] = []
plan_map[session_id] = plans_list
profile_map = {}
with open(PROFILES_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
for k, line in enumerate(csv_reader):
if k == 0: continue
profile_map[line[0]] = [i for i in range(len(line)) if line[i] == "1.0"]
session_click_map = {}
with open(TRAIN_CLICK_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
for k, line in enumerate(csv_reader):
if k == 0: continue
if line[0] == "" or line[1] == "" or line[2] == "":
continue
session_click_map[line[0]] = line[2]
#return train_data_dict, profile_map, session_click_map, plan_map
generate_sparse_features(train_data_dict, profile_map, session_click_map, plan_map)
def generate_sparse_features(train_data_dict, profile_map, session_click_map, plan_map):
if not os.path.isdir("./out/"):
os.mkdir("./out/")
with open(os.path.join("./out/", "test_session.txt"), 'w') as f_train:
for session_id, plan_list in plan_map.items():
if session_id not in train_data_dict:
continue
cur_map = train_data_dict[session_id]
cur_map["session_id"] = session_id
if cur_map["pid"] != "":
cur_map["profile"] = profile_map[cur_map["pid"]]
else:
cur_map["profile"] = [0]
del cur_map["pid"]
whole_rank = 0
for plan in plan_list:
whole_rank += 1
cur_map["mode_rank" + str(whole_rank)] = plan["transport_mode"]
if whole_rank < 5:
for r in range(whole_rank + 1, 6):
cur_map["mode_rank" + str(r)] = -1
cur_map["whole_rank"] = whole_rank
flag_click = False
rank = 1
price_list = []
eta_list = []
distance_list = []
for plan in plan_list:
if not plan["price"]:
price_list.append(0)
else:
price_list.append(int(plan["price"]))
eta_list.append(int(plan["eta"]))
distance_list.append(int(plan["distance"]))
price_list.sort(reverse=False)
eta_list.sort(reverse=False)
distance_list.sort(reverse=False)
for plan in plan_list:
if plan["price"] and int(plan["price"]) == price_list[0]:
cur_map["mode_min_price"] = plan["transport_mode"]
if plan["price"] and int(plan["price"]) == price_list[-1]:
cur_map["mode_max_price"] = plan["transport_mode"]
if int(plan["eta"]) == eta_list[0]:
cur_map["mode_min_eta"] = plan["transport_mode"]
if int(plan["eta"]) == eta_list[-1]:
cur_map["mode_max_eta"] = plan["transport_mode"]
if int(plan["distance"]) == distance_list[0]:
cur_map["mode_min_distance"] = plan["transport_mode"]
if int(plan["distance"]) == distance_list[-1]:
cur_map["mode_max_distance"] = plan["transport_mode"]
if "mode_min_price" not in cur_map:
cur_map["mode_min_price"] = -1
if "mode_max_price" not in cur_map:
cur_map["mode_max_price"] = -1
for plan in plan_list:
cur_price = int(plan["price"]) if plan["price"] else 0
cur_eta = int(plan["eta"])
cur_distance = int(plan["distance"])
cur_map["price_rank"] = price_list.index(cur_price) + 1
cur_map["eta_rank"] = eta_list.index(cur_eta) + 1
cur_map["distance_rank"] = distance_list.index(cur_distance) + 1
if ("transport_mode" in plan) and (session_id in session_click_map) and (
int(plan["transport_mode"]) == int(session_click_map[session_id])):
cur_map["plan"] = plan
cur_map["label"] = 1
flag_click = True
# print("label is 1")
else:
cur_map["plan"] = plan
cur_map["label"] = 0
cur_map["plan_rank"] = rank
rank += 1
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
cur_map["plan"]["distance"] = -1
cur_map["plan"]["price"] = -1
cur_map["plan"]["eta"] = -1
cur_map["plan"]["transport_mode"] = 0
cur_map["plan_rank"] = 0
cur_map["price_rank"] = 0
cur_map["eta_rank"] = 0
cur_map["plan_rank"] = 0
cur_map["label"] = 1
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
build_norm_feature()
if __name__ == "__main__":
preprocess()
\ No newline at end of file
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys, time, random, csv, datetime, json
import pandas as pd
import numpy as np
import argparse
import logging
import time
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("preprocess")
logger.setLevel(logging.INFO)
TRAIN_QUERIES_PATH = "./data_set_phase1/test_queries.csv"
TRAIN_PLANS_PATH = "./data_set_phase1/test_plans.csv"
TRAIN_CLICK_PATH = "./data_set_phase1/train_clicks.csv"
PROFILES_PATH = "./data_set_phase1/profiles.csv"
O1_MIN = 115.47
O1_MAX = 117.29
O2_MIN = 39.46
O2_MAX = 40.97
D1_MIN = 115.44
D1_MAX = 117.37
D2_MIN = 39.46
D2_MAX = 40.96
DISTANCE_MIN = 1.0
DISTANCE_MAX = 225864.0
THRESHOLD_DIS = 200000.0
PRICE_MIN = 200.0
PRICE_MAX = 92300.0
THRESHOLD_PRICE = 20000
ETA_MIN = 1.0
ETA_MAX = 72992.0
THRESHOLD_ETA = 10800.0
def build_norm_feature():
with open("./out/normed_test_session.txt", 'w') as nf:
with open("./out/test_session.txt", 'r') as f:
for line in f:
cur_map = json.loads(line)
cur_map["plan"]["distance"] = (cur_map["plan"]["distance"] - DISTANCE_MIN) / (DISTANCE_MAX - DISTANCE_MIN)
if cur_map["plan"]["price"]:
cur_map["plan"]["price"] = (cur_map["plan"]["price"] - PRICE_MIN) / (PRICE_MAX - PRICE_MIN)
else:
cur_map["plan"]["price"] = 0.0
cur_map["plan"]["eta"] = (cur_map["plan"]["eta"] - ETA_MIN) / (ETA_MAX - ETA_MIN)
cur_json_instance = json.dumps(cur_map)
nf.write(cur_json_instance + '\n')
def preprocess():
"""
Construct the train data indexed by session id and mode id jointly. Convert all the raw features (user profile,
od pair, req time, click time, eta, price, distance, transport mode) to one-hot ids used for
embedding. We split the one-hot features into two categories: user feature and context feature for
better understanding of FFM algorithm.
Note that the user profile is already provided by one-hot encoded form, we convert it back to the
ids for unity with the context feature and easily using of PaddlePaddle embedding layer. Given the
train clicks data, we label each train instance with 1 or 0 depend on if this instance is clicked or
not.
:return:
"""
#args = parse_args()
train_data_dict = {}
with open("./weather.json", 'r') as f:
weather_dict = json.load(f)
with open(TRAIN_QUERIES_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
train_index_list = []
for k, line in enumerate(csv_reader):
if k == 0: continue
if line[0] == "": continue
if line[1] == "":
train_index_list.append(line[0] + "_0")
else:
train_index_list.append(line[0] + "_" + line[1])
train_index = line[0]
train_data_dict[train_index] = {}
train_data_dict[train_index]["pid"] = line[1]
train_data_dict[train_index]["query"] = {}
reqweekday = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%w")
reqhour = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%H")
date_key = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%m-%d")
train_data_dict[train_index]["weather"] = {}
train_data_dict[train_index]["weather"].update({"max_temp": weather_dict[date_key]["max_temp"]})
train_data_dict[train_index]["weather"].update({"min_temp": weather_dict[date_key]["min_temp"]})
train_data_dict[train_index]["weather"].update({"wea": weather_dict[date_key]["weather"]})
train_data_dict[train_index]["weather"].update({"wind": weather_dict[date_key]["wind"]})
train_data_dict[train_index]["query"].update({"weekday":reqweekday})
train_data_dict[train_index]["query"].update({"hour":reqhour})
o = line[3].split(',')
o_first = o[0]
o_second = o[1]
train_data_dict[train_index]["query"].update({"o1":float(o_first)})
train_data_dict[train_index]["query"].update({"o2":float(o_second)})
d = line[4].split(',')
d_first = d[0]
d_second = d[1]
train_data_dict[train_index]["query"].update({"d1":float(d_first)})
train_data_dict[train_index]["query"].update({"d2":float(d_second)})
plan_map = {}
plan_data = pd.read_csv(TRAIN_PLANS_PATH)
for index, row in plan_data.iterrows():
plans_str = row['plans']
plans_list = json.loads(plans_str)
session_id = str(row['sid'])
# train_data_dict[session_id]["plans"] = []
plan_map[session_id] = plans_list
profile_map = {}
with open(PROFILES_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
for k, line in enumerate(csv_reader):
if k == 0: continue
profile_map[line[0]] = [i for i in range(len(line)) if line[i] == "1.0"]
session_click_map = {}
with open(TRAIN_CLICK_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
for k, line in enumerate(csv_reader):
if k == 0: continue
if line[0] == "" or line[1] == "" or line[2] == "":
continue
session_click_map[line[0]] = line[2]
#return train_data_dict, profile_map, session_click_map, plan_map
generate_sparse_features(train_data_dict, profile_map, session_click_map, plan_map)
def generate_sparse_features(train_data_dict, profile_map, session_click_map, plan_map):
if not os.path.isdir("./out/"):
os.mkdir("./out/")
with open(os.path.join("./out/", "test_session.txt"), 'w') as f_train:
for session_id, plan_list in plan_map.items():
if session_id not in train_data_dict:
continue
cur_map = train_data_dict[session_id]
cur_map["session_id"] = session_id
if cur_map["pid"] != "":
cur_map["profile"] = profile_map[cur_map["pid"]]
else:
cur_map["profile"] = [0]
# del cur_map["pid"]
whole_rank = 0
for plan in plan_list:
whole_rank += 1
cur_map["mode_rank" + str(whole_rank)] = plan["transport_mode"]
if whole_rank < 5:
for r in range(whole_rank + 1, 6):
cur_map["mode_rank" + str(r)] = -1
cur_map["whole_rank"] = whole_rank
rank = 1
price_list = []
eta_list = []
distance_list = []
for plan in plan_list:
if not plan["price"]:
price_list.append(0)
else:
price_list.append(int(plan["price"]))
eta_list.append(int(plan["eta"]))
distance_list.append(int(plan["distance"]))
price_list.sort(reverse=False)
eta_list.sort(reverse=False)
distance_list.sort(reverse=False)
for plan in plan_list:
if plan["price"] and int(plan["price"]) == price_list[0]:
cur_map["mode_min_price"] = plan["transport_mode"]
if plan["price"] and int(plan["price"]) == price_list[-1]:
cur_map["mode_max_price"] = plan["transport_mode"]
if int(plan["eta"]) == eta_list[0]:
cur_map["mode_min_eta"] = plan["transport_mode"]
if int(plan["eta"]) == eta_list[-1]:
cur_map["mode_max_eta"] = plan["transport_mode"]
if int(plan["distance"]) == distance_list[0]:
cur_map["mode_min_distance"] = plan["transport_mode"]
if int(plan["distance"]) == distance_list[-1]:
cur_map["mode_max_distance"] = plan["transport_mode"]
if "mode_min_price" not in cur_map:
cur_map["mode_min_price"] = -1
if "mode_max_price" not in cur_map:
cur_map["mode_max_price"] = -1
for plan in plan_list:
cur_price = int(plan["price"]) if plan["price"] else 0
cur_eta = int(plan["eta"])
cur_distance = int(plan["distance"])
cur_map["price_rank"] = price_list.index(cur_price) + 1
cur_map["eta_rank"] = eta_list.index(cur_eta) + 1
cur_map["distance_rank"] = distance_list.index(cur_distance) + 1
if ("transport_mode" in plan) and (session_id in session_click_map) and (
int(plan["transport_mode"]) == int(session_click_map[session_id])):
cur_map["plan"] = plan
cur_map["label"] = 1
else:
cur_map["plan"] = plan
cur_map["label"] = 0
cur_map["plan_rank"] = rank
rank += 1
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
cur_map["plan"]["distance"] = -1
cur_map["plan"]["price"] = -1
cur_map["plan"]["eta"] = -1
cur_map["plan"]["transport_mode"] = 0
cur_map["plan_rank"] = 0
cur_map["price_rank"] = 0
cur_map["eta_rank"] = 0
cur_map["plan_rank"] = 0
cur_map["label"] = 1
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
build_norm_feature()
if __name__ == "__main__":
preprocess()
\ No newline at end of file
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys, time, random, csv, datetime, json
import pandas as pd
import numpy as np
import argparse
import logging
import time
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("preprocess")
logger.setLevel(logging.INFO)
TRAIN_QUERIES_PATH = "./data_set_phase1/train_queries.csv"
TRAIN_PLANS_PATH = "./data_set_phase1/train_plans.csv"
TRAIN_CLICK_PATH = "./data_set_phase1/train_clicks.csv"
PROFILES_PATH = "./data_set_phase1/profiles.csv"
OUT_NORM_TRAIN_PATH = "./out/normed_train.txt"
OUT_RAW_TRAIN_PATH = "./out/train.txt"
OUT_DIR = "./out"
O1_MIN = 115.47
O1_MAX = 117.29
O2_MIN = 39.46
O2_MAX = 40.97
D1_MIN = 115.44
D1_MAX = 117.37
D2_MIN = 39.46
D2_MAX = 40.96
SCALE_OD = 0.02
DISTANCE_MIN = 1.0
DISTANCE_MAX = 225864.0
THRESHOLD_DIS = 40000.0
SCALE_DIS = 500
PRICE_MIN = 200.0
PRICE_MAX = 92300.0
THRESHOLD_PRICE = 20000
SCALE_PRICE = 100
ETA_MIN = 1.0
ETA_MAX = 72992.0
THRESHOLD_ETA = 10800.0
SCALE_ETA = 120
def build_norm_feature():
with open(OUT_NORM_TRAIN_PATH, 'w') as nf:
with open(OUT_RAW_TRAIN_PATH, 'r') as f:
for line in f:
cur_map = json.loads(line)
if cur_map["plan"]["distance"] > THRESHOLD_DIS:
cur_map["plan"]["distance"] = int(THRESHOLD_DIS)
elif cur_map["plan"]["distance"] > 0:
cur_map["plan"]["distance"] = int(cur_map["plan"]["distance"] / SCALE_DIS)
if cur_map["plan"]["price"] and cur_map["plan"]["price"] > THRESHOLD_PRICE:
cur_map["plan"]["price"] = int(THRESHOLD_PRICE)
elif not cur_map["plan"]["price"] or cur_map["plan"]["price"] < 0:
cur_map["plan"]["price"] = 0
else:
cur_map["plan"]["price"] = int(cur_map["plan"]["price"] / SCALE_PRICE)
if cur_map["plan"]["eta"] > THRESHOLD_ETA:
cur_map["plan"]["eta"] = int(THRESHOLD_ETA)
elif cur_map["plan"]["eta"] > 0:
cur_map["plan"]["eta"] = int(cur_map["plan"]["eta"] / SCALE_ETA)
# o1
if cur_map["query"]["o1"] > O1_MAX:
cur_map["query"]["o1"] = int((O1_MAX - O1_MIN) / SCALE_OD + 1)
elif cur_map["query"]["o1"] < O1_MIN:
cur_map["query"]["o1"] = 0
else:
cur_map["query"]["o1"] = int((cur_map["query"]["o1"] - O1_MIN) / 0.02)
# o2
if cur_map["query"]["o2"] > O2_MAX:
cur_map["query"]["o2"] = int((O2_MAX - O2_MIN) / SCALE_OD + 1)
elif cur_map["query"]["o2"] < O2_MIN:
cur_map["query"]["o2"] = 0
else:
cur_map["query"]["o2"] = int((cur_map["query"]["o2"] - O2_MIN) / 0.02)
# d1
if cur_map["query"]["d1"] > D1_MAX:
cur_map["query"]["d1"] = int((D1_MAX - D1_MIN) / SCALE_OD + 1)
elif cur_map["query"]["d1"] < D1_MIN:
cur_map["query"]["d1"] = 0
else:
cur_map["query"]["d1"] = int((cur_map["query"]["d1"] - D1_MIN) / SCALE_OD)
# d2
if cur_map["query"]["d2"] > D2_MAX:
cur_map["query"]["d2"] = int((D2_MAX - D2_MIN) / SCALE_OD + 1)
elif cur_map["query"]["d2"] < D2_MIN:
cur_map["query"]["d2"] = 0
else:
cur_map["query"]["d2"] = int((cur_map["query"]["d2"] - D2_MIN) / SCALE_OD)
cur_json_instance = json.dumps(cur_map)
nf.write(cur_json_instance + '\n')
def preprocess():
"""
Construct the train data indexed by session id and mode id jointly. Convert all the raw features (user profile,
od pair, req time, click time, eta, price, distance, transport mode) to one-hot ids used for
embedding. We split the one-hot features into two categories: user feature and context feature for
better understanding of FM algorithm.
Note that the user profile is already provided by one-hot encoded form, we treat it as embedded vector
for unity with the context feature and easily using of PaddlePaddle embedding layer. Given the
train clicks data, we label each train instance with 1 or 0 depend on if this instance is clicked or
not include non-click case.
:return:
"""
train_data_dict = {}
with open(TRAIN_QUERIES_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
train_index_list = []
for k, line in enumerate(csv_reader):
if k == 0: continue
if line[0] == "": continue
if line[1] == "":
train_index_list.append(line[0] + "_0")
else:
train_index_list.append(line[0] + "_" + line[1])
train_index = line[0]
train_data_dict[train_index] = {}
train_data_dict[train_index]["pid"] = line[1]
train_data_dict[train_index]["query"] = {}
reqweekday = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%w")
reqhour = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%H")
train_data_dict[train_index]["query"].update({"weekday":reqweekday})
train_data_dict[train_index]["query"].update({"hour":reqhour})
o = line[3].split(',')
o_first = o[0]
o_second = o[1]
train_data_dict[train_index]["query"].update({"o1":float(o_first)})
train_data_dict[train_index]["query"].update({"o2":float(o_second)})
d = line[4].split(',')
d_first = d[0]
d_second = d[1]
train_data_dict[train_index]["query"].update({"d1":float(d_first)})
train_data_dict[train_index]["query"].update({"d2":float(d_second)})
plan_map = {}
plan_data = pd.read_csv(TRAIN_PLANS_PATH)
for index, row in plan_data.iterrows():
plans_str = row['plans']
plans_list = json.loads(plans_str)
session_id = str(row['sid'])
# train_data_dict[session_id]["plans"] = []
plan_map[session_id] = plans_list
profile_map = {}
with open(PROFILES_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
for k, line in enumerate(csv_reader):
if k == 0: continue
profile_map[line[0]] = [i for i in range(len(line)) if line[i] == "1.0"]
session_click_map = {}
with open(TRAIN_CLICK_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
for k, line in enumerate(csv_reader):
if k == 0: continue
if line[0] == "" or line[1] == "" or line[2] == "":
continue
session_click_map[line[0]] = line[2]
#return train_data_dict, profile_map, session_click_map, plan_map
generate_sparse_features(train_data_dict, profile_map, session_click_map, plan_map)
def generate_sparse_features(train_data_dict, profile_map, session_click_map, plan_map):
if not os.path.isdir(OUT_DIR):
os.mkdir(OUT_DIR)
with open(os.path.join("./out/", "train.txt"), 'w') as f_train:
for session_id, plan_list in plan_map.items():
if session_id not in train_data_dict:
continue
cur_map = train_data_dict[session_id]
if cur_map["pid"] != "":
cur_map["profile"] = profile_map[cur_map["pid"]]
else:
cur_map["profile"] = [0]
del cur_map["pid"]
whole_rank = 0
for plan in plan_list:
whole_rank += 1
cur_map["whole_rank"] = whole_rank
flag_click = False
rank = 1
for plan in plan_list:
if ("transport_mode" in plan) and (session_id in session_click_map) and (
int(plan["transport_mode"]) == int(session_click_map[session_id])):
cur_map["plan"] = plan
cur_map["label"] = 1
flag_click = True
# print("label is 1")
else:
cur_map["plan"] = plan
cur_map["label"] = 0
cur_map["rank"] = rank
rank += 1
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
if not flag_click:
cur_map["plan"]["distance"] = -1
cur_map["plan"]["price"] = -1
cur_map["plan"]["eta"] = -1
cur_map["plan"]["transport_mode"] = 0
cur_map["rank"] = 0
cur_map["label"] = 1
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
else:
cur_map["plan"]["distance"] = -1
cur_map["plan"]["price"] = -1
cur_map["plan"]["eta"] = -1
cur_map["plan"]["transport_mode"] = 0
cur_map["rank"] = 0
cur_map["label"] = 0
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
build_norm_feature()
if __name__ == "__main__":
preprocess()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, random, csv, datetime, json
import pandas as pd
import numpy as np
import argparse
import logging
import time
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("preprocess")
logger.setLevel(logging.INFO)
TRAIN_QUERIES_PATH = "./data_set_phase1/train_queries.csv"
TRAIN_PLANS_PATH = "./data_set_phase1/train_plans.csv"
TRAIN_CLICK_PATH = "./data_set_phase1/train_clicks.csv"
PROFILES_PATH = "./data_set_phase1/profiles.csv"
OUT_DIR = "./out"
ORI_TRAIN_PATH = "train.txt"
NORM_TRAIN_PATH = "normed_train.txt"
#variable to control the ratio of positive and negative instance of transmode 0 which is original label of no click
THRESHOLD_LABEL = 0.5
O1_MIN = 115.47
O1_MAX = 117.29
O2_MIN = 39.46
O2_MAX = 40.97
D1_MIN = 115.44
D1_MAX = 117.37
D2_MIN = 39.46
D2_MAX = 40.96
DISTANCE_MIN = 1.0
DISTANCE_MAX = 225864.0
THRESHOLD_DIS = 200000.0
PRICE_MIN = 200.0
PRICE_MAX = 92300.0
THRESHOLD_PRICE = 20000
ETA_MIN = 1.0
ETA_MAX = 72992.0
THRESHOLD_ETA = 10800.0
def build_norm_feature():
with open(os.path.join(OUT_DIR, NORM_TRAIN_PATH), 'w') as nf:
with open(os.path.join(OUT_DIR, ORI_TRAIN_PATH), 'r') as f:
for line in f:
cur_map = json.loads(line)
cur_map["plan"]["distance"] = (cur_map["plan"]["distance"] - DISTANCE_MIN) / (DISTANCE_MAX - DISTANCE_MIN)
if cur_map["plan"]["price"]:
cur_map["plan"]["price"] = (cur_map["plan"]["price"] - PRICE_MIN) / (PRICE_MAX - PRICE_MIN)
else:
cur_map["plan"]["price"] = 0.0
cur_map["plan"]["eta"] = (cur_map["plan"]["eta"] - ETA_MIN) / (ETA_MAX - ETA_MIN)
cur_json_instance = json.dumps(cur_map)
nf.write(cur_json_instance + '\n')
def preprocess():
"""
Construct the train data indexed by session id and mode id jointly. Convert all the raw features (user profile,
od pair, req time, click time, eta, price, distance, transport mode) to one-hot ids used for
embedding. We split the one-hot features into two categories: user feature and context feature for
better understanding of FM algorithm.
Note that the user profile is already provided by one-hot encoded form, we treat it as embedded vector
for unity with the context feature and easily using of PaddlePaddle embedding layer. Given the
train clicks data, we label each train instance with 1 or 0 depend on if this instance is clicked or
not include non-click case. To Be Changed
:return:
"""
train_data_dict = {}
with open("./weather.json", 'r') as f:
weather_dict = json.load(f)
with open(TRAIN_QUERIES_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
train_index_list = []
for k, line in enumerate(csv_reader):
if k == 0: continue
if line[0] == "": continue
if line[1] == "":
train_index_list.append(line[0] + "_0")
else:
train_index_list.append(line[0] + "_" + line[1])
train_index = line[0]
train_data_dict[train_index] = {}
train_data_dict[train_index]["pid"] = line[1]
train_data_dict[train_index]["query"] = {}
train_data_dict[train_index]["weather"] = {}
reqweekday = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%w")
reqhour = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%H")
# weather related features, no big use, maybe more detailed weather information is better
date_key = datetime.datetime.strptime(line[2], '%Y-%m-%d %H:%M:%S').strftime("%m-%d")
train_data_dict[train_index]["weather"] = {}
train_data_dict[train_index]["weather"].update({"max_temp": weather_dict[date_key]["max_temp"]})
train_data_dict[train_index]["weather"].update({"min_temp": weather_dict[date_key]["min_temp"]})
train_data_dict[train_index]["weather"].update({"wea": weather_dict[date_key]["weather"]})
train_data_dict[train_index]["weather"].update({"wind": weather_dict[date_key]["wind"]})
train_data_dict[train_index]["query"].update({"weekday":reqweekday})
train_data_dict[train_index]["query"].update({"hour":reqhour})
o = line[3].split(',')
o_first = o[0]
o_second = o[1]
train_data_dict[train_index]["query"].update({"o1":float(o_first)})
train_data_dict[train_index]["query"].update({"o2":float(o_second)})
d = line[4].split(',')
d_first = d[0]
d_second = d[1]
train_data_dict[train_index]["query"].update({"d1":float(d_first)})
train_data_dict[train_index]["query"].update({"d2":float(d_second)})
plan_map = {}
plan_data = pd.read_csv(TRAIN_PLANS_PATH)
for index, row in plan_data.iterrows():
plans_str = row['plans']
plans_list = json.loads(plans_str)
session_id = str(row['sid'])
# train_data_dict[session_id]["plans"] = []
plan_map[session_id] = plans_list
profile_map = {}
with open(PROFILES_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
for k, line in enumerate(csv_reader):
if k == 0: continue
profile_map[line[0]] = [i for i in range(len(line)) if line[i] == "1.0"]
session_click_map = {}
with open(TRAIN_CLICK_PATH, 'r') as f:
csv_reader = csv.reader(f, delimiter=',')
for k, line in enumerate(csv_reader):
if k == 0: continue
if line[0] == "" or line[1] == "" or line[2] == "":
continue
session_click_map[line[0]] = line[2]
#return train_data_dict, profile_map, session_click_map, plan_map
generate_sparse_features(train_data_dict, profile_map, session_click_map, plan_map)
def generate_sparse_features(train_data_dict, profile_map, session_click_map, plan_map):
if not os.path.isdir(OUT_DIR):
os.mkdir(OUT_DIR)
with open(os.path.join(OUT_DIR, ORI_TRAIN_PATH), 'w') as f_train:
for session_id, plan_list in plan_map.items():
if session_id not in train_data_dict:
continue
cur_map = train_data_dict[session_id]
if cur_map["pid"] != "":
cur_map["profile"] = profile_map[cur_map["pid"]]
else:
cur_map["profile"] = [0]
#rank information related feature
whole_rank = 0
for plan in plan_list:
whole_rank += 1
cur_map["mode_rank" + str(whole_rank)] = plan["transport_mode"]
if whole_rank < 5:
for r in range(whole_rank + 1, 6):
cur_map["mode_rank" + str(r)] = -1
cur_map["whole_rank"] = whole_rank
flag_click = False
rank = 1
price_list = []
eta_list = []
distance_list = []
for plan in plan_list:
if not plan["price"]:
price_list.append(0)
else:
price_list.append(int(plan["price"]))
eta_list.append(int(plan["eta"]))
distance_list.append(int(plan["distance"]))
price_list.sort(reverse=False)
eta_list.sort(reverse=False)
distance_list.sort(reverse=False)
for plan in plan_list:
if plan["price"] and int(plan["price"]) == price_list[0]:
cur_map["mode_min_price"] = plan["transport_mode"]
if plan["price"] and int(plan["price"]) == price_list[-1]:
cur_map["mode_max_price"] = plan["transport_mode"]
if int(plan["eta"]) == eta_list[0]:
cur_map["mode_min_eta"] = plan["transport_mode"]
if int(plan["eta"]) == eta_list[-1]:
cur_map["mode_max_eta"] = plan["transport_mode"]
if int(plan["distance"]) == distance_list[0]:
cur_map["mode_min_distance"] = plan["transport_mode"]
if int(plan["distance"]) == distance_list[-1]:
cur_map["mode_max_distance"] = plan["transport_mode"]
if "mode_min_price" not in cur_map:
cur_map["mode_min_price"] = -1
if "mode_max_price" not in cur_map:
cur_map["mode_max_price"] = -1
for plan in plan_list:
if ("transport_mode" in plan) and (session_id in session_click_map) and (
int(plan["transport_mode"]) == int(session_click_map[session_id])):
flag_click = True
if flag_click:
for plan in plan_list:
cur_price = int(plan["price"]) if plan["price"] else 0
cur_eta = int(plan["eta"])
cur_distance = int(plan["distance"])
cur_map["price_rank"] = price_list.index(cur_price) + 1
cur_map["eta_rank"] = eta_list.index(cur_eta) + 1
cur_map["distance_rank"] = distance_list.index(cur_distance) + 1
if ("transport_mode" in plan) and (session_id in session_click_map) and (
int(plan["transport_mode"]) == int(session_click_map[session_id])):
cur_map["plan"] = plan
cur_map["label"] = 1
else:
cur_map["plan"] = plan
cur_map["label"] = 0
cur_map["plan_rank"] = rank
rank += 1
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
cur_map["plan"] = {}
#since we define a new ctr task from original task, we use a basic way to generate instances of transport mode 0.
#There should be a optimal strategy to generate instances of transport mode 0
if not flag_click:
cur_map["plan"]["distance"] = -1
cur_map["plan"]["price"] = -1
cur_map["plan"]["eta"] = -1
cur_map["plan"]["transport_mode"] = 0
cur_map["plan_rank"] = 0
cur_map["price_rank"] = 0
cur_map["eta_rank"] = 0
cur_map["distance_rank"] = 0
cur_map["label"] = 1
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
else:
if random.random() < THRESHOLD_LABEL:
cur_map["plan"]["distance"] = -1
cur_map["plan"]["price"] = -1
cur_map["plan"]["eta"] = -1
cur_map["plan"]["transport_mode"] = 0
cur_map["plan_rank"] = 0
cur_map["price_rank"] = 0
cur_map["eta_rank"] = 0
cur_map["distance_rank"] = 0
cur_map["label"] = 0
cur_json_instance = json.dumps(cur_map)
f_train.write(cur_json_instance + '\n')
build_norm_feature()
if __name__ == "__main__":
preprocess()
{"10-01": {"max_temp": "24", "min_temp": "12", "weather": "q", "wind": "45"}, "10-02": {"max_temp": "24", "min_temp": "11", "weather": "q", "wind": "12"}, "10-03": {"max_temp": "25", "min_temp": "10", "weather": "q", "wind": "12"}, "10-04": {"max_temp": "25", "min_temp": "12", "weather": "q", "wind": "12"}, "10-05": {"max_temp": "24", "min_temp": "14", "weather": "dy", "wind": "12"}, "10-06": {"max_temp": "20", "min_temp": "8", "weather": "q", "wind": "45"}, "10-07": {"max_temp": "21", "min_temp": "7", "weather": "q", "wind": "12"}, "10-08": {"max_temp": "21", "min_temp": "8", "weather": "dy", "wind": "12"}, "10-09": {"max_temp": "15", "min_temp": "4", "weather": "dyq", "wind": "45"}, "10-10": {"max_temp": "17", "min_temp": "4", "weather": "dyq", "wind": "12"}, "10-11": {"max_temp": "18", "min_temp": "5", "weather": "qdy", "wind": "12"}, "10-12": {"max_temp": "20", "min_temp": "5", "weather": "dyq", "wind": "12"}, "10-13": {"max_temp": "20", "min_temp": "8", "weather": "dy", "wind": "12"}, "10-14": {"max_temp": "21", "min_temp": "10", "weather": "dy", "wind": "12"}, "10-15": {"max_temp": "17", "min_temp": "11", "weather": "xq", "wind": "12"}, "10-16": {"max_temp": "17", "min_temp": "7", "weather": "dyq", "wind": "12"}, "10-17": {"max_temp": "17", "min_temp": "5", "weather": "q", "wind": "12"}, "10-18": {"max_temp": "18", "min_temp": "5", "weather": "q", "wind": "12"}, "10-19": {"max_temp": "19", "min_temp": "7", "weather": "dy", "wind": "12"}, "10-20": {"max_temp": "18", "min_temp": "7", "weather": "dy", "wind": "12"}, "10-21": {"max_temp": "18", "min_temp": "7", "weather": "dy", "wind": "12"}, "10-22": {"max_temp": "19", "min_temp": "5", "weather": "dyq", "wind": "12"}, "10-23": {"max_temp": "19", "min_temp": "4", "weather": "q", "wind": "34"}, "10-24": {"max_temp": "20", "min_temp": "6", "weather": "qdy", "wind": "12"}, "10-25": {"max_temp": "15", "min_temp": "8", "weather": "dy", "wind": "12"}, "10-26": {"max_temp": "14", "min_temp": "3", "weather": "q", "wind": "45"}, "10-27": {"max_temp": "17", "min_temp": "5", "weather": "dy", "wind": "12"}, "10-28": {"max_temp": "17", "min_temp": "4", "weather": "dyq", "wind": "45"}, "10-29": {"max_temp": "15", "min_temp": "3", "weather": "q", "wind": "34"}, "10-30": {"max_temp": "16", "min_temp": "1", "weather": "q", "wind": "12"}, "10-31": {"max_temp": "17", "min_temp": "3", "weather": "q", "wind": "12"}, "11-01": {"max_temp": "17", "min_temp": "3", "weather": "q", "wind": "12"}, "11-02": {"max_temp": "18", "min_temp": "4", "weather": "q", "wind": "12"}, "11-03": {"max_temp": "16", "min_temp": "6", "weather": "dy", "wind": "12"}, "11-04": {"max_temp": "10", "min_temp": "2", "weather": "xydy", "wind": "34"}, "11-05": {"max_temp": "10", "min_temp": "2", "weather": "dy", "wind": "12"}, "11-06": {"max_temp": "12", "min_temp": "0", "weather": "dy", "wind": "12"}, "11-07": {"max_temp": "13", "min_temp": "3", "weather": "dy", "wind": "12"}, "11-08": {"max_temp": "14", "min_temp": "2", "weather": "dy", "wind": "12"}, "11-09": {"max_temp": "15", "min_temp": "1", "weather": "qdy", "wind": "34"}, "11-10": {"max_temp": "11", "min_temp": "0", "weather": "dy", "wind": "12"}, "11-11": {"max_temp": "13", "min_temp": "1", "weather": "dyq", "wind": "12"}, "11-12": {"max_temp": "14", "min_temp": "2", "weather": "q", "wind": "12"}, "11-13": {"max_temp": "13", "min_temp": "5", "weather": "dy", "wind": "12"}, "11-14": {"max_temp": "13", "min_temp": "5", "weather": "dy", "wind": "12"}, "11-15": {"max_temp": "8", "min_temp": "1", "weather": "xydy", "wind": "34"}, "11-16": {"max_temp": "8", "min_temp": "-1", "weather": "q", "wind": "12"}, "11-17": {"max_temp": "9", "min_temp": "-2", "weather": "dyq", "wind": "12"}, "11-18": {"max_temp": "11", "min_temp": "-3", "weather": "q", "wind": "34"}, "11-19": {"max_temp": "10", "min_temp": "-2", "weather": "qdy", "wind": "12"}, "11-20": {"max_temp": "9", "min_temp": "-1", "weather": "dy", "wind": "12"}, "11-21": {"max_temp": "9", "min_temp": "-3", "weather": "q", "wind": "2"}, "11-22": {"max_temp": "8", "min_temp": "-3", "weather": "qdy", "wind": "1"}, "11-23": {"max_temp": "7", "min_temp": "0", "weather": "dy", "wind": "2"}, "11-24": {"max_temp": "9", "min_temp": "-3", "weather": "qdy", "wind": "2"}, "11-25": {"max_temp": "10", "min_temp": "-3", "weather": "q", "wind": "1"}, "11-26": {"max_temp": "10", "min_temp": "0", "weather": "dy", "wind": "1"}, "11-27": {"max_temp": "9", "min_temp": "-3", "weather": "qdy", "wind": "2"}, "11-28": {"max_temp": "8", "min_temp": "-3", "weather": "q", "wind": "1"}, "11-29": {"max_temp": "7", "min_temp": "-4", "weather": "q", "wind": "1"}, "11-30": {"max_temp": "8", "min_temp": "-3", "weather": "q", "wind": "1"}, "12-01": {"max_temp": "7", "min_temp": "0", "weather": "dy", "wind": "1"}, "12-02": {"max_temp": "9", "min_temp": "2", "weather": "dy", "wind": "1"}, "12-03": {"max_temp": "8", "min_temp": "-3", "weather": "dyq", "wind": "3"}, "12-04": {"max_temp": "4", "min_temp": "-6", "weather": "qdy", "wind": "2"}, "12-05": {"max_temp": "1", "min_temp": "-4", "weather": "dy", "wind": "1"}, "12-06": {"max_temp": "-2", "min_temp": "-9", "weather": "q", "wind": "3"}, "12-07": {"max_temp": "-4", "min_temp": "-10", "weather": "q", "wind": "3"}, "12-08": {"max_temp": "-2", "min_temp": "-10", "weather": "qdy", "wind": "2"}, "12-09": {"max_temp": "-1", "min_temp": "-10", "weather": "dyq", "wind": "1"}}
\ No newline at end of file
import argparse
import os
import sys
import time
from collections import OrderedDict
import paddle.fluid as fluid
from network import DCN
import utils
def boolean_string(s):
if s.lower() not in {'false', 'true'}:
raise ValueError('Not a valid boolean string')
return s.lower() == 'true'
def parse_args():
parser = argparse.ArgumentParser("dcn cluster train.")
parser.add_argument(
'--train_data_dir',
type=str,
default='dist_data/dist_train_data',
help='The path of train data')
parser.add_argument(
'--test_valid_data_dir',
type=str,
default='dist_data/dist_test_valid_data',
help='The path of test and valid data')
parser.add_argument(
'--vocab_dir',
type=str,
default='dist_data/vocab',
help='The path of generated vocabs')
parser.add_argument(
'--cat_feat_num',
type=str,
default='dist_data/cat_feature_num.txt',
help='The path of generated cat_feature_num.txt')
parser.add_argument(
'--batch_size', type=int, default=512, help="Batch size")
parser.add_argument('--num_epoch', type=int, default=10, help="train epoch")
parser.add_argument(
'--model_output_dir',
type=str,
default='models',
help='The path for model to store')
parser.add_argument(
'--num_thread', type=int, default=1, help='The number of threads')
parser.add_argument('--test_epoch', type=str, default='1')
parser.add_argument(
'--dnn_hidden_units',
nargs='+',
type=int,
default=[1024, 1024],
help='DNN layers and hidden units')
parser.add_argument(
'--cross_num',
type=int,
default=6,
help='The number of Cross network layers')
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
parser.add_argument(
'--l2_reg_cross',
type=float,
default=1e-5,
help='Cross net l2 regularizer coefficient')
parser.add_argument(
'--use_bn',
type=boolean_string,
default=True,
help='Whether use batch norm in dnn part')
parser.add_argument(
'--is_sparse',
action='store_true',
required=False,
default=False,
help='embedding will use sparse or not, (default: False)')
parser.add_argument(
'--clip_by_norm', type=float, default=100.0, help="gradient clip norm")
parser.add_argument('--print_steps', type=int, default=5)
parser.add_argument('--use_gpu', type=int, default=1)
# dist params
parser.add_argument('--is_local', type=int, default=1, help='whether local')
parser.add_argument(
'--num_devices', type=int, default=1, help='Number of GPU devices')
parser.add_argument(
'--role', type=str, default='pserver', help='trainer or pserver')
parser.add_argument(
'--endpoints',
type=str,
default='127.0.0.1:6000',
help='The pserver endpoints, like: 127.0.0.1:6000, 127.0.0.1:6001')
parser.add_argument(
'--current_endpoint',
type=str,
default='127.0.0.1:6000',
help='The current_endpoint')
parser.add_argument(
'--trainer_id',
type=int,
default=0,
help='trainer id ,only trainer_id=0 save model')
parser.add_argument(
'--trainers',
type=int,
default=1,
help='The num of trianers, (default: 1)')
args = parser.parse_args()
return args
def train():
""" do training """
args = parse_args()
print(args)
if args.trainer_id == 0 and not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
cat_feat_dims_dict = OrderedDict()
for line in open(args.cat_feat_num):
spls = line.strip().split()
assert len(spls) == 2
cat_feat_dims_dict[spls[0]] = int(spls[1])
dcn_model = DCN(args.cross_num, args.dnn_hidden_units, args.l2_reg_cross,
args.use_bn, args.clip_by_norm, cat_feat_dims_dict,
args.is_sparse)
dcn_model.build_network()
optimizer = fluid.optimizer.Adam(learning_rate=args.lr)
optimizer.minimize(dcn_model.loss)
def train_loop(main_program):
""" train network """
start_time = time.time()
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_use_var(dcn_model.data_list)
pipe_command = 'python reader.py {}'.format(args.vocab_dir)
dataset.set_pipe_command(pipe_command)
dataset.set_batch_size(args.batch_size)
dataset.set_thread(args.num_thread)
train_filelist = [
os.path.join(args.train_data_dir, fname)
for fname in next(os.walk(args.train_data_dir))[2]
]
dataset.set_filelist(train_filelist)
if args.use_gpu == 1:
exe = fluid.Executor(fluid.CUDAPlace(0))
dataset.set_thread(1)
else:
exe = fluid.Executor(fluid.CPUPlace())
dataset.set_thread(args.num_thread)
exe.run(fluid.default_startup_program())
for epoch_id in range(args.num_epoch):
start = time.time()
sys.stderr.write('\nepoch%d start ...\n' % (epoch_id + 1))
exe.train_from_dataset(
program=main_program,
dataset=dataset,
fetch_list=[
dcn_model.loss, dcn_model.avg_logloss, dcn_model.auc_var
],
fetch_info=['total_loss', 'avg_logloss', 'auc'],
debug=False,
print_period=args.print_steps)
model_dir = os.path.join(args.model_output_dir,
'epoch_' + str(epoch_id + 1), "checkpoint")
sys.stderr.write('epoch%d is finished and takes %f s\n' % (
(epoch_id + 1), time.time() - start))
if args.trainer_id == 0: # only trainer 0 save model
print("save model in {}".format(model_dir))
fluid.save(main_program, model_dir)
print("train time cost {:.4f}".format(time.time() - start_time))
print("finish training")
if args.is_local:
print("run local training")
train_loop(fluid.default_main_program())
else:
print("run distribute training")
t = fluid.DistributeTranspiler()
t.transpile(
args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
if args.role == "pserver":
print("run psever")
pserver_prog, pserver_startup = t.get_pserver_programs(
args.current_endpoint)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(pserver_startup)
exe.run(pserver_prog)
elif args.role == "trainer":
print("run trainer")
train_loop(t.get_trainer_program())
if __name__ == "__main__":
utils.check_version()
train()
#!/bin/bash
#export GLOG_v=30
#export GLOG_logtostderr=1
# start pserver0
python -u cluster_train.py \
--train_data_dir dist_data/dist_train_data \
--model_output_dir cluster_model \
--is_local 0 \
--is_sparse \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6000 \
--trainers 2 \
> pserver0.log 2>&1 &
# start pserver1
python -u cluster_train.py \
--train_data_dir dist_data/dist_train_data \
--model_output_dir cluster_model \
--is_local 0 \
--is_sparse \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6001 \
--trainers 2 \
> pserver1.log 2>&1 &
# start trainer0
#CUDA_VISIBLE_DEVICES=1 python cluster_train.py \
python -u cluster_train.py \
--train_data_dir dist_data/dist_train_data \
--model_output_dir cluster_model \
--use_gpu 0 \
--is_local 0 \
--is_sparse \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 0 \
> trainer0.log 2>&1 &
# start trainer1
#CUDA_VISIBLE_DEVICES=2 python cluster_train.py \
python -u cluster_train.py \
--train_data_dir dist_data/dist_train_data \
--model_output_dir cluster_model \
--use_gpu 0 \
--is_local 0 \
--is_sparse \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 1 \
> trainer1.log 2>&1 &
echo "2 pservers and 2 trainers started."
\ No newline at end of file
......@@ -72,8 +72,8 @@ class CriteoDataset(dg.MultiSlotDataGenerator):
yield label_feat_list
import paddle
batch_iter = paddle.batch(
paddle.reader.buffered(
batch_iter = fluid.io.batch(
fluid.io.buffered(
local_iter, size=buf_size), batch_size=batch)
return batch_iter
......
import argparse
import os
import sys
import time
from network_conf import ctr_deepfm_model
import paddle.fluid as fluid
import utils
def parse_args():
parser = argparse.ArgumentParser("deepfm cluster train.")
parser.add_argument(
'--train_data_dir',
type=str,
default='dist_data/dist_train_data',
help='The path of train data (default: data/train_data)')
parser.add_argument(
'--test_data_dir',
type=str,
default='dist_data/dist_test_data',
help='The path of test data (default: models)')
parser.add_argument(
'--feat_dict',
type=str,
default='dist_data/aid_data/feat_dict_10.pkl2',
help='The path of feat_dict')
parser.add_argument(
'--batch_size',
type=int,
default=100,
help="The size of mini-batch (default:100)")
parser.add_argument(
'--embedding_size',
type=int,
default=10,
help="The size for embedding layer (default:10)")
parser.add_argument(
'--num_epoch',
type=int,
default=10,
help="The number of epochs to train (default: 50)")
parser.add_argument(
'--model_output_dir',
type=str,
required=True,
help='The path for model to store (default: models)')
parser.add_argument(
'--num_thread',
type=int,
default=1,
help='The number of threads (default: 1)')
parser.add_argument('--test_epoch', type=str, default='1')
parser.add_argument(
'--layer_sizes',
nargs='+',
type=int,
default=[400, 400, 400],
help='The size of each layers (default: [10, 10, 10])')
parser.add_argument(
'--act',
type=str,
default='relu',
help='The activation of each layers (default: relu)')
parser.add_argument(
'--is_sparse',
action='store_true',
required=False,
default=False,
help='embedding will use sparse or not, (default: False)')
parser.add_argument(
'--lr', type=float, default=1e-4, help='Learning rate (default: 1e-4)')
parser.add_argument(
'--reg', type=float, default=1e-4, help=' (default: 1e-4)')
parser.add_argument('--num_field', type=int, default=39)
parser.add_argument('--num_feat', type=int, default=141443)
parser.add_argument('--use_gpu', type=int, default=1)
# dist params
parser.add_argument('--is_local', type=int, default=1, help='whether local')
parser.add_argument(
'--num_devices', type=int, default=1, help='Number of GPU devices')
parser.add_argument(
'--role', type=str, default='pserver', help='trainer or pserver')
parser.add_argument(
'--endpoints',
type=str,
default='127.0.0.1:6000',
help='The pserver endpoints, like: 127.0.0.1:6000, 127.0.0.1:6001')
parser.add_argument(
'--current_endpoint',
type=str,
default='127.0.0.1:6000',
help='The current_endpoint')
parser.add_argument(
'--trainer_id',
type=int,
default=0,
help='trainer id ,only trainer_id=0 save model')
parser.add_argument(
'--trainers',
type=int,
default=1,
help='The num of trianers, (default: 1)')
args = parser.parse_args()
return args
def train():
""" do training """
args = parse_args()
print(args)
if args.trainer_id == 0 and not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
loss, auc, data_list, auc_states = ctr_deepfm_model(
args.embedding_size, args.num_field, args.num_feat, args.layer_sizes,
args.act, args.reg, args.is_sparse)
optimizer = fluid.optimizer.SGD(
learning_rate=args.lr,
regularization=fluid.regularizer.L2DecayRegularizer(args.reg))
optimizer.minimize(loss)
def train_loop(main_program):
""" train network """
start_time = time.time()
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_use_var(data_list)
pipe_command = 'python criteo_reader.py {}'.format(args.feat_dict)
dataset.set_pipe_command(pipe_command)
dataset.set_batch_size(args.batch_size)
dataset.set_thread(args.num_thread)
train_filelist = [
os.path.join(args.train_data_dir, x)
for x in os.listdir(args.train_data_dir)
]
if args.use_gpu == 1:
exe = fluid.Executor(fluid.CUDAPlace(0))
dataset.set_thread(1)
else:
exe = fluid.Executor(fluid.CPUPlace())
dataset.set_thread(args.num_thread)
exe.run(fluid.default_startup_program())
for epoch_id in range(args.num_epoch):
start = time.time()
sys.stderr.write('\nepoch%d start ...\n' % (epoch_id + 1))
dataset.set_filelist(train_filelist)
exe.train_from_dataset(
program=main_program,
dataset=dataset,
fetch_list=[loss, auc],
fetch_info=['epoch %d batch loss' % (epoch_id + 1), "auc"],
print_period=5,
debug=False)
model_dir = os.path.join(args.model_output_dir,
'epoch_' + str(epoch_id + 1))
sys.stderr.write('epoch%d is finished and takes %f s\n' % (
(epoch_id + 1), time.time() - start))
if args.trainer_id == 0: # only trainer 0 save model
print("save model in {}".format(model_dir))
fluid.save(main_program, model_dir)
print("train time cost {:.4f}".format(time.time() - start_time))
print("finish training")
if args.is_local:
print("run local training")
train_loop(fluid.default_main_program())
else:
print("run distribute training")
t = fluid.DistributeTranspiler()
t.transpile(
args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
if args.role == "pserver":
print("run psever")
pserver_prog, pserver_startup = t.get_pserver_programs(
args.current_endpoint)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(pserver_startup)
exe.run(pserver_prog)
elif args.role == "trainer":
print("run trainer")
train_loop(t.get_trainer_program())
if __name__ == "__main__":
utils.check_version()
train()
#!/bin/bash
#export GLOG_v=30
#export GLOG_logtostderr=1
# start pserver0
python -u cluster_train.py \
--train_data_dir dist_data/dist_train_data \
--model_output_dir cluster_model \
--is_local 0 \
--is_sparse \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6000 \
--trainers 2 \
> pserver0.log 2>&1 &
# start pserver1
python -u cluster_train.py \
--train_data_dir dist_data/dist_train_data \
--model_output_dir cluster_model \
--is_local 0 \
--is_sparse \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6001 \
--trainers 2 \
> pserver1.log 2>&1 &
# start trainer0
#CUDA_VISIBLE_DEVICES=1 python cluster_train.py \
python -u cluster_train.py \
--train_data_dir dist_data/dist_train_data \
--model_output_dir cluster_model \
--use_gpu 0 \
--is_local 0 \
--is_sparse \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 0 \
> trainer0.log 2>&1 &
# start trainer1
#CUDA_VISIBLE_DEVICES=2 python cluster_train.py \
python -u cluster_train.py \
--train_data_dir dist_data/dist_train_data \
--model_output_dir cluster_model \
--use_gpu 0 \
--is_local 0 \
--is_sparse \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 1 \
> trainer1.log 2>&1 &
echo "2 pservers and 2 trainers started."
\ No newline at end of file
......@@ -30,7 +30,7 @@ def infer():
]
criteo_dataset = CriteoDataset()
criteo_dataset.setup(args.feat_dict)
test_reader = paddle.batch(
test_reader = fluid.io.batch(
criteo_dataset.test(test_files), batch_size=args.batch_size)
startup_program = fluid.framework.Program()
......
......@@ -6,6 +6,7 @@ import pickle
import random
import paddle
import paddle.fluid as fluid
class DataGenerator(object):
......@@ -58,7 +59,7 @@ class DataGenerator(object):
if not cycle:
break
return paddle.batch(_reader, batch_size=batch_size)
return fluid.io.batch(_reader, batch_size=batch_size)
def data_reader(batch_size,
......
......@@ -8,8 +8,6 @@
├── train.py # 训练脚本
├── infer.py # 预测脚本
├── network.py # 网络结构
├── cluster_train.py # 多机训练
├── cluster_train.sh # 多机训练脚本
├── reader.py # 和读取数据相关的函数
├── data/
├── build_dataset.py # 文本数据转化为paddle数据
......@@ -129,12 +127,3 @@ CUDA_VISIBLE_DEVICES=3 python infer.py --model_path 'din_amazon/global_step_4000
```text
2019-02-22 11:22:58,804 - INFO - TEST --> loss: [0.47005194] auc:0.863794952818
```
## 多机训练
可参考cluster_train.py 配置多机环境
运行命令本地模拟多机场景
```
sh cluster_train.sh
```
import sys
import logging
import time
import numpy as np
import argparse
import paddle.fluid as fluid
import paddle
import time
import network
import reader
import random
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser("din")
parser.add_argument(
'--config_path',
type=str,
default='data/config.txt',
help='dir of config')
parser.add_argument(
'--train_dir',
type=str,
default='data/paddle_train.txt',
help='dir of train file')
parser.add_argument(
'--model_dir',
type=str,
default='din_amazon/',
help='dir of saved model')
parser.add_argument(
'--batch_size', type=int, default=16, help='number of batch size')
parser.add_argument(
'--epoch_num', type=int, default=200, help='number of epoch')
parser.add_argument(
'--use_cuda', type=int, default=0, help='whether to use gpu')
parser.add_argument(
'--parallel',
type=int,
default=0,
help='whether to use parallel executor')
parser.add_argument(
'--base_lr', type=float, default=0.85, help='based learning rate')
parser.add_argument(
'--role', type=str, default='pserver', help='trainer or pserver')
parser.add_argument(
'--endpoints',
type=str,
default='127.0.0.1:6000',
help='The pserver endpoints, like: 127.0.0.1:6000, 127.0.0.1:6001')
parser.add_argument(
'--current_endpoint',
type=str,
default='127.0.0.1:6000',
help='The current_endpoint')
parser.add_argument(
'--trainer_id',
type=int,
default=0,
help='trainer id ,only trainer_id=0 save model')
parser.add_argument(
'--trainers',
type=int,
default=1,
help='The num of trianers, (default: 1)')
args = parser.parse_args()
return args
def train():
args = parse_args()
config_path = args.config_path
train_path = args.train_dir
epoch_num = args.epoch_num
use_cuda = True if args.use_cuda else False
use_parallel = True if args.parallel else False
logger.info("reading data begins")
user_count, item_count, cat_count = reader.config_read(config_path)
#data_reader, max_len = reader.prepare_reader(train_path, args.batch_size)
logger.info("reading data completes")
avg_cost, pred = network.network(item_count, cat_count, 433)
base_lr = args.base_lr
boundaries = [410000]
values = [base_lr, 0.2]
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=fluid.layers.piecewise_decay(
boundaries=boundaries, values=values))
sgd_optimizer.minimize(avg_cost)
def train_loop(main_program):
data_reader, max_len = reader.prepare_reader(train_path,
args.batch_size)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
feed_list = [
"hist_item_seq", "hist_cat_seq", "target_item", "target_cat",
"label", "mask", "target_item_seq", "target_cat_seq"
]
loader = fluid.io.DataLoader.from_generator(
feed_list=feed_list, capacity=10000, iterable=True)
loader.set_sample_list_generator(data_reader, places=place)
if use_parallel:
train_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
loss_name=avg_cost.name,
main_program=main_program)
else:
train_exe = exe
logger.info("train begins")
global_step = 0
PRINT_STEP = 1000
start_time = time.time()
loss_sum = 0.0
for id in range(epoch_num):
epoch = id + 1
for data in loader():
global_step += 1
results = train_exe.run(main_program,
feed=data,
fetch_list=[avg_cost.name, pred.name],
return_numpy=True)
loss_sum += results[0].mean()
if global_step % PRINT_STEP == 0:
logger.info(
"epoch: %d\tglobal_step: %d\ttrain_loss: %.4f\t\ttime: %.2f"
% (epoch, global_step, loss_sum / PRINT_STEP,
time.time() - start_time))
start_time = time.time()
loss_sum = 0.0
if (global_step > 400000 and
global_step % PRINT_STEP == 0) or (
global_step < 400000 and
global_step % 50000 == 0):
save_dir = args.model_dir + "/global_step_" + str(
global_step)
feed_var_name = [
"hist_item_seq", "hist_cat_seq", "target_item",
"target_cat", "label", "mask", "target_item_seq",
"target_cat_seq"
]
fetch_vars = [avg_cost, pred]
fluid.io.save_inference_model(save_dir, feed_var_name,
fetch_vars, exe)
train_exe.close()
t = fluid.DistributeTranspiler()
t.transpile(
args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
if args.role == "pserver":
logger.info("run psever")
prog, startup = t.get_pserver_programs(args.current_endpoint)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup)
exe.run(prog)
elif args.role == "trainer":
logger.info("run trainer")
train_loop(t.get_trainer_program())
if __name__ == "__main__":
train()
#!/bin/bash
#export GLOG_v=30
#export GLOG_logtostderr=1
python -u cluster_train.py \
--config_path 'data/config.txt' \
--train_dir 'data/paddle_train.txt' \
--batch_size 32 \
--epoch_num 100 \
--use_cuda 0 \
--parallel 0 \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6000 \
--trainers 2 \
> pserver0.log 2>&1 &
python -u cluster_train.py \
--config_path 'data/config.txt' \
--train_dir 'data/paddle_train.txt' \
--batch_size 32 \
--epoch_num 100 \
--use_cuda 0 \
--parallel 0 \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6001 \
--trainers 2 \
> pserver1.log 2>&1 &
python -u cluster_train.py \
--config_path 'data/config.txt' \
--train_dir 'data/paddle_train.txt' \
--batch_size 32 \
--epoch_num 100 \
--use_cuda 0 \
--parallel 0 \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 0 \
> trainer0.log 2>&1 &
python -u cluster_train.py \
--config_path 'data/config.txt' \
--train_dir 'data/paddle_train.txt' \
--batch_size 32 \
--epoch_num 100 \
--use_cuda 0 \
--parallel 0 \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 1 \
> trainer1.log 2>&1 &
......@@ -458,7 +458,7 @@ elif fleet.is_worker():
# 默认使用0号节点保存模型
if params.test and fleet.is_first_worker():
model_path = (str(params.model_path) + "/"+"epoch_" + str(epoch))
fluid.io.save_persistables(executor=exe, dirname=model_path)
fleet.save_persistables(executor=exe, dirname=model_path)
# 训练结束,调用stop_worker()通知pserver
fleet.stop_worker()
......@@ -504,7 +504,7 @@ sh local_cluster.sh
便可以开启分布式模拟训练,默认启用2x2的训练模式。Trainer与Pserver的运行日志,存放于`./log/`文件夹,保存的模型位于`./models/`,使用默认配置运行后,理想输出为:
- pserver.0.log
```bash
get_pserver_program() is deprecated, call get_pserver_programs() to get pserver main and startup in a single call.
I1126 07:37:49.952580 15056 grpc_server.cc:477] Server listening on 127.0.0.1:36011 successful, selected port: 36011
```
......
......@@ -30,8 +30,7 @@ logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(
description="PaddlePaddle CTR-DNN example")
parser = argparse.ArgumentParser(description="PaddlePaddle CTR-DNN example")
# -------------Data & Model Path-------------
parser.add_argument(
'--test_files_path',
......@@ -54,8 +53,7 @@ def parse_args():
'--infer_epoch',
type=int,
default=0,
help='Specify which epoch to run infer'
)
help='Specify which epoch to run infer')
# -------------Network parameter-------------
parser.add_argument(
'--embedding_size',
......@@ -68,10 +66,7 @@ def parse_args():
default=1000001,
help='sparse feature hashing space for index processing')
parser.add_argument(
'--dense_feature_dim',
type=int,
default=13,
help='dense feature shape')
'--dense_feature_dim', type=int, default=13, help='dense feature shape')
# -------------device parameter-------------
parser.add_argument(
......@@ -102,10 +97,11 @@ def run_infer(args, model_path):
place = fluid.CPUPlace()
train_generator = generator.CriteoDataset(args.sparse_feature_dim)
file_list = [
os.path.join(args.test_files_path, x) for x in os.listdir(args.test_files_path)
os.path.join(args.test_files_path, x)
for x in os.listdir(args.test_files_path)
]
test_reader = paddle.batch(train_generator.test(file_list),
batch_size=args.batch_size)
test_reader = fluid.io.batch(
train_generator.test(file_list), batch_size=args.batch_size)
startup_program = fluid.framework.Program()
test_program = fluid.framework.Program()
ctr_model = CTR()
......@@ -171,13 +167,15 @@ if __name__ == "__main__":
model_list = []
for _, dir, _ in os.walk(args.model_path):
for model in dir:
if "epoch" in model and args.infer_epoch == int(model.split('_')[-1]):
if "epoch" in model and args.infer_epoch == int(
model.split('_')[-1]):
path = os.path.join(args.model_path, model)
model_list.append(path)
if len(model_list) == 0:
logger.info("There is no satisfactory model {} at path {}, please check your start command & env. ".format(
str("epoch_")+str(args.infer_epoch), args.model_path))
logger.info(
"There is no satisfactory model {} at path {}, please check your start command & env. ".
format(str("epoch_") + str(args.infer_epoch), args.model_path))
for model in model_list:
logger.info("Test model {}".format(model))
......
import argparse
import os
import sys
import time
import network_conf
import paddle.fluid as fluid
import utils
def parse_args():
parser = argparse.ArgumentParser("xdeepfm cluster train.")
parser.add_argument(
'--train_data_dir',
type=str,
default='data/train_data',
help='The path of train data (default: data/train_data)')
parser.add_argument(
'--test_data_dir',
type=str,
default='data/test_data',
help='The path of test data (default: models)')
parser.add_argument(
'--batch_size',
type=int,
default=100,
help="The size of mini-batch (default:100)")
parser.add_argument(
'--embedding_size',
type=int,
default=10,
help="The size for embedding layer (default:10)")
parser.add_argument(
'--num_epoch',
type=int,
default=10,
help="The number of epochs to train (default: 10)")
parser.add_argument(
'--model_output_dir',
type=str,
required=True,
help='The path for model to store (default: models)')
parser.add_argument(
'--num_thread',
type=int,
default=1,
help='The number of threads (default: 1)')
parser.add_argument('--test_epoch', type=str, default='1')
parser.add_argument(
'--layer_sizes_dnn',
nargs='+',
type=int,
default=[10, 10, 10],
help='The size of each layers')
parser.add_argument(
'--layer_sizes_cin',
nargs='+',
type=int,
default=[10, 10],
help='The size of each layers')
parser.add_argument(
'--act',
type=str,
default='relu',
help='The activation of each layers (default: relu)')
parser.add_argument(
'--lr', type=float, default=1e-1, help='Learning rate (default: 1e-4)')
parser.add_argument(
'--reg', type=float, default=1e-4, help=' (default: 1e-4)')
parser.add_argument('--num_field', type=int, default=39)
parser.add_argument('--num_feat', type=int, default=28651)
parser.add_argument(
'--model_name',
type=str,
default='ctr_xdeepfm_model',
help='The name of model (default: ctr_xdeepfm_model)')
parser.add_argument('--use_gpu', type=int, default=1)
parser.add_argument('--print_steps', type=int, default=50)
parser.add_argument('--is_local', type=int, default=1, help='whether local')
parser.add_argument(
'--is_sparse',
action='store_true',
required=False,
default=False,
help='embedding will use sparse or not, (default: False)')
# dist params
parser.add_argument(
'--num_devices', type=int, default=1, help='Number of GPU devices')
parser.add_argument(
'--role', type=str, default='pserver', help='trainer or pserver')
parser.add_argument(
'--endpoints',
type=str,
default='127.0.0.1:6000',
help='The pserver endpoints, like: 127.0.0.1:6000, 127.0.0.1:6001')
parser.add_argument(
'--current_endpoint',
type=str,
default='127.0.0.1:6000',
help='The current_endpoint')
parser.add_argument(
'--trainer_id',
type=int,
default=0,
help='trainer id ,only trainer_id=0 save model')
parser.add_argument(
'--trainers',
type=int,
default=1,
help='The num of trianers, (default: 1)')
args = parser.parse_args()
return args
def train():
""" do training """
args = parse_args()
print(args)
if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
loss, auc, data_list, auc_states = eval('network_conf.' + args.model_name)(
args.embedding_size, args.num_field, args.num_feat,
args.layer_sizes_dnn, args.act, args.reg, args.layer_sizes_cin,
args.is_sparse)
optimizer = fluid.optimizer.SGD(
learning_rate=args.lr,
regularization=fluid.regularizer.L2DecayRegularizer(args.reg))
optimizer.minimize(loss)
def train_loop(main_program):
""" train network """
start_time = time.time()
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_use_var(data_list)
dataset.set_pipe_command('python criteo_reader.py')
dataset.set_batch_size(args.batch_size)
dataset.set_filelist([
os.path.join(args.train_data_dir, x)
for x in os.listdir(args.train_data_dir)
])
if args.use_gpu == 1:
exe = fluid.Executor(fluid.CUDAPlace(0))
dataset.set_thread(1)
else:
exe = fluid.Executor(fluid.CPUPlace())
dataset.set_thread(args.num_thread)
exe.run(fluid.default_startup_program())
for epoch_id in range(args.num_epoch):
start = time.time()
sys.stderr.write('\nepoch%d start ...\n' % (epoch_id + 1))
exe.train_from_dataset(
program=main_program,
dataset=dataset,
fetch_list=[loss, auc],
fetch_info=['loss', 'auc'],
debug=False,
print_period=args.print_steps)
model_dir = os.path.join(args.model_output_dir,
'epoch_' + str(epoch_id + 1), "checkpoint")
sys.stderr.write('epoch%d is finished and takes %f s\n' % (
(epoch_id + 1), time.time() - start))
if args.trainer_id == 0: # only trainer 0 save model
print("save model in {}".format(model_dir))
fluid.save(main_program, model_dir)
print("train time cost {:.4f}".format(time.time() - start_time))
print("finish training")
if args.is_local:
print("run local training")
train_loop(fluid.default_main_program())
else:
print("run distribute training")
t = fluid.DistributeTranspiler()
t.transpile(
args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
if args.role == "pserver":
print("run psever")
pserver_prog, pserver_startup = t.get_pserver_programs(
args.current_endpoint)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(pserver_startup)
exe.run(pserver_prog)
elif args.role == "trainer":
print("run trainer")
train_loop(t.get_trainer_program())
if __name__ == "__main__":
utils.check_version()
train()
#!/bin/bash
#export GLOG_v=30
#export GLOG_logtostderr=1
# start pserver0
python -u cluster_train.py \
--train_data_dir data/train_data \
--model_output_dir cluster_model \
--is_local 0 \
--is_sparse \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6000 \
--trainers 2 \
> pserver0.log 2>&1 &
# start pserver1
python -u cluster_train.py \
--train_data_dir data/train_data \
--model_output_dir cluster_model \
--is_local 0 \
--is_sparse \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6001 \
--trainers 2 \
> pserver1.log 2>&1 &
# start trainer0
#CUDA_VISIBLE_DEVICES=1 python cluster_train.py \
python -u cluster_train.py \
--train_data_dir data/train_data \
--model_output_dir cluster_model \
--use_gpu 0 \
--is_local 0 \
--is_sparse \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 0 \
> trainer0.log 2>&1 &
# start trainer1
#CUDA_VISIBLE_DEVICES=2 python cluster_train.py \
python -u cluster_train.py \
--train_data_dir data/train_data \
--model_output_dir cluster_model \
--use_gpu 0 \
--is_local 0 \
--is_sparse \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 1 \
> trainer1.log 2>&1 &
echo "2 pservers and 2 trainers started."
\ No newline at end of file
......@@ -30,7 +30,7 @@ def infer():
for x in os.listdir(args.test_data_dir)
]
criteo_dataset = CriteoDataset()
test_reader = paddle.batch(
test_reader = fluid.io.batch(
criteo_dataset.test(test_files), batch_size=args.batch_size)
startup_program = fluid.framework.Program()
......
......@@ -11,8 +11,6 @@
├── infer_sample_neg.py # 预测脚本 sample负例
├── net.py # 网络结构
├── text2paddle.py # 文本数据转paddle数据
├── cluster_train.py # 多机训练
├── cluster_train.sh # 多机训练脚本
├── utils # 通用函数
├── convert_format.py # 转换数据格式
├── vocab.txt # 小样本字典
......@@ -168,7 +166,7 @@ CUDA_VISIBLE_DEVICES=0 python train_sample_neg.py --loss ce --use_cuda 1
可在[net.py](./net.py) `network` 函数中调整网络结构,当前的网络结构如下:
```python
emb = fluid.layers.embedding(
emb = fluid.embedding(
input=src,
size=[vocab_size, hid_size],
param_attr=fluid.ParamAttr(
......@@ -278,12 +276,3 @@ model:model_r@20/epoch_10 recall@20:0.681 time_cost(s):12.2
## 多机训练
厂内用户可以参考[wiki](http://wiki.baidu.com/pages/viewpage.action?pageId=628300529)利用paddlecloud 配置多机环境
可参考cluster_train.py 配置其他多机环境
运行命令本地模拟多机场景, 暂不支持windows
```
sh cluster_train.sh
```
注意本地模拟需要关闭代理
import os
import sys
import time
import six
import numpy as np
import math
import argparse
import paddle.fluid as fluid
import paddle
import time
import utils
import net
SEED = 102
def parse_args():
parser = argparse.ArgumentParser("gru4rec benchmark.")
parser.add_argument(
'--train_dir',
type=str,
default='train_data',
help='train file address')
parser.add_argument(
'--vocab_path',
type=str,
default='vocab.txt',
help='vocab file address')
parser.add_argument('--is_local', type=int, default=1, help='whether local')
parser.add_argument('--hid_size', type=int, default=100, help='hid size')
parser.add_argument(
'--model_dir', type=str, default='model_recall20', help='model dir')
parser.add_argument(
'--batch_size', type=int, default=5, help='num of batch size')
parser.add_argument('--pass_num', type=int, default=10, help='num of epoch')
parser.add_argument(
'--print_batch', type=int, default=10, help='num of print batch')
parser.add_argument(
'--use_cuda', type=int, default=0, help='whether use gpu')
parser.add_argument(
'--base_lr', type=float, default=0.01, help='learning rate')
parser.add_argument(
'--num_devices', type=int, default=1, help='Number of GPU devices')
parser.add_argument(
'--role', type=str, default='pserver', help='trainer or pserver')
parser.add_argument(
'--endpoints',
type=str,
default='127.0.0.1:6000',
help='The pserver endpoints, like: 127.0.0.1:6000, 127.0.0.1:6001')
parser.add_argument(
'--current_endpoint',
type=str,
default='127.0.0.1:6000',
help='The current_endpoint')
parser.add_argument(
'--trainer_id',
type=int,
default=0,
help='trainer id ,only trainer_id=0 save model')
parser.add_argument(
'--trainers',
type=int,
default=1,
help='The num of trianers, (default: 1)')
args = parser.parse_args()
return args
def get_cards(args):
return args.num_devices
def train():
""" do training """
args = parse_args()
hid_size = args.hid_size
train_dir = args.train_dir
vocab_path = args.vocab_path
use_cuda = True if args.use_cuda else False
print("use_cuda:", use_cuda)
batch_size = args.batch_size
vocab_size, train_reader = utils.prepare_data(
train_dir, vocab_path, batch_size=batch_size * get_cards(args),\
buffer_size=1000, word_freq_threshold=0, is_train=True)
# Train program
src_wordseq, dst_wordseq, avg_cost, acc = net.all_vocab_network(
vocab_size=vocab_size, hid_size=hid_size)
# Optimization to minimize lost
sgd_optimizer = fluid.optimizer.SGD(learning_rate=args.base_lr)
sgd_optimizer.minimize(avg_cost)
def train_loop(main_program):
""" train network """
pass_num = args.pass_num
model_dir = args.model_dir
fetch_list = [avg_cost.name]
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
total_time = 0.0
for pass_idx in six.moves.xrange(pass_num):
epoch_idx = pass_idx + 1
print("epoch_%d start" % epoch_idx)
t0 = time.time()
i = 0
newest_ppl = 0
for data in train_reader():
i += 1
lod_src_wordseq = utils.to_lodtensor([dat[0] for dat in data],
place)
lod_dst_wordseq = utils.to_lodtensor([dat[1] for dat in data],
place)
ret_avg_cost = exe.run(main_program,
feed={
"src_wordseq": lod_src_wordseq,
"dst_wordseq": lod_dst_wordseq
},
fetch_list=fetch_list)
avg_ppl = np.exp(ret_avg_cost[0])
newest_ppl = np.mean(avg_ppl)
if i % args.print_batch == 0:
print("step:%d ppl:%.3f" % (i, newest_ppl))
t1 = time.time()
total_time += t1 - t0
print("epoch:%d num_steps:%d time_cost(s):%f" %
(epoch_idx, i, total_time / epoch_idx))
save_dir = "%s/epoch_%d" % (model_dir, epoch_idx)
feed_var_names = ["src_wordseq", "dst_wordseq"]
fetch_vars = [avg_cost, acc]
if args.trainer_id == 0:
fluid.io.save_inference_model(save_dir, feed_var_names,
fetch_vars, exe)
print("model saved in %s" % save_dir)
print("finish training")
if args.is_local:
print("run local training")
train_loop(fluid.default_main_program())
else:
print("run distribute training")
t = fluid.DistributeTranspiler()
t.transpile(
args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
if args.role == "pserver":
print("run psever")
pserver_prog = t.get_pserver_program(args.current_endpoint)
pserver_startup = t.get_startup_program(args.current_endpoint,
pserver_prog)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(pserver_startup)
exe.run(pserver_prog)
elif args.role == "trainer":
print("run trainer")
train_loop(t.get_trainer_program())
if __name__ == "__main__":
train()
#!/bin/bash
#export GLOG_v=30
#export GLOG_logtostderr=1
# start pserver0
python cluster_train.py \
--train_dir train_data \
--model_dir cluster_model \
--vocab_path vocab.txt \
--batch_size 5 \
--is_local 0 \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6000 \
--trainers 2 \
> pserver0.log 2>&1 &
# start pserver1
python cluster_train.py \
--train_dir train_data \
--model_dir cluster_model \
--vocab_path vocab.txt \
--batch_size 5 \
--is_local 0 \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6001 \
--trainers 2 \
> pserver1.log 2>&1 &
# start trainer0
#CUDA_VISIBLE_DEVICES=1 python cluster_train.py \
python cluster_train.py \
--train_dir train_data \
--model_dir cluster_model \
--vocab_path vocab.txt \
--batch_size 5 \
--print_batch 10 \
--use_cuda 0 \
--is_local 0 \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 0 \
> trainer0.log 2>&1 &
# start trainer1
#CUDA_VISIBLE_DEVICES=2 python cluster_train.py \
python cluster_train.py \
--train_dir train_data \
--model_dir cluster_model \
--vocab_path vocab.txt \
--batch_size 5 \
--print_batch 10 \
--use_cuda 0 \
--is_local 0 \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 1 \
> trainer1.log 2>&1 &
......@@ -103,7 +103,7 @@ def prepare_data(file_dir,
if is_train and 'ce_mode' not in os.environ:
vocab_size = get_vocab_size(vocab_path)
reader = sort_batch(
paddle.reader.shuffle(
fluid.io.shuffle(
train(
file_dir, buffer_size, data_type=DataType.SEQ),
buf_size=buffer_size),
......
......@@ -102,8 +102,8 @@ def parse_args():
def start_infer(args, model_path):
dataset = reader.SyntheticDataset(args.sparse_feature_dim, args.query_slots,
args.title_slots)
test_reader = paddle.batch(
paddle.reader.shuffle(
test_reader = fluid.io.batch(
fluid.io.shuffle(
dataset.valid(), buf_size=args.batch_size * 100),
batch_size=args.batch_size)
place = fluid.CPUPlace()
......
......@@ -112,8 +112,8 @@ def start_train(args):
dataset = reader.SyntheticDataset(args.sparse_feature_dim, args.query_slots,
args.title_slots)
train_reader = paddle.batch(
paddle.reader.shuffle(
train_reader = fluid.io.batch(
fluid.io.shuffle(
dataset.train(), buf_size=args.batch_size * 100),
batch_size=args.batch_size)
place = fluid.CPUPlace()
......
......@@ -23,23 +23,28 @@ _K = None
_args = None
_model_path = None
def run_infer(args, model_path, test_data_path):
test_data_generator = utils.CriteoDataset()
with fluid.scope_guard(fluid.Scope()):
test_reader = paddle.batch(test_data_generator.test(test_data_path, False), batch_size=args.test_batch_size)
test_reader = fluid.io.batch(
test_data_generator.test(test_data_path, False),
batch_size=args.test_batch_size)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
infer_program, feed_target_names, fetch_vars = fluid.io.load_inference_model(model_path, exe)
infer_program, feed_target_names, fetch_vars = fluid.io.load_inference_model(
model_path, exe)
for data in test_reader():
user_input = np.array([dat[0] for dat in data])
item_input = np.array([dat[1] for dat in data])
pred_val = exe.run(infer_program,
pred_val = exe.run(
infer_program,
feed={"user_input": user_input,
"item_input": item_input},
fetch_list=fetch_vars,
......@@ -47,6 +52,7 @@ def run_infer(args, model_path, test_data_path):
return pred_val[0].reshape(1, -1).tolist()[0]
def evaluate_model(args, testRatings, testNegatives, K, model_path):
"""
Evaluate the performance (Hit_Ratio, NDCG) of top-K recommendation
......@@ -60,18 +66,19 @@ def evaluate_model(args, testRatings, testNegatives, K, model_path):
global _args
_args = args
_model_path= model_path
_model_path = model_path
_testRatings = testRatings
_testNegatives = testNegatives
_K = K
hits, ndcgs = [],[]
hits, ndcgs = [], []
for idx in range(len(_testRatings)):
(hr,ndcg) = eval_one_rating(idx)
(hr, ndcg) = eval_one_rating(idx)
hits.append(hr)
ndcgs.append(ndcg)
return (hits, ndcgs)
def eval_one_rating(idx):
rating = _testRatings[idx]
items = _testNegatives[idx]
......@@ -80,9 +87,9 @@ def eval_one_rating(idx):
items.append(gtItem)
# Get prediction scores
map_item_score = {}
users = np.full(len(items), u, dtype = 'int32')
users = users.reshape(-1,1)
items_array = np.array(items).reshape(-1,1)
users = np.full(len(items), u, dtype='int32')
users = users.reshape(-1, 1)
items_array = np.array(items).reshape(-1, 1)
temp = np.hstack((users, items_array))
np.savetxt("Data/test.txt", temp, fmt='%d', delimiter=',')
predictions = run_infer(_args, _model_path, _args.test_data_path)
......@@ -99,15 +106,17 @@ def eval_one_rating(idx):
return (hr, ndcg)
def getHitRatio(ranklist, gtItem):
for item in ranklist:
if item == gtItem:
return 1
return 0
def getNDCG(ranklist, gtItem):
for i in range(len(ranklist)):
item = ranklist[i]
if item == gtItem:
return math.log(2) / math.log(i+2)
return math.log(2) / math.log(i + 2)
return 0
......@@ -43,9 +43,6 @@ cpu 单机多卡训练
CPU_NUM=10 python train.py --train_dir train_data --use_cuda 0 --parallel 1 --batch_size 50 --model_dir model_output --num_devices 10
```
本地模拟多机训练, 不支持windows.
``` bash
sh cluster_train.sh
```
## Inference
......
#Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import argparse
import logging
import paddle.fluid as fluid
import paddle
import utils
import numpy as np
from nets import SequenceSemanticRetrieval
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser("sequence semantic retrieval")
parser.add_argument(
"--train_dir", type=str, default='train_data', help="Training file")
parser.add_argument(
"--base_lr", type=float, default=0.01, help="learning rate")
parser.add_argument(
'--vocab_path', type=str, default='vocab.txt', help='vocab file')
parser.add_argument(
"--epochs", type=int, default=10, help="Number of epochs")
parser.add_argument(
'--parallel', type=int, default=0, help='whether parallel')
parser.add_argument(
'--use_cuda', type=int, default=0, help='whether use gpu')
parser.add_argument(
'--print_batch', type=int, default=10, help='num of print batch')
parser.add_argument(
'--model_dir', type=str, default='model_output', help='model dir')
parser.add_argument(
"--hidden_size", type=int, default=128, help="hidden size")
parser.add_argument(
"--batch_size", type=int, default=50, help="number of batch")
parser.add_argument(
"--embedding_dim", type=int, default=128, help="embedding dim")
parser.add_argument(
'--num_devices', type=int, default=1, help='Number of GPU devices')
parser.add_argument(
'--step_num', type=int, default=1000, help='Number of steps')
parser.add_argument(
'--enable_ce',
action='store_true',
help='If set, run the task with continuous evaluation logs.')
parser.add_argument(
'--role', type=str, default='pserver', help='trainer or pserver')
parser.add_argument(
'--endpoints',
type=str,
default='127.0.0.1:6000',
help='The pserver endpoints, like: 127.0.0.1:6000, 127.0.0.1:6001')
parser.add_argument(
'--current_endpoint',
type=str,
default='127.0.0.1:6000',
help='The current_endpoint')
parser.add_argument(
'--trainer_id',
type=int,
default=0,
help='trainer id ,only trainer_id=0 save model')
parser.add_argument(
'--trainers',
type=int,
default=1,
help='The num of trianers, (default: 1)')
return parser.parse_args()
def get_cards(args):
return args.num_devices
def train_loop(main_program, avg_cost, acc, train_input_data, place, args,
train_reader):
data_list = [var.name for var in train_input_data]
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
loader = fluid.io.DataLoader.from_generator(
feed_list=train_input_data, capacity=10000, iterable=True)
loader.set_sample_list_generator(train_reader, places=place)
train_exe = exe
total_time = 0.0
ce_info = []
for pass_id in range(args.epochs):
epoch_idx = pass_id + 1
print("epoch_%d start" % epoch_idx)
t0 = time.time()
i = 0
for batch_id, data in enumerate(loader()):
i += 1
loss_val, correct_val = train_exe.run(
feed=data, fetch_list=[avg_cost.name, acc.name])
ce_info.append(float(np.mean(correct_val)) / args.batch_size)
if i % args.print_batch == 0:
logger.info(
"Train --> pass: {} batch_id: {} avg_cost: {}, acc: {}".
format(pass_id, batch_id,
np.mean(loss_val),
float(np.mean(correct_val)) / args.batch_size))
if args.enable_ce and i > args.step_num:
break
t1 = time.time()
total_time += t1 - t0
print("epoch:%d num_steps:%d time_cost(s):%f" %
(epoch_idx, i, total_time / epoch_idx))
save_dir = "%s/epoch_%d" % (args.model_dir, epoch_idx)
fluid.save(fluid.default_main_program(), save_dir)
print("model saved in %s" % save_dir)
# only for ce
if args.enable_ce:
ce_acc = 0
try:
ce_acc = ce_info[-2]
except:
print("ce info error")
epoch_idx = args.epochs
device = get_device(args)
if args.use_cuda:
gpu_num = device[1]
print("kpis\teach_pass_duration_gpu%s\t%s" %
(gpu_num, total_time / epoch_idx))
print("kpis\ttrain_acc_gpu%s\t%s" % (gpu_num, ce_acc))
else:
cpu_num = device[1]
threads_num = device[2]
print("kpis\teach_pass_duration_cpu%s_thread%s\t%s" %
(cpu_num, threads_num, total_time / epoch_idx))
print("kpis\ttrain_acc_cpu%s_thread%s\t%s" %
(cpu_num, threads_num, ce_acc))
def train(args):
if args.enable_ce:
SEED = 102
fluid.default_startup_program().random_seed = SEED
fluid.default_main_program().random_seed = SEED
use_cuda = True if args.use_cuda else False
parallel = True if args.parallel else False
print("use_cuda:", use_cuda, "parallel:", parallel)
train_reader, vocab_size = utils.construct_train_data(
args.train_dir, args.vocab_path, args.batch_size * get_cards(args))
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
ssr = SequenceSemanticRetrieval(vocab_size, args.embedding_dim,
args.hidden_size)
# Train program
train_input_data, cos_pos, avg_cost, acc = ssr.train()
# Optimization to minimize lost
optimizer = fluid.optimizer.Adagrad(learning_rate=args.base_lr)
optimizer.minimize(avg_cost)
print("run distribute training")
t = fluid.DistributeTranspiler()
t.transpile(
args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
if args.role == "pserver":
print("run psever")
pserver_prog = t.get_pserver_program(args.current_endpoint)
pserver_startup = t.get_startup_program(args.current_endpoint,
pserver_prog)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(pserver_startup)
exe.run(pserver_prog)
elif args.role == "trainer":
print("run trainer")
train_loop(t.get_trainer_program(), avg_cost, acc, train_input_data,
place, args, train_reader)
def get_device(args):
if args.use_cuda:
gpus = os.environ.get("CUDA_VISIBLE_DEVICES", 1)
gpu_num = len(gpus.split(','))
return "gpu", gpu_num
else:
threads_num = os.environ.get('NUM_THREADS', 1)
cpu_num = os.environ.get('CPU_NUM', 1)
return "cpu", int(cpu_num), int(threads_num)
def main():
args = parse_args()
train(args)
if __name__ == "__main__":
main()
#!/bin/bash
#export GLOG_v=30
#export GLOG_logtostderr=1
# start pserver0
python cluster_train.py \
--train_dir train_data \
--model_dir cluster_model \
--vocab_path vocab.txt \
--batch_size 5 \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6000 \
--trainers 2 \
> pserver0.log 2>&1 &
# start pserver1
python cluster_train.py \
--train_dir train_data \
--model_dir cluster_model \
--vocab_path vocab.txt \
--batch_size 5 \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6001 \
--trainers 2 \
> pserver1.log 2>&1 &
# start trainer0
#CUDA_VISIBLE_DEVICES=1 python cluster_train.py \
python cluster_train.py \
--train_dir train_data \
--model_dir cluster_model \
--vocab_path vocab.txt \
--batch_size 5 \
--print_batch 10 \
--use_cuda 0 \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 0 \
> trainer0.log 2>&1 &
# start trainer1
#CUDA_VISIBLE_DEVICES=2 python cluster_train.py \
python cluster_train.py \
--train_dir train_data \
--model_dir cluster_model \
--vocab_path vocab.txt \
--batch_size 5 \
--print_batch 10 \
--use_cuda 0 \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 1 \
> trainer1.log 2>&1 &
......@@ -37,16 +37,16 @@ def parse_args():
def model(vocab_size, emb_size, hidden_size):
user_data = fluid.layers.data(
name="user", shape=[1], dtype="int64", lod_level=1)
all_item_data = fluid.layers.data(
name="all_item", shape=[vocab_size, 1], dtype="int64")
user_data = fluid.data(
name="user", shape=[None, 1], dtype="int64", lod_level=1)
all_item_data = fluid.data(
name="all_item", shape=[None, vocab_size, 1], dtype="int64")
user_emb = fluid.layers.embedding(
user_emb = fluid.embedding(
input=user_data, size=[vocab_size, emb_size], param_attr="emb.item")
all_item_emb = fluid.layers.embedding(
all_item_emb = fluid.embedding(
input=all_item_data, size=[vocab_size, emb_size], param_attr="emb.item")
all_item_emb_re = fluid.layers.reshape(x=all_item_emb, shape=[-1, emb_size])
all_item_emb_re = all_item_emb
user_encoder = net.GrnnEncoder(hidden_size=hidden_size)
user_enc = user_encoder.forward(user_emb)
......@@ -63,7 +63,7 @@ def model(vocab_size, emb_size, hidden_size):
bias_attr="item.b")
cos_item = fluid.layers.cos_sim(X=all_item_hid, Y=user_re)
all_pre_ = fluid.layers.reshape(x=cos_item, shape=[-1, vocab_size])
pos_label = fluid.layers.data(name="pos_label", shape=[1], dtype="int64")
pos_label = fluid.data(name="pos_label", shape=[None, 1], dtype="int64")
acc = fluid.layers.accuracy(input=all_pre_, label=pos_label, k=20)
return acc
......
......@@ -18,7 +18,7 @@ def construct_train_data(file_dir, vocab_path, batch_size):
files = [file_dir + '/' + f for f in os.listdir(file_dir)]
y_data = reader.YoochooseDataset(vocab_size)
train_reader = fluid.io.batch(
paddle.reader.shuffle(
fluid.io.shuffle(
y_data.train(files), buf_size=batch_size * 100),
batch_size=batch_size)
return train_reader, vocab_size
......
......@@ -9,8 +9,6 @@
├── infer.py # 预测脚本
├── net.py # 网络结构
├── text2paddle.py # 文本数据转paddle数据
├── cluster_train.py # 多机训练
├── cluster_train.sh # 多机训练脚本
├── utils # 通用函数
├── vocab_text.txt # 小样本文本字典
├── vocab_tag.txt # 小样本类别字典
......@@ -89,9 +87,3 @@ python infer.py
```
python infer.py --model_dir big_model --vocab_tag_path big_vocab_tag.txt --test_dir test_big_data/
```
## 本地模拟多机
运行命令
```
sh cluster_train.py
```
import os
import sys
import time
import six
import numpy as np
import math
import argparse
import paddle
import paddle.fluid as fluid
import time
import utils
import net
SEED = 102
def parse_args():
parser = argparse.ArgumentParser("TagSpace benchmark.")
parser.add_argument(
'--neg_size', type=int, default=3, help='neg/pos ratio')
parser.add_argument(
'--train_dir', type=str, default='train_data', help='train file address')
parser.add_argument(
'--vocab_text_path', type=str, default='vocab_text.txt', help='vocab_text file address')
parser.add_argument(
'--vocab_tag_path', type=str, default='vocab_tag.txt', help='vocab_text file address')
parser.add_argument(
'--is_local', type=int, default=1, help='whether local')
parser.add_argument(
'--model_dir', type=str, default='model_', help='model dir')
parser.add_argument(
'--batch_size', type=int, default=5, help='num of batch size')
parser.add_argument(
'--print_batch', type=int, default=10, help='num of print batch')
parser.add_argument(
'--pass_num', type=int, default=10, help='num of epoch')
parser.add_argument(
'--use_cuda', type=int, default=0, help='whether use gpu')
parser.add_argument(
'--base_lr', type=float, default=0.01, help='learning rate')
parser.add_argument(
'--num_devices', type=int, default=1, help='Number of GPU devices')
parser.add_argument(
'--role', type=str, default='pserver', help='trainer or pserver')
parser.add_argument(
'--endpoints', type=str, default='127.0.0.1:6000', help='The pserver endpoints, like: 127.0.0.1:6000, 127.0.0.1:6001')
parser.add_argument(
'--current_endpoint', type=str, default='127.0.0.1:6000', help='The current_endpoint')
parser.add_argument(
'--trainer_id', type=int, default=0, help='trainer id ,only trainer_id=0 save model')
parser.add_argument(
'--trainers', type=int, default=1, help='The num of trianers, (default: 1)')
args = parser.parse_args()
return args
def get_cards(args):
return args.num_devices
def train():
""" do training """
args = parse_args()
train_dir = args.train_dir
vocab_text_path = args.vocab_text_path
vocab_tag_path = args.vocab_tag_path
use_cuda = True if args.use_cuda else False
batch_size = args.batch_size
neg_size = args.neg_size
vocab_text_size, vocab_tag_size, train_reader = utils.prepare_data(
file_dir=train_dir, vocab_text_path=vocab_text_path,
vocab_tag_path=vocab_tag_path, neg_size=neg_size,
batch_size=batch_size * get_cards(args),
buffer_size=batch_size*100, is_train=True)
""" train network """
# Train program
avg_cost, correct, cos_pos = net.network(vocab_text_size, vocab_tag_size, neg_size=neg_size)
# Optimization to minimize lost
sgd_optimizer = fluid.optimizer.SGD(learning_rate=args.base_lr)
sgd_optimizer.minimize(avg_cost)
def train_loop(main_program):
# Initialize executor
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
pass_num = args.pass_num
model_dir = args.model_dir
fetch_list = [avg_cost.name]
exe.run(fluid.default_startup_program())
total_time = 0.0
for pass_idx in range(pass_num):
epoch_idx = pass_idx + 1
print("epoch_%d start" % epoch_idx)
t0 = time.time()
for batch_id, data in enumerate(train_reader()):
lod_text_seq = utils.to_lodtensor([dat[0] for dat in data], place)
lod_pos_tag = utils.to_lodtensor([dat[1] for dat in data], place)
lod_neg_tag = utils.to_lodtensor([dat[2] for dat in data], place)
loss_val, correct_val = exe.run(
feed={
"text": lod_text_seq,
"pos_tag": lod_pos_tag,
"neg_tag": lod_neg_tag},
fetch_list=[avg_cost.name, correct.name])
if batch_id % args.print_batch == 0:
print("TRAIN --> pass: {} batch_num: {} avg_cost: {}, acc: {}"
.format(pass_idx, (batch_id+10) * batch_size, np.mean(loss_val),
float(np.sum(correct_val)) / batch_size))
t1 = time.time()
total_time += t1 - t0
print("epoch:%d num_steps:%d time_cost(s):%f" %
(epoch_idx, batch_id, total_time / epoch_idx))
save_dir = "%s/epoch_%d" % (model_dir, epoch_idx)
feed_var_names = ["text", "pos_tag"]
fetch_vars = [cos_pos]
fluid.io.save_inference_model(save_dir, feed_var_names, fetch_vars, exe)
print("finish training")
if args.is_local:
print("run local training")
train_loop(fluid.default_main_program())
else:
print("run distribute training")
t = fluid.DistributeTranspiler()
t.transpile(args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
if args.role == "pserver":
print("run psever")
pserver_prog = t.get_pserver_program(args.current_endpoint)
pserver_startup = t.get_startup_program(args.current_endpoint,
pserver_prog)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(pserver_startup)
exe.run(pserver_prog)
elif args.role == "trainer":
print("run trainer")
train_loop(t.get_trainer_program())
if __name__ == "__main__":
train()
#!/bin/bash
#export GLOG_v=30
#export GLOG_logtostderr=1
# start pserver0
python cluster_train.py \
--train_dir train_data \
--model_dir cluster_model \
--batch_size 5 \
--is_local 0 \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6000 \
--trainers 2 \
> pserver0.log 2>&1 &
# start pserver1
python cluster_train.py \
--train_dir train_data \
--model_dir cluster_model \
--batch_size 5 \
--is_local 0 \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6001 \
--trainers 2 \
> pserver1.log 2>&1 &
# start trainer0
#CUDA_VISIBLE_DEVICES=1 python cluster_train.py \
python cluster_train.py \
--train_dir train_data \
--model_dir cluster_model \
--batch_size 5 \
--print_batch 10 \
--use_cuda 0 \
--is_local 0 \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 0 \
> trainer0.log 2>&1 &
# start trainer1
#CUDA_VISIBLE_DEVICES=2 python cluster_train.py \
python cluster_train.py \
--train_dir train_data \
--model_dir cluster_model \
--batch_size 5 \
--print_batch 10 \
--use_cuda 0 \
--is_local 0 \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 1 \
> trainer1.log 2>&1 &
......@@ -65,7 +65,7 @@ def prepare_data(file_dir,
vocab_text_size = get_vocab_size(vocab_text_path)
vocab_tag_size = get_vocab_size(vocab_tag_path)
reader = sort_batch(
paddle.reader.shuffle(
fluid.io.shuffle(
train(
file_dir,
vocab_tag_size,
......
......@@ -403,23 +403,23 @@ acc = fluid.layers.accuracy(input=softmax_prob, label=labels_reshape)
在demo网络中,我们设置为从某一层的所有节点开始进行检索。paddle组网对输入定义的实现如下:
```python
def input_data(self):
input_emb = fluid.layers.data(
input_emb = fluid.data(
name="input_emb",
shape=[self.input_embed_size],
shape=[None, self.input_embed_size],
dtype="float32",
)
# first_layer 与 first_layer_mask 对应着infer起始的节点
first_layer = fluid.layers.data(
first_layer = fluid.data(
name="first_layer_node",
shape=[1],
shape=[None, 1],
dtype="int64",
lod_level=1, #支持变长
)
first_layer_mask = fluid.layers.data(
first_layer_mask = fluid.data(
name="first_layer_node_mask",
shape=[1],
shape=[None, 1],
dtype="int64",
lod_level=1,
)
......
......@@ -35,6 +35,7 @@ class TDMDataset(dg.MultiSlotStringDataGenerator):
"""
Read test_data line by line & yield batch
"""
def local_iter():
"""Read file line by line"""
for fname in infer_file_list:
......@@ -46,13 +47,14 @@ class TDMDataset(dg.MultiSlotStringDataGenerator):
yield [input_emb]
import paddle
batch_iter = paddle.batch(local_iter, batch)
batch_iter = fluid.io.batch(local_iter, batch)
return batch_iter
def generate_sample(self, line):
"""
Read the data line by line and process it as a dictionary
"""
def iterator():
"""
This function needs to be implemented by the user, based on data format
......
......@@ -41,26 +41,23 @@ class TdmInferNet(object):
self.input_trans_net = InputTransNet(args)
def input_data(self):
input_emb = fluid.layers.data(
input_emb = fluid.data(
name="input_emb",
shape=[self.input_embed_size],
dtype="float32",
)
shape=[None, self.input_embed_size],
dtype="float32", )
# first_layer 与 first_layer_mask 对应着infer起始层的节点
first_layer = fluid.layers.data(
first_layer = fluid.data(
name="first_layer_node",
shape=[1],
shape=[None, 1],
dtype="int64",
lod_level=1,
)
lod_level=1, )
first_layer_mask = fluid.layers.data(
first_layer_mask = fluid.data(
name="first_layer_node_mask",
shape=[1],
shape=[None, 1],
dtype="int64",
lod_level=1,
)
lod_level=1, )
inputs = [input_emb] + [first_layer] + [first_layer_mask]
return inputs
......@@ -125,28 +122,27 @@ class TdmInferNet(object):
size=[self.node_nums, self.node_embed_size],
param_attr=fluid.ParamAttr(name="TDM_Tree_Emb"))
input_fc_out = self.input_trans_net.layer_fc_infer(
input_trans_emb, layer_idx)
input_fc_out = self.input_trans_net.layer_fc_infer(input_trans_emb,
layer_idx)
# 过每一层的分类器
layer_classifier_res = self.layer_classifier.classifier_layer_infer(input_fc_out,
node_emb,
layer_idx)
layer_classifier_res = self.layer_classifier.classifier_layer_infer(
input_fc_out, node_emb, layer_idx)
# 过最终的判别分类器
tdm_fc = fluid.layers.fc(input=layer_classifier_res,
tdm_fc = fluid.layers.fc(
input=layer_classifier_res,
size=self.label_nums,
act=None,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(
name="tdm.cls_fc.weight"),
param_attr=fluid.ParamAttr(name="tdm.cls_fc.weight"),
bias_attr=fluid.ParamAttr(name="tdm.cls_fc.bias"))
prob = fluid.layers.softmax(tdm_fc)
positive_prob = fluid.layers.slice(
prob, axes=[2], starts=[1], ends=[2])
prob_re = fluid.layers.reshape(
positive_prob, [-1, current_layer_node_num])
prob_re = fluid.layers.reshape(positive_prob,
[-1, current_layer_node_num])
# 过滤掉padding产生的无效节点(node_id=0)
node_zero_mask = fluid.layers.cast(current_layer_node, 'bool')
......@@ -161,11 +157,10 @@ class TdmInferNet(object):
# index_sample op根据下标索引tensor对应位置的值
# 若paddle版本>2.0,调用方式为paddle.index_sample
top_node = fluid.contrib.layers.index_sample(
current_layer_node, topk_i)
top_node = fluid.contrib.layers.index_sample(current_layer_node,
topk_i)
prob_re_mask = prob_re * current_layer_child_mask # 过滤掉非叶子节点
topk_value = fluid.contrib.layers.index_sample(
prob_re_mask, topk_i)
topk_value = fluid.contrib.layers.index_sample(prob_re_mask, topk_i)
node_score.append(topk_value)
node_list.append(top_node)
......@@ -190,7 +185,8 @@ class TdmInferNet(object):
res_node = fluid.layers.reshape(res_layer_node, [-1, self.topK, 1])
# 利用Tree_info信息,将node_id转换为item_id
tree_info = fluid.default_main_program().global_block().var("TDM_Tree_Info")
tree_info = fluid.default_main_program().global_block().var(
"TDM_Tree_Info")
res_node_emb = fluid.layers.gather_nd(tree_info, res_node)
res_item = fluid.layers.slice(
......
#!/bin/bash
export MKL_NUM_THREADS=1
export OMP_NUM_THREADS=1
cudaid=${text_matching_on_quora:=0} # use 0-th card as default
export CUDA_VISIBLE_DEVICES=$cudaid
FLAGS_benchmark=true python train_and_evaluate.py --model_name=cdssmNet --config=cdssm_base --enable_ce --epoch_num=5 | python _ce.py
cudaid=${text_matching_on_quora_m:=0,1,2,3} # use 0,1,2,3 card as default
export CUDA_VISIBLE_DEVICES=$cudaid
FLAGS_benchmark=true python train_and_evaluate.py --model_name=cdssmNet --config=cdssm_base --enable_ce --epoch_num=5 | python _ce.py
# Text matching on Quora qestion-answer pair dataset
## contents
* [Introduction](#introduction)
* [a brief review of the Quora Question Pair (QQP) Task](#a-brief-review-of-the-quora-question-pair-qqp-task)
* [Our Work](#our-work)
* [Environment Preparation](#environment-preparation)
* [Install Fluid release 1.0](#install-fluid-release-10)
* [cpu version](#cpu-version)
* [gpu version](#gpu-version)
* [Have I installed Fluid successfully?](#have-i-installed-fluid-successfully)
* [Prepare Data](#prepare-data)
* [Train and evaluate](#train-and-evaluate)
* [Models](#models)
* [Results](#results)
## Introduction
### a brief review of the Quora Question Pair (QQP) Task
The [Quora Question Pair](https://data.quora.com/First-Quora-Dataset-Release-Question-Pairs) dataset contains 400,000 question pairs from [Quora](https://www.quora.com/), where people ask and answer questions related to specific areas. Each sample in the dataset consists of two questions (both English) and a label that represents whether the questions are duplicate. The dataset is well annotated by human.
Below are two samples from the dataset. The last column indicates whether the two questions are duplicate (1) or not (0).
|id | qid1 | qid2| question1| question2| is_duplicate
|:---:|:---:|:---:|:---:|:---:|:---:|
|0 |1 |2 |What is the step by step guide to invest in share market in india? |What is the step by step guide to invest in share market? |0|
|1 |3 |4 |What is the story of Kohinoor (Koh-i-Noor) Diamond? | What would happen if the Indian government stole the Kohinoor (Koh-i-Noor) diamond back? |0|
A [kaggle competition](https://www.kaggle.com/c/quora-question-pairs#description) was held based on this dataset in 2017. The kagglers were given a training dataset (with labels), and requested to make predictions on a test dataset (without labels). The predictions were evaluated by the log-likelihood loss on the test data.
The kaggle competition has inspired much effective work. However, most of these models are rule-based and difficult to be transferred to new tasks. Researchers are seeking for more general models that work well on this task and other natual language processing (NLP) tasks.
[Wang _et al._](https://arxiv.org/abs/1702.03814) proposed a bilateral multi-perspective matching (BIMPM) model based on the Quora Question Pair dataset. They splitted the original dataset to [3 parts](https://drive.google.com/file/d/0B0PlTAo--BnaQWlsZl9FZ3l1c28/view?usp=sharing): _train.tsv_ (384,348 samples), _dev.tsv_ (10,000 samples) and _test.tsv_ (10,000 samples). The class distribution of _train.tsv_ is unbalanced (37% positive and 63% negative), while those of _dev.tsv_ and _test.tsv_ are balanced(50% positive and 50% negetive). We used the same splitting method in our experiments.
### Our Work
Based on the Quora Question Pair Dataset, we implemented some classic models in the area of neural language understanding (NLU). The accuracy of prediction results are evaluated on the _test.tsv_ from [Wang _et al._](https://arxiv.org/abs/1702.03814).
## Environment Preparation
### Install Fluid release 1.0
Please follow the [official document in English](http://www.paddlepaddle.org/documentation/docs/en/1.0/build_and_install/pip_install_en.html) or [official document in Chinese](http://www.paddlepaddle.org/documentation/docs/zh/1.0/beginners_guide/install/Start.html) to install the Fluid deep learning framework.
#### Have I installed Fluid successfully?
Run the following script from your command line:
```shell
python -c "import paddle"
```
If Fluid is installed successfully you should see no error message. Feel free to open issues under the [PaddlePaddle repository](https://github.com/PaddlePaddle/Paddle/issues) for support.
## Prepare Data
Please download the Quora dataset from [Google drive](https://drive.google.com/file/d/0B0PlTAo--BnaQWlsZl9FZ3l1c28/view?usp=sharing) and unzip to $HOME/.cache/paddle/dataset.
Then run _data/prepare_quora_data.sh_ to download the pre-trained _word2vec_ embedding file -- _glove.840B.300d.zip_:
```shell
sh data/prepare_quora_data.sh
```
At this point the dataset directory ($HOME/.cache/paddle/dataset) structure should be:
```shell
$HOME/.cache/paddle/dataset
|- Quora_question_pair_partition
|- train.tsv
|- test.tsv
|- dev.tsv
|- readme.txt
|- wordvec.txt
|- glove.840B.300d.txt
```
## Train and evaluate
We provide multiple models and configurations. Details are shown in `models` and `configs` directories. For a quick start, please run the _cdssmNet_ model with the corresponding configuration:
```shell
python train_and_evaluate.py \
--model_name=cdssmNet \
--config=cdssm_base
```
Logs will be output to the console. If everything works well, the logging information will have the same formats as the content in _cdssm_base.log_.
All configurations used in our experiments are as follows:
|Model|Config|command
|:----:|:----:|:----:|
|cdssmNet|cdssm_base|python train_and_evaluate.py --model_name=cdssmNet --config=cdssm_base
|DecAttNet|decatt_glove|python train_and_evaluate.py --model_name=DecAttNet --config=decatt_glove
|InferSentNet|infer_sent_v1|python train_and_evaluate.py --model_name=InferSentNet --config=infer_sent_v1
|InferSentNet|infer_sent_v2|python train_and_evaluate.py --model_name=InferSentNet --config=infer_sent_v2
|SSENet|sse_base|python train_and_evaluate.py --model_name=SSENet --config=sse_base
## Models
We implemeted 4 models for now: the convolutional deep-structured semantic model (CDSSM, CNN-based), the InferSent model (RNN-based), the shortcut-stacked encoder (SSE, RNN-based), and the decomposed attention model (DecAtt, attention-based).
|Model|features|Context Encoder|Match Layer|Classification Layer
|:----:|:----:|:----:|:----:|:----:|
|CDSSM|word|1 layer conv1d|concatenation|MLP
|DecAtt|word|Attention|concatenation|MLP
|InferSent|word|1 layer Bi-LSTM|concatenation/element-wise product/<br>absolute element-wise difference|MLP
|SSE|word|3 layer Bi-LSTM|concatenation/element-wise product/<br>absolute element-wise difference|MLP
### CDSSM
```
@inproceedings{shen2014learning,
title={Learning semantic representations using convolutional neural networks for web search},
author={Shen, Yelong and He, Xiaodong and Gao, Jianfeng and Deng, Li and Mesnil, Gr{\'e}goire},
booktitle={Proceedings of the 23rd International Conference on World Wide Web},
pages={373--374},
year={2014},
organization={ACM}
}
```
### InferSent
```
@article{conneau2017supervised,
title={Supervised learning of universal sentence representations from natural language inference data},
author={Conneau, Alexis and Kiela, Douwe and Schwenk, Holger and Barrault, Loic and Bordes, Antoine},
journal={arXiv preprint arXiv:1705.02364},
year={2017}
}
```
### SSE
```
@article{nie2017shortcut,
title={Shortcut-stacked sentence encoders for multi-domain inference},
author={Nie, Yixin and Bansal, Mohit},
journal={arXiv preprint arXiv:1708.02312},
year={2017}
}
```
### DecAtt
```
@article{tomar2017neural,
title={Neural paraphrase identification of questions with noisy pretraining},
author={Tomar, Gaurav Singh and Duque, Thyago and T{\"a}ckstr{\"o}m, Oscar and Uszkoreit, Jakob and Das, Dipanjan},
journal={arXiv preprint arXiv:1704.04565},
year={2017}
}
```
## Results
|Model|Config|dev accuracy| test accuracy
|:----:|:----:|:----:|:----:|
|cdssmNet|cdssm_base|83.56%|82.83%|
|DecAttNet|decatt_glove|86.31%|86.22%|
|InferSentNet|infer_sent_v1|87.15%|86.62%|
|InferSentNet|infer_sent_v2|88.55%|88.43%|
|SSENet|sse_base|88.35%|88.25%|
In our experiment, we found that LSTM-based models outperformed convolution-based models. The DecAtt model has fewer parameters than LSTM-based models, but is sensitive to hyper-parameters.
<p align="center">
<img src="imgs/models_test_acc.png" width = "500" alt="test_acc"/>
</p>
# this file is only used for continuous evaluation test!
import os
import sys
sys.path.append(os.environ['ceroot'])
from kpi import CostKpi
from kpi import DurationKpi
each_pass_duration_card1_kpi = DurationKpi(
'each_pass_duration_card1', 0.08, 0, actived=True)
train_avg_cost_card1_kpi = CostKpi('train_avg_cost_card1', 0.08, 0)
train_avg_acc_card1_kpi = CostKpi('train_avg_acc_card1', 0.02, 0)
each_pass_duration_card4_kpi = DurationKpi(
'each_pass_duration_card4', 0.08, 0, actived=True)
train_avg_cost_card4_kpi = CostKpi('train_avg_cost_card4', 0.08, 0)
train_avg_acc_card4_kpi = CostKpi('train_avg_acc_card4', 0.02, 0)
tracking_kpis = [
each_pass_duration_card1_kpi,
train_avg_cost_card1_kpi,
train_avg_acc_card1_kpi,
each_pass_duration_card4_kpi,
train_avg_cost_card4_kpi,
train_avg_acc_card4_kpi,
]
def parse_log(log):
'''
This method should be implemented by model developers.
The suggestion:
each line in the log should be key, value, for example:
"
train_cost\t1.0
test_cost\t1.0
train_cost\t1.0
train_cost\t1.0
train_acc\t1.2
"
'''
for line in log.split('\n'):
fs = line.strip().split('\t')
print(fs)
if len(fs) == 3 and fs[0] == 'kpis':
kpi_name = fs[1]
kpi_value = float(fs[2])
yield kpi_name, kpi_value
def log_to_ce(log):
kpi_tracker = {}
for kpi in tracking_kpis:
kpi_tracker[kpi.name] = kpi
for (kpi_name, kpi_value) in parse_log(log):
print(kpi_name, kpi_value)
kpi_tracker[kpi_name].add_record(kpi_value)
kpi_tracker[kpi_name].persist()
if __name__ == '__main__':
log = sys.stdin.read()
log_to_ce(log)
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .cdssm import cdssm_base
from .dec_att import decatt_glove
from .sse import sse_base
from .infer_sent import infer_sent_v1
from .infer_sent import infer_sent_v2
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
class config(object):
def __init__(self):
self.batch_size = 128
self.epoch_num = 50
self.optimizer_type = 'adam' # sgd, adagrad
# pretrained word embedding
self.use_pretrained_word_embedding = True
# when employing pretrained word embedding,
# out of vocabulary words' embedding is initialized with uniform or normal numbers
self.OOV_fill = 'uniform'
self.embedding_norm = False
# or else, use padding and masks for sequence data
self.use_lod_tensor = True
# lr = lr * lr_decay after each epoch
self.lr_decay = 1
self.learning_rate = 0.001
self.save_dirname = 'model_dir'
self.train_samples_num = 384348
self.duplicate_data = False
self.metric_type = ['accuracy']
def list_config(self):
print("config", self.__dict__)
def has_member(self, var_name):
return var_name in self.__dict__
if __name__ == "__main__":
basic = config()
basic.list_config()
basic.ahh = 2
basic.list_config()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import basic_config
def cdssm_base():
"""
set configs
"""
config = basic_config.config()
config.learning_rate = 0.001
config.save_dirname = "model_dir"
config.use_pretrained_word_embedding = True
config.dict_dim = 40000 # approx_vocab_size
# net config
config.emb_dim = 300
config.kernel_size = 5
config.kernel_count = 300
config.fc_dim = 128
config.mlp_hid_dim = [128, 128]
config.droprate_conv = 0.1
config.droprate_fc = 0.1
config.class_dim = 2
return config
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import basic_config
def decatt_glove():
"""
use config 'decAtt_glove' in the paper 'Neural Paraphrase Identification of Questions with Noisy Pretraining'
"""
config = basic_config.config()
config.learning_rate = 0.05
config.save_dirname = "model_dir"
config.use_pretrained_word_embedding = True
config.dict_dim = 40000 # approx_vocab_size
config.metric_type = ['accuracy', 'accuracy_with_threshold']
config.optimizer_type = 'sgd'
config.lr_decay = 1
config.use_lod_tensor = False
config.embedding_norm = False
config.OOV_fill = 'uniform'
config.duplicate_data = False
# net config
config.emb_dim = 300
config.proj_emb_dim = 200 #TODO: has project?
config.num_units = [400, 200]
config.word_embedding_trainable = True
config.droprate = 0.1
config.share_wight_btw_seq = True
config.class_dim = 2
return config
def decatt_word():
"""
use config 'decAtt_glove' in the paper 'Neural Paraphrase Identification of Questions with Noisy Pretraining'
"""
config = basic_config.config()
config.learning_rate = 0.05
config.save_dirname = "model_dir"
config.use_pretrained_word_embedding = False
config.dict_dim = 40000 # approx_vocab_size
config.metric_type = ['accuracy', 'accuracy_with_threshold']
config.optimizer_type = 'sgd'
config.lr_decay = 1
config.use_lod_tensor = False
config.embedding_norm = False
config.OOV_fill = 'uniform'
config.duplicate_data = False
# net config
config.emb_dim = 300
config.proj_emb_dim = 200 #TODO: has project?
config.num_units = [400, 200]
config.word_embedding_trainable = True
config.droprate = 0.1
config.share_wight_btw_seq = True
config.class_dim = 2
return config
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Please download the Quora dataset firstly from https://drive.google.com/file/d/0B0PlTAo--BnaQWlsZl9FZ3l1c28/view?usp=sharing
# to the ROOT_DIR: $HOME/.cache/paddle/dataset
DATA_DIR=$HOME/.cache/paddle/dataset
wget --directory-prefix=$DATA_DIR http://nlp.stanford.edu/data/glove.840B.300d.zip
unzip $DATA_DIR/glove.840B.300d.zip
# The finally dataset dir should be like
# $HOME/.cache/paddle/dataset
# |- Quora_question_pair_partition
# |- train.tsv
# |- test.tsv
# |- dev.tsv
# |- readme.txt
# |- wordvec.txt
# |- glove.840B.300d.txt
Image files for this model: text_matching_on_quora
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册