diff --git a/fleet_rec/core/model.py b/fleet_rec/core/model.py index 528be0bf66312eab9e0f9b0d96d43ae4a75ac672..e7afbad57a53bea5898cadcf10d0d77e9dce852a 100644 --- a/fleet_rec/core/model.py +++ b/fleet_rec/core/model.py @@ -44,11 +44,12 @@ class Model(object): raise ValueError("configured optimizer can only supported SGD/Adam/Adagrad") if name == "SGD": - optimizer_i = fluid.optimizer.Adam(lr, lazy_mode=True) + reg = envs.get_global_env("hyper_parameters.reg", 0.0001, self._namespace) + optimizer_i = fluid.optimizer.SGD(lr, regularization=fluid.regularizer.L2DecayRegularizer(reg)) elif name == "ADAM": optimizer_i = fluid.optimizer.Adam(lr, lazy_mode=True) elif name == "ADAGRAD": - optimizer_i = fluid.optimizer.Adam(lr, lazy_mode=True) + optimizer_i = fluid.optimizer.Adagrad(lr) else: raise ValueError("configured optimizer can only supported SGD/Adam/Adagrad") @@ -57,7 +58,7 @@ class Model(object): def optimizer(self): learning_rate = envs.get_global_env("hyper_parameters.learning_rate", None, self._namespace) optimizer = envs.get_global_env("hyper_parameters.optimizer", None, self._namespace) - + print(">>>>>>>>>>>.learnig rate: %s" %learning_rate) return self._build_optimizer(optimizer, learning_rate) @abc.abstractmethod diff --git a/models/rank/dcn/__init__.py b/models/rank/dcn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abf198b97e6e818e1fbe59006f98492640bcee54 --- /dev/null +++ b/models/rank/dcn/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/models/rank/dcn/config.yaml b/models/rank/dcn/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d84adb748ae460512a5c31fa898adfda1a88da56 --- /dev/null +++ b/models/rank/dcn/config.yaml @@ -0,0 +1,53 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +train: + trainer: + # for cluster training + strategy: "async" + + epochs: 10 + workspace: "fleetrec.models.rank.dcn" + + reader: + batch_size: 2 + class: "{workspace}/criteo_reader.py" + train_data_path: "{workspace}/data/train" + feat_dict_name: "{workspace}/data/vocab" + + model: + models: "{workspace}/model.py" + hyper_parameters: + cross_num: 2 + dnn_hidden_units: [128, 128] + l2_reg_cross: 0.00005 + dnn_use_bn: False + clip_by_norm: 100.0 + cat_feat_num: "{workspace}/data/cat_feature_num.txt" + is_sparse: False + is_test: False + num_field: 39 + learning_rate: 0.0001 + act: "relu" + optimizer: adam + + save: + increment: + dirname: "increment" + epoch_interval: 2 + save_last: True + inference: + dirname: "inference" + epoch_interval: 4 + save_last: True diff --git a/models/rank/dcn/criteo_reader.py b/models/rank/dcn/criteo_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..4b81e1fb66b6fac2ead7583ad93451e4822077d2 --- /dev/null +++ b/models/rank/dcn/criteo_reader.py @@ -0,0 +1,90 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function +import math +import sys + +from fleetrec.core.reader import Reader +from fleetrec.core.utils import envs +try: + import cPickle as pickle +except ImportError: + import pickle +from collections import Counter +import os + +class TrainReader(Reader): + def init(self): + self.cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + self.cont_max_ = [ + 5775, 257675, 65535, 969, 23159456, 431037, 56311, 6047, 29019, 11, + 231, 4008, 7393 + ] + self.cont_diff_ = [ + self.cont_max_[i] - self.cont_min_[i] + for i in range(len(self.cont_min_)) + ] + self.cont_idx_ = list(range(1, 14)) + self.cat_idx_ = list(range(14, 40)) + + dense_feat_names = ['I' + str(i) for i in range(1, 14)] + sparse_feat_names = ['C' + str(i) for i in range(1, 27)] + target = ['label'] + + self.label_feat_names = target + dense_feat_names + sparse_feat_names + + self.cat_feat_idx_dict_list = [{} for _ in range(26)] + + # TODO: set vocabulary dictionary + vocab_dir = envs.get_global_env("feat_dict_name", None, "train.reader") + for i in range(26): + lookup_idx = 1 # remain 0 for default value + for line in open( + os.path.join(vocab_dir, 'C' + str(i + 1) + '.txt')): + self.cat_feat_idx_dict_list[i][line.strip()] = lookup_idx + lookup_idx += 1 + + def _process_line(self, line): + features = line.rstrip('\n').split('\t') + label_feat_list = [[] for _ in range(40)] + for idx in self.cont_idx_: + if features[idx] == '': + label_feat_list[idx].append(0) + else: + # 0-1 minmax norm + # label_feat_list[idx].append((float(features[idx]) - self.cont_min_[idx - 1]) / + # self.cont_diff_[idx - 1]) + # log transform + label_feat_list[idx].append( + math.log(4 + float(features[idx])) + if idx == 2 else math.log(1 + float(features[idx]))) + for idx in self.cat_idx_: + if features[idx] == '' or features[ + idx] not in self.cat_feat_idx_dict_list[idx - 14]: + label_feat_list[idx].append(0) + else: + label_feat_list[idx].append(self.cat_feat_idx_dict_list[ + idx - 14][features[idx]]) + label_feat_list[0].append(int(features[0])) + return label_feat_list + + def generate_sample(self, line): + """ + Read the data line by line and process it as a dictionary + """ + def data_iter(): + label_feat_list = self._process_line(line) + yield list(zip(self.label_feat_names, label_feat_list)) + + return data_iter \ No newline at end of file diff --git a/models/rank/dcn/data/download.py b/models/rank/dcn/data/download.py new file mode 100644 index 0000000000000000000000000000000000000000..b862988aef1802d8bfeb2554b7dce812e941f9e9 --- /dev/null +++ b/models/rank/dcn/data/download.py @@ -0,0 +1,24 @@ +import os +import sys +import io + +LOCAL_PATH = os.path.dirname(os.path.abspath(__file__)) +TOOLS_PATH = os.path.join(LOCAL_PATH, "..", "..", "tools") +sys.path.append(TOOLS_PATH) + +from fleetrec.tools.tools import download_file_and_uncompress + +if __name__ == '__main__': + trainfile = 'train.txt' + url = "https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz" + + print("download and extract starting...") + download_file_and_uncompress(url) + print("download and extract finished") + + count = 0 + for _ in io.open(trainfile, 'r', encoding='utf-8'): + count += 1 + + print("total records: %d" % count) + print("done") diff --git a/models/rank/dcn/data/preprocess.py b/models/rank/dcn/data/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..b356607729eedd73854a77449ffda3cc3bb8050f --- /dev/null +++ b/models/rank/dcn/data/preprocess.py @@ -0,0 +1,131 @@ +from __future__ import print_function, absolute_import, division + +import os +import sys +from collections import Counter +import numpy as np +""" +preprocess Criteo train data, generate extra statistic files for model input. +""" +# input filename +FILENAME = 'train.200000.txt' + +# global vars +CAT_FEATURE_NUM = 'cat_feature_num.txt' +INT_FEATURE_MINMAX = 'int_feature_minmax.txt' +VOCAB_DIR = 'vocab' +TRAIN_DIR = 'train' +TEST_VALID_DIR = 'test_valid' +SPLIT_RATIO = 0.9 +FREQ_THR = 10 + +INT_COLUMN_NAMES = ['I' + str(i) for i in range(1, 14)] +CAT_COLUMN_NAMES = ['C' + str(i) for i in range(1, 27)] + + +def check_statfiles(): + """ + check if statistic files of Criteo exists + :return: + """ + statsfiles = [CAT_FEATURE_NUM, INT_FEATURE_MINMAX] + [ + os.path.join(VOCAB_DIR, cat_fn + '.txt') for cat_fn in CAT_COLUMN_NAMES + ] + if all([os.path.exists(fn) for fn in statsfiles]): + return True + return False + + +def create_statfiles(): + """ + create statistic files of Criteo, including: + min/max of interger features + counts of categorical features + vocabs of each categorical features + :return: + """ + int_minmax_list = [[sys.maxsize, -sys.maxsize] + for _ in range(13)] # count integer feature min max + cat_ct_list = [Counter() for _ in range(26)] # count categorical features + for idx, line in enumerate(open(FILENAME)): + spls = line.rstrip('\n').split('\t') + assert len(spls) == 40 + + for i in range(13): + if not spls[1 + i]: continue + int_val = int(spls[1 + i]) + int_minmax_list[i][0] = min(int_minmax_list[i][0], int_val) + int_minmax_list[i][1] = max(int_minmax_list[i][1], int_val) + + for i in range(26): + cat_ct_list[i].update([spls[14 + i]]) + + # save min max of integer features + with open(INT_FEATURE_MINMAX, 'w') as f: + for name, minmax in zip(INT_COLUMN_NAMES, int_minmax_list): + print("{} {} {}".format(name, minmax[0], minmax[1]), file=f) + + # remove '' from all cat_set[i] and filter low freq categorical value + cat_set_list = [set() for i in range(len(cat_ct_list))] + for i, ct in enumerate(cat_ct_list): + if '' in ct: del ct[''] + for key in list(ct.keys()): + if ct[key] >= FREQ_THR: + cat_set_list[i].add(key) + + del cat_ct_list + + # create vocab dir + if not os.path.exists(VOCAB_DIR): + os.makedirs(VOCAB_DIR) + + # write vocab file of categorical features + with open(CAT_FEATURE_NUM, 'w') as cat_feat_count_file: + for name, s in zip(CAT_COLUMN_NAMES, cat_set_list): + print('{} {}'.format(name, len(s)), file=cat_feat_count_file) + + vocabfile = os.path.join(VOCAB_DIR, name + '.txt') + + with open(vocabfile, 'w') as f: + for vocab_val in s: + print(vocab_val, file=f) + + +def split_data(): + """ + split train.txt into train and test_valid files. + :return: + """ + if not os.path.exists(TRAIN_DIR): + os.makedirs(TRAIN_DIR) + if not os.path.exists(TEST_VALID_DIR): + os.makedirs(TEST_VALID_DIR) + + fin = open('train.200000.txt', 'r') + data_dir = TRAIN_DIR + fout = open(os.path.join(data_dir, 'part-0'), 'w') + split_idx = int(45840617 * SPLIT_RATIO) + for line_idx, line in enumerate(fin): + if line_idx == split_idx: + fout.close() + data_dir = TEST_VALID_DIR + cur_part_idx = int(line_idx / 200000) + fout = open( + os.path.join(data_dir, 'part-' + str(cur_part_idx)), 'w') + if line_idx % 200000 == 0 and line_idx != 0: + fout.close() + cur_part_idx = int(line_idx / 200000) + fout = open( + os.path.join(data_dir, 'part-' + str(cur_part_idx)), 'w') + fout.write(line) + fout.close() + fin.close() + + +if __name__ == '__main__': + if not check_statfiles(): + print('create statstic files of Criteo...') + create_statfiles() + print('split train.200000.txt...') + split_data() + print('done') diff --git a/models/rank/dcn/model.py b/models/rank/dcn/model.py new file mode 100644 index 0000000000000000000000000000000000000000..c0395c27318cda335dd8271e28df1f8b0a9193b9 --- /dev/null +++ b/models/rank/dcn/model.py @@ -0,0 +1,146 @@ +import paddle.fluid as fluid +import math + +from fleetrec.core.utils import envs +from fleetrec.core.model import Model as ModelBase +from collections import OrderedDict + +class Model(ModelBase): + def __init__(self, config): + ModelBase.__init__(self, config) + + def init_network(self): + self.cross_num = envs.get_global_env("hyper_parameters.cross_num", None, self._namespace) + self.dnn_hidden_units = envs.get_global_env("hyper_parameters.dnn_hidden_units", None, self._namespace) + self.l2_reg_cross = envs.get_global_env("hyper_parameters.l2_reg_cross", None, self._namespace) + self.dnn_use_bn = envs.get_global_env("hyper_parameters.dnn_use_bn", None, self._namespace) + self.clip_by_norm = envs.get_global_env("hyper_parameters.clip_by_norm", None, self._namespace) + cat_feat_num = envs.get_global_env("hyper_parameters.cat_feat_num", None, self._namespace) + cat_feat_dims_dict = OrderedDict() + for line in open(cat_feat_num): + spls = line.strip().split() + assert len(spls) == 2 + cat_feat_dims_dict[spls[0]] = int(spls[1]) + self.cat_feat_dims_dict = cat_feat_dims_dict if cat_feat_dims_dict else OrderedDict( + ) + self.is_sparse = envs.get_global_env("hyper_parameters.is_sparse", None, self._namespace) + + self.dense_feat_names = ['I' + str(i) for i in range(1, 14)] + self.sparse_feat_names = ['C' + str(i) for i in range(1, 27)] + + # {feat_name: dims} + self.feat_dims_dict = OrderedDict( + [(feat_name, 1) for feat_name in self.dense_feat_names]) + self.feat_dims_dict.update(self.cat_feat_dims_dict) + + self.net_input = None + self.loss = None + + def _create_embedding_input(self, data_dict): + # sparse embedding + sparse_emb_dict = OrderedDict((name, fluid.embedding( + input=fluid.layers.cast( + data_dict[name], dtype='int64'), + size=[ + self.feat_dims_dict[name] + 1, + 6 * int(pow(self.feat_dims_dict[name], 0.25)) + ], + is_sparse=self.is_sparse)) for name in self.sparse_feat_names) + + # combine dense and sparse_emb + dense_input_list = [ + data_dict[name] for name in data_dict if name.startswith('I') + ] + sparse_emb_list = list(sparse_emb_dict.values()) + + sparse_input = fluid.layers.concat(sparse_emb_list, axis=-1) + sparse_input = fluid.layers.flatten(sparse_input) + + dense_input = fluid.layers.concat(dense_input_list, axis=-1) + dense_input = fluid.layers.flatten(dense_input) + dense_input = fluid.layers.cast(dense_input, 'float32') + + net_input = fluid.layers.concat([dense_input, sparse_input], axis=-1) + + return net_input + + def _deep_net(self, input, hidden_units, use_bn=False, is_test=False): + for units in hidden_units: + input = fluid.layers.fc(input=input, size=units) + if use_bn: + input = fluid.layers.batch_norm(input, is_test=is_test) + input = fluid.layers.relu(input) + return input + + def _cross_layer(self, x0, x, prefix): + input_dim = x0.shape[-1] + w = fluid.layers.create_parameter( + [input_dim], dtype='float32', name=prefix + "_w") + b = fluid.layers.create_parameter( + [input_dim], dtype='float32', name=prefix + "_b") + xw = fluid.layers.reduce_sum(x * w, dim=1, keep_dim=True) # (N, 1) + return x0 * xw + b + x, w + + def _cross_net(self, input, num_corss_layers): + x = x0 = input + l2_reg_cross_list = [] + for i in range(num_corss_layers): + x, w = self._cross_layer(x0, x, "cross_layer_{}".format(i)) + l2_reg_cross_list.append(self._l2_loss(w)) + l2_reg_cross_loss = fluid.layers.reduce_sum( + fluid.layers.concat( + l2_reg_cross_list, axis=-1)) + return x, l2_reg_cross_loss + + def _l2_loss(self, w): + return fluid.layers.reduce_sum(fluid.layers.square(w)) + + def train_net(self): + self.init_network() + self.target_input = fluid.data( + name='label', shape=[None, 1], dtype='float32') + data_dict = OrderedDict() + for feat_name in self.feat_dims_dict: + data_dict[feat_name] = fluid.data( + name=feat_name, shape=[None, 1], dtype='float32') + + self.net_input = self._create_embedding_input(data_dict) + + deep_out = self._deep_net(self.net_input, self.dnn_hidden_units, self.dnn_use_bn, False) + + cross_out, l2_reg_cross_loss = self._cross_net(self.net_input, + self.cross_num) + + last_out = fluid.layers.concat([deep_out, cross_out], axis=-1) + logit = fluid.layers.fc(last_out, 1) + + self.prob = fluid.layers.sigmoid(logit) + self._data_var = [self.target_input] + [ + data_dict[dense_name] for dense_name in self.dense_feat_names + ] + [data_dict[sparse_name] for sparse_name in self.sparse_feat_names] + + # auc + prob_2d = fluid.layers.concat([1 - self.prob, self.prob], 1) + label_int = fluid.layers.cast(self.target_input, 'int64') + auc_var, batch_auc_var, self.auc_states = fluid.layers.auc( + input=prob_2d, label=label_int, slide_steps=0) + self._metrics["AUC"] = auc_var + self._metrics["BATCH_AUC"] = batch_auc_var + + + # logloss + logloss = fluid.layers.log_loss(self.prob, self.target_input) + self.avg_logloss = fluid.layers.reduce_mean(logloss) + + # reg_coeff * l2_reg_cross + l2_reg_cross_loss = self.l2_reg_cross * l2_reg_cross_loss + self.loss = self.avg_logloss + l2_reg_cross_loss + self._cost = self.loss + + def optimizer(self): + learning_rate = envs.get_global_env("hyper_parameters.learning_rate", None, self._namespace) + optimizer = fluid.optimizer.Adam(learning_rate, lazy_mode=True) + return optimizer + + def infer_net(self, parameter_list): + self.deepfm_net() diff --git a/models/rank/deepfm/__init__.py b/models/rank/deepfm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abf198b97e6e818e1fbe59006f98492640bcee54 --- /dev/null +++ b/models/rank/deepfm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/models/rank/deepfm/config.yaml b/models/rank/deepfm/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd06406db481fdfb8e00aefecda7ca42a3c89353 --- /dev/null +++ b/models/rank/deepfm/config.yaml @@ -0,0 +1,49 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +train: + trainer: + # for cluster training + strategy: "async" + + epochs: 10 + workspace: "fleetrec.models.rank.deepfm" + + reader: + batch_size: 2 + class: "{workspace}/criteo_reader.py" + train_data_path: "{workspace}/data/train_data" + feat_dict_name: "{workspace}/data/aid_data/feat_dict_10.pkl2" + + model: + models: "{workspace}/model.py" + hyper_parameters: + sparse_feature_number: 1086460 + sparse_feature_dim: 9 + num_field: 39 + fc_sizes: [400, 400, 400] + learning_rate: 0.0001 + reg: 0.001 + act: "relu" + optimizer: SGD + + save: + increment: + dirname: "increment" + epoch_interval: 2 + save_last: True + inference: + dirname: "inference" + epoch_interval: 4 + save_last: True diff --git a/models/rank/deepfm/criteo_reader.py b/models/rank/deepfm/criteo_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..a4c2587a507cd1744383c725e5a237950abffd42 --- /dev/null +++ b/models/rank/deepfm/criteo_reader.py @@ -0,0 +1,71 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +from fleetrec.core.reader import Reader +from fleetrec.core.utils import envs +try: + import cPickle as pickle +except ImportError: + import pickle + +class TrainReader(Reader): + def init(self): + self.cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + self.cont_max_ = [ + 5775, 257675, 65535, 969, 23159456, 431037, 56311, 6047, 29019, 46, + 231, 4008, 7393 + ] + self.cont_diff_ = [ + self.cont_max_[i] - self.cont_min_[i] + for i in range(len(self.cont_min_)) + ] + self.continuous_range_ = range(1, 14) + self.categorical_range_ = range(14, 40) + # load preprocessed feature dict + self.feat_dict_name = envs.get_global_env("feat_dict_name", None, "train.reader") + self.feat_dict_ = pickle.load(open(self.feat_dict_name, 'rb')) + + def _process_line(self, line): + features = line.rstrip('\n').split('\t') + feat_idx = [] + feat_value = [] + for idx in self.continuous_range_: + if features[idx] == '': + feat_idx.append(0) + feat_value.append(0.0) + else: + feat_idx.append(self.feat_dict_[idx]) + feat_value.append( + (float(features[idx]) - self.cont_min_[idx - 1]) / + self.cont_diff_[idx - 1]) + for idx in self.categorical_range_: + if features[idx] == '' or features[idx] not in self.feat_dict_: + feat_idx.append(0) + feat_value.append(0.0) + else: + feat_idx.append(self.feat_dict_[features[idx]]) + feat_value.append(1.0) + label = [int(features[0])] + return feat_idx, feat_value, label + + def generate_sample(self, line): + """ + Read the data line by line and process it as a dictionary + """ + def data_iter(): + feat_idx, feat_value, label = self._process_line(line) + yield [('feat_idx', feat_idx), ('feat_value', feat_value), ('label', label)] + + return data_iter \ No newline at end of file diff --git a/models/rank/deepfm/data/download_preprocess.py b/models/rank/deepfm/data/download_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..1f114bcfb7ec7ce5c93e676bb0467026eedf6f33 --- /dev/null +++ b/models/rank/deepfm/data/download_preprocess.py @@ -0,0 +1,25 @@ +import os +import shutil +import sys + +LOCAL_PATH = os.path.dirname(os.path.abspath(__file__)) +TOOLS_PATH = os.path.join(LOCAL_PATH, "..", "..", "tools") +sys.path.append(TOOLS_PATH) + +from fleetrec.tools.tools import download_file_and_uncompress, download_file + +if __name__ == '__main__': + url = "https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz" + url2 = "https://paddlerec.bj.bcebos.com/deepfm%2Ffeat_dict_10.pkl2" + + print("download and extract starting...") + download_file_and_uncompress(url) + download_file(url2, "./aid_data/feat_dict_10.pkl2", True) + print("download and extract finished") + + print("preprocessing...") + os.system("python preprocess.py") + print("preprocess done") + + shutil.rmtree("raw_data") + print("done") diff --git a/models/rank/deepfm/data/preprocess.py b/models/rank/deepfm/data/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..1fa4a5feae17bde64463d2f05beb3d053284dcda --- /dev/null +++ b/models/rank/deepfm/data/preprocess.py @@ -0,0 +1,101 @@ +import os +import numpy +from collections import Counter +import shutil +import pickle + + +def get_raw_data(): + if not os.path.isdir('raw_data'): + os.mkdir('raw_data') + + fin = open('train.txt', 'r') + fout = open('raw_data/part-0', 'w') + for line_idx, line in enumerate(fin): + if line_idx % 200000 == 0 and line_idx != 0: + fout.close() + cur_part_idx = int(line_idx / 200000) + fout = open('raw_data/part-' + str(cur_part_idx), 'w') + fout.write(line) + fout.close() + fin.close() + + +def split_data(): + split_rate_ = 0.9 + dir_train_file_idx_ = 'aid_data/train_file_idx.txt' + filelist_ = [ + 'raw_data/part-%d' % x for x in range(len(os.listdir('raw_data'))) + ] + + if not os.path.exists(dir_train_file_idx_): + train_file_idx = list( + numpy.random.choice( + len(filelist_), int(len(filelist_) * split_rate_), False)) + with open(dir_train_file_idx_, 'w') as fout: + fout.write(str(train_file_idx)) + else: + with open(dir_train_file_idx_, 'r') as fin: + train_file_idx = eval(fin.read()) + + for idx in range(len(filelist_)): + if idx in train_file_idx: + shutil.move(filelist_[idx], 'train_data') + else: + shutil.move(filelist_[idx], 'test_data') + + +def get_feat_dict(): + freq_ = 10 + dir_feat_dict_ = 'aid_data/feat_dict_' + str(freq_) + '.pkl2' + continuous_range_ = range(1, 14) + categorical_range_ = range(14, 40) + + if not os.path.exists(dir_feat_dict_): + # print('generate a feature dict') + # Count the number of occurrences of discrete features + feat_cnt = Counter() + with open('train.txt', 'r') as fin: + for line_idx, line in enumerate(fin): + if line_idx % 100000 == 0: + print('generating feature dict', line_idx / 45000000) + features = line.rstrip('\n').split('\t') + for idx in categorical_range_: + if features[idx] == '': continue + feat_cnt.update([features[idx]]) + + # Only retain discrete features with high frequency + dis_feat_set = set() + for feat, ot in feat_cnt.items(): + if ot >= freq_: + dis_feat_set.add(feat) + + # Create a dictionary for continuous and discrete features + feat_dict = {} + tc = 1 + # Continuous features + for idx in continuous_range_: + feat_dict[idx] = tc + tc += 1 + for feat in dis_feat_set: + feat_dict[feat] = tc + tc += 1 + # Save dictionary + with open(dir_feat_dict_, 'wb') as fout: + pickle.dump(feat_dict, fout, protocol=2) + print('args.num_feat ', len(feat_dict) + 1) + + +if __name__ == '__main__': + if not os.path.isdir('train_data'): + os.mkdir('train_data') + if not os.path.isdir('test_data'): + os.mkdir('test_data') + if not os.path.isdir('aid_data'): + os.mkdir('aid_data') + + get_raw_data() + split_data() + get_feat_dict() + + print('Done!') diff --git a/models/rank/deepfm/model.py b/models/rank/deepfm/model.py new file mode 100644 index 0000000000000000000000000000000000000000..bd5e80e56de81cce0c92379a05797ca81027ac23 --- /dev/null +++ b/models/rank/deepfm/model.py @@ -0,0 +1,147 @@ +import paddle.fluid as fluid +import math + +from fleetrec.core.utils import envs +from fleetrec.core.model import Model as ModelBase + + +class Model(ModelBase): + def __init__(self, config): + ModelBase.__init__(self, config) + + def deepfm_net(self): + init_value_ = 0.1 + is_distributed = True if envs.get_trainer() == "CtrTrainer" else False + sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number", None, self._namespace) + sparse_feature_dim = envs.get_global_env("hyper_parameters.sparse_feature_dim", None, self._namespace) + + # ------------------------- network input -------------------------- + + num_field = envs.get_global_env("hyper_parameters.num_field", None, self._namespace) + raw_feat_idx = fluid.data(name='feat_idx', shape=[None, num_field], dtype='int64') # None * num_field(defalut:39) + raw_feat_value = fluid.data(name='feat_value', shape=[None, num_field], dtype='float32') # None * num_field + self.label = fluid.data(name='label', shape=[None, 1], dtype='float32') # None * 1 + feat_idx = fluid.layers.reshape(raw_feat_idx,[-1, 1]) # (None * num_field) * 1 + feat_value = fluid.layers.reshape(raw_feat_value, [-1, num_field, 1]) # None * num_field * 1 + + # ------------------------- set _data_var -------------------------- + + self._data_var.append(raw_feat_idx) + self._data_var.append(raw_feat_value) + self._data_var.append(self.label) + if self._platform != "LINUX": + self._data_loader = fluid.io.DataLoader.from_generator( + feed_list=self._data_var, capacity=64, use_double_buffer=False, iterable=False) + + #------------------------- first order term -------------------------- + + reg = envs.get_global_env("hyper_parameters.reg", 1e-4, self._namespace) + first_weights_re = fluid.embedding( + input=feat_idx, + is_sparse=True, + is_distributed=is_distributed, + dtype='float32', + size=[sparse_feature_number + 1, 1], + padding_idx=0, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.TruncatedNormalInitializer( + loc=0.0, scale=init_value_), + regularizer=fluid.regularizer.L1DecayRegularizer(reg))) + first_weights = fluid.layers.reshape( + first_weights_re, shape=[-1, num_field, 1]) # None * num_field * 1 + y_first_order = fluid.layers.reduce_sum((first_weights * feat_value), 1) + + #------------------------- second order term -------------------------- + + feat_embeddings_re = fluid.embedding( + input=feat_idx, + is_sparse=True, + is_distributed=is_distributed, + dtype='float32', + size=[sparse_feature_number + 1, sparse_feature_dim], + padding_idx=0, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.TruncatedNormalInitializer( + loc=0.0, scale=init_value_ / math.sqrt(float(sparse_feature_dim))))) + feat_embeddings = fluid.layers.reshape( + feat_embeddings_re, + shape=[-1, num_field, + sparse_feature_dim]) # None * num_field * embedding_size + feat_embeddings = feat_embeddings * feat_value # None * num_field * embedding_size + + # sum_square part + summed_features_emb = fluid.layers.reduce_sum(feat_embeddings, + 1) # None * embedding_size + summed_features_emb_square = fluid.layers.square( + summed_features_emb) # None * embedding_size + + # square_sum part + squared_features_emb = fluid.layers.square( + feat_embeddings) # None * num_field * embedding_size + squared_sum_features_emb = fluid.layers.reduce_sum( + squared_features_emb, 1) # None * embedding_size + + y_second_order = 0.5 * fluid.layers.reduce_sum( + summed_features_emb_square - squared_sum_features_emb, 1, + keep_dim=True) # None * 1 + + + #------------------------- DNN -------------------------- + + layer_sizes = envs.get_global_env("hyper_parameters.fc_sizes", None, self._namespace) + act = envs.get_global_env("hyper_parameters.act", None, self._namespace) + y_dnn = fluid.layers.reshape(feat_embeddings, + [-1, num_field * sparse_feature_dim]) + for s in layer_sizes: + y_dnn = fluid.layers.fc( + input=y_dnn, + size=s, + act=act, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.TruncatedNormalInitializer( + loc=0.0, scale=init_value_ / math.sqrt(float(10)))), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.TruncatedNormalInitializer( + loc=0.0, scale=init_value_))) + y_dnn = fluid.layers.fc( + input=y_dnn, + size=1, + act=None, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.TruncatedNormalInitializer( + loc=0.0, scale=init_value_)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.TruncatedNormalInitializer( + loc=0.0, scale=init_value_))) + + #------------------------- DeepFM -------------------------- + + self.predict = fluid.layers.sigmoid(y_first_order + y_second_order + y_dnn) + + def train_net(self): + self.deepfm_net() + + #------------------------- Cost(logloss) -------------------------- + + cost = fluid.layers.log_loss(input=self.predict, label=self.label) + avg_cost = fluid.layers.reduce_sum(cost) + + self._cost = avg_cost + + #------------------------- Metric(Auc) -------------------------- + + predict_2d = fluid.layers.concat([1 - self.predict, self.predict], 1) + label_int = fluid.layers.cast(self.label, 'int64') + auc_var, batch_auc_var, _ = fluid.layers.auc(input=predict_2d, + label=label_int, + slide_steps=0) + self._metrics["AUC"] = auc_var + self._metrics["BATCH_AUC"] = batch_auc_var + + def optimizer(self): + learning_rate = envs.get_global_env("hyper_parameters.learning_rate", None, self._namespace) + optimizer = fluid.optimizer.Adam(learning_rate, lazy_mode=True) + return optimizer + + def infer_net(self, parameter_list): + self.deepfm_net() \ No newline at end of file diff --git a/models/rank/xdeepfm/__init__.py b/models/rank/xdeepfm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abf198b97e6e818e1fbe59006f98492640bcee54 --- /dev/null +++ b/models/rank/xdeepfm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/models/rank/xdeepfm/config.yaml b/models/rank/xdeepfm/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a93b61fbce0d835b478c86f0e572bc5c88ab6138 --- /dev/null +++ b/models/rank/xdeepfm/config.yaml @@ -0,0 +1,50 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +train: + trainer: + # for cluster training + strategy: "async" + + epochs: 10 + workspace: "fleetrec.models.rank.xdeepfm" + + reader: + batch_size: 2 + class: "{workspace}/criteo_reader.py" + train_data_path: "{workspace}/data/train_data" + + model: + models: "{workspace}/model.py" + hyper_parameters: + layer_sizes_dnn: [10, 10, 10] + layer_sizes_cin: [10, 10] + sparse_feature_number: 1086460 + sparse_feature_dim: 9 + num_field: 39 + fc_sizes: [400, 400, 400] + learning_rate: 0.0001 + reg: 0.0001 + act: "relu" + optimizer: SGD + + save: + increment: + dirname: "increment" + epoch_interval: 2 + save_last: True + inference: + dirname: "inference" + epoch_interval: 4 + save_last: True diff --git a/models/rank/xdeepfm/criteo_reader.py b/models/rank/xdeepfm/criteo_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..1b5e4041625228f5ebaa0fefe9f2ada566a5cecb --- /dev/null +++ b/models/rank/xdeepfm/criteo_reader.py @@ -0,0 +1,43 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +from fleetrec.core.reader import Reader +from fleetrec.core.utils import envs +try: + import cPickle as pickle +except ImportError: + import pickle + +class TrainReader(Reader): + def init(self): + pass + + def _process_line(self, line): + features = line.strip('\n').split('\t') + feat_idx = [] + feat_value = [] + for idx in range(1, 40): + feat_idx.append(int(features[idx])) + feat_value.append(1.0) + label = [int(features[0])] + return feat_idx, feat_value, label + + def generate_sample(self, line): + def data_iter(): + feat_idx, feat_value, label = self._process_line(line) + yield [('feat_idx', feat_idx), ('feat_value', feat_value), ('label', + label)] + + return data_iter \ No newline at end of file diff --git a/models/rank/xdeepfm/data/download.py b/models/rank/xdeepfm/data/download.py new file mode 100644 index 0000000000000000000000000000000000000000..d0483ea3f0d5ddfeb1ad5123bd91cf2d5b6e1331 --- /dev/null +++ b/models/rank/xdeepfm/data/download.py @@ -0,0 +1,28 @@ +import os +import shutil +import sys + +LOCAL_PATH = os.path.dirname(os.path.abspath(__file__)) +TOOLS_PATH = os.path.join(LOCAL_PATH, "..", "..", "tools") +sys.path.append(TOOLS_PATH) + +from fleetrec.tools.tools import download_file_and_uncompress, download_file + +if __name__ == '__main__': + url_train = "https://paddlerec.bj.bcebos.com/xdeepfm%2Ftr" + url_test = "https://paddlerec.bj.bcebos.com/xdeepfm%2Fev" + + train_dir = "train_data" + test_dir = "test_data" + + if not os.path.exists(train_dir): + os.mkdir(train_dir) + if not os.path.exists(test_dir): + os.mkdir(test_dir) + + print("download and extract starting...") + download_file(url_train, "./train_data/tr", True) + download_file(url_test, "./test_data/ev", True) + print("download and extract finished") + + print("done") diff --git a/models/rank/xdeepfm/model.py b/models/rank/xdeepfm/model.py new file mode 100644 index 0000000000000000000000000000000000000000..6619c78bc718674f0efd12cf841efbe85f3cd729 --- /dev/null +++ b/models/rank/xdeepfm/model.py @@ -0,0 +1,164 @@ +import paddle.fluid as fluid +import math + +from fleetrec.core.utils import envs +from fleetrec.core.model import Model as ModelBase + + +class Model(ModelBase): + def __init__(self, config): + ModelBase.__init__(self, config) + + def xdeepfm_net(self): + init_value_ = 0.1 + initer = fluid.initializer.TruncatedNormalInitializer( + loc=0.0, scale=init_value_) + + is_distributed = True if envs.get_trainer() == "CtrTrainer" else False + sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number", None, self._namespace) + sparse_feature_dim = envs.get_global_env("hyper_parameters.sparse_feature_dim", None, self._namespace) + + # ------------------------- network input -------------------------- + + num_field = envs.get_global_env("hyper_parameters.num_field", None, self._namespace) + raw_feat_idx = fluid.data(name='feat_idx', shape=[None, num_field], dtype='int64') + raw_feat_value = fluid.data(name='feat_value', shape=[None, num_field], dtype='float32') + self.label = fluid.data(name='label', shape=[None, 1], dtype='float32') # None * 1 + feat_idx = fluid.layers.reshape(raw_feat_idx, [-1, 1]) # (None * num_field) * 1 + feat_value = fluid.layers.reshape(raw_feat_value, [-1, num_field, 1]) # None * num_field * 1 + + feat_embeddings = fluid.embedding( + input=feat_idx, + is_sparse=True, + dtype='float32', + size=[sparse_feature_number + 1, sparse_feature_dim], + padding_idx=0, + param_attr=fluid.ParamAttr(initializer=initer)) + feat_embeddings = fluid.layers.reshape( + feat_embeddings, + [-1, num_field, sparse_feature_dim]) # None * num_field * embedding_size + feat_embeddings = feat_embeddings * feat_value # None * num_field * embedding_size + + # ------------------------- set _data_var -------------------------- + + self._data_var.append(raw_feat_idx) + self._data_var.append(raw_feat_value) + self._data_var.append(self.label) + if self._platform != "LINUX": + self._data_loader = fluid.io.DataLoader.from_generator( + feed_list=self._data_var, capacity=64, use_double_buffer=False, iterable=False) + + # -------------------- linear -------------------- + + weights_linear = fluid.embedding( + input=feat_idx, + is_sparse=True, + dtype='float32', + size=[sparse_feature_number + 1, 1], + padding_idx=0, + param_attr=fluid.ParamAttr(initializer=initer)) + weights_linear = fluid.layers.reshape( + weights_linear, [-1, num_field, 1]) # None * num_field * 1 + b_linear = fluid.layers.create_parameter( + shape=[1], + dtype='float32', + default_initializer=fluid.initializer.ConstantInitializer(value=0)) + y_linear = fluid.layers.reduce_sum( + (weights_linear * feat_value), 1) + b_linear + + # -------------------- CIN -------------------- + + layer_sizes_cin = envs.get_global_env("hyper_parameters.layer_sizes_cin", None, self._namespace) + Xs = [feat_embeddings] + last_s = num_field + for s in layer_sizes_cin: + # calculate Z^(k+1) with X^k and X^0 + X_0 = fluid.layers.reshape( + fluid.layers.transpose(Xs[0], [0, 2, 1]), + [-1, sparse_feature_dim, num_field, + 1]) # None, embedding_size, num_field, 1 + X_k = fluid.layers.reshape( + fluid.layers.transpose(Xs[-1], [0, 2, 1]), + [-1, sparse_feature_dim, 1, last_s]) # None, embedding_size, 1, last_s + Z_k_1 = fluid.layers.matmul( + X_0, X_k) # None, embedding_size, num_field, last_s + + # compresses Z^(k+1) to X^(k+1) + Z_k_1 = fluid.layers.reshape(Z_k_1, [ + -1, sparse_feature_dim, last_s * num_field + ]) # None, embedding_size, last_s*num_field + Z_k_1 = fluid.layers.transpose( + Z_k_1, [0, 2, 1]) # None, s*num_field, embedding_size + Z_k_1 = fluid.layers.reshape( + Z_k_1, [-1, last_s * num_field, 1, sparse_feature_dim] + ) # None, last_s*num_field, 1, embedding_size (None, channal_in, h, w) + X_k_1 = fluid.layers.conv2d( + Z_k_1, + num_filters=s, + filter_size=(1, 1), + act=None, + bias_attr=False, + param_attr=fluid.ParamAttr( + initializer=initer)) # None, s, 1, embedding_size + X_k_1 = fluid.layers.reshape( + X_k_1, [-1, s, sparse_feature_dim]) # None, s, embedding_size + + Xs.append(X_k_1) + last_s = s + + # sum pooling + y_cin = fluid.layers.concat(Xs[1:], + 1) # None, (num_field++), embedding_size + y_cin = fluid.layers.reduce_sum(y_cin, -1) # None, (num_field++) + y_cin = fluid.layers.fc(input=y_cin, + size=1, + act=None, + param_attr=fluid.ParamAttr(initializer=initer), + bias_attr=None) + y_cin = fluid.layers.reduce_sum(y_cin, dim=-1, keep_dim=True) + + # -------------------- DNN -------------------- + + layer_sizes_dnn = envs.get_global_env("hyper_parameters.layer_sizes_dnn", None, self._namespace) + act = envs.get_global_env("hyper_parameters.act", None, self._namespace) + y_dnn = fluid.layers.reshape(feat_embeddings, + [-1, num_field * sparse_feature_dim]) + for s in layer_sizes_dnn: + y_dnn = fluid.layers.fc(input=y_dnn, + size=s, + act=act, + param_attr=fluid.ParamAttr(initializer=initer), + bias_attr=None) + y_dnn = fluid.layers.fc(input=y_dnn, + size=1, + act=None, + param_attr=fluid.ParamAttr(initializer=initer), + bias_attr=None) + + # ------------------- xDeepFM ------------------ + + self.predict = fluid.layers.sigmoid(y_linear + y_cin + y_dnn) + + def train_net(self): + self.xdeepfm_net() + + cost = fluid.layers.log_loss(input=self.predict, label=self.label, epsilon=0.0000001) + batch_cost = fluid.layers.reduce_mean(cost) + self._cost = batch_cost + + # for auc + predict_2d = fluid.layers.concat([1 - self.predict, self.predict], 1) + label_int = fluid.layers.cast(self.label, 'int64') + auc_var, batch_auc_var, _ = fluid.layers.auc(input=predict_2d, + label=label_int, + slide_steps=0) + self._metrics["AUC"] = auc_var + self._metrics["BATCH_AUC"] = batch_auc_var + + def optimizer(self): + learning_rate = envs.get_global_env("hyper_parameters.learning_rate", None, self._namespace) + optimizer = fluid.optimizer.Adam(learning_rate, lazy_mode=True) + return optimizer + + def infer_net(self, parameter_list): + self.xdeepfm_net() \ No newline at end of file