diff --git a/fleet_rec/core/factory.py b/fleet_rec/core/factory.py index bbcc05f8e8388e25ced775c184361504136b5aee..817ccbba657cebfdb89b2d38160950cfcbd33a60 100644 --- a/fleet_rec/core/factory.py +++ b/fleet_rec/core/factory.py @@ -19,15 +19,22 @@ import yaml from fleetrec.core.utils import envs -trainer_abs = os.path.join(os.path.dirname(os.path.abspath(__file__)), "trainers") +trainer_abs = os.path.join(os.path.dirname( + os.path.abspath(__file__)), "trainers") trainers = {} def trainer_registry(): - trainers["SingleTrainer"] = os.path.join(trainer_abs, "single_trainer.py") - trainers["ClusterTrainer"] = os.path.join(trainer_abs, "cluster_trainer.py") - trainers["CtrCodingTrainer"] = os.path.join(trainer_abs, "ctr_coding_trainer.py") - trainers["CtrModulTrainer"] = os.path.join(trainer_abs, "ctr_modul_trainer.py") + trainers["SingleTrainer"] = os.path.join( + trainer_abs, "single_trainer.py") + trainers["ClusterTrainer"] = os.path.join( + trainer_abs, "cluster_trainer.py") + trainers["CtrCodingTrainer"] = os.path.join( + trainer_abs, "ctr_coding_trainer.py") + trainers["CtrModulTrainer"] = os.path.join( + trainer_abs, "ctr_modul_trainer.py") + trainers["TDMSingleTrainer"] = os.path.join( + trainer_abs, "tdm_single_trainer.py") trainer_registry() @@ -46,7 +53,8 @@ class TrainerFactory(object): if trainer_abs is None: if not os.path.isfile(train_mode): - raise IOError("trainer {} can not be recognized".format(train_mode)) + raise IOError( + "trainer {} can not be recognized".format(train_mode)) trainer_abs = train_mode train_mode = "UserDefineTrainer" diff --git a/fleet_rec/core/trainers/tdm_trainer.py b/fleet_rec/core/trainers/tdm_single_trainer.py similarity index 99% rename from fleet_rec/core/trainers/tdm_trainer.py rename to fleet_rec/core/trainers/tdm_single_trainer.py index d8d2974a2045d0d382a99b47c09c6b5764f002fc..a894a02904b9d12f4aa889fe0ddb0d834d847d16 100644 --- a/fleet_rec/core/trainers/tdm_trainer.py +++ b/fleet_rec/core/trainers/tdm_single_trainer.py @@ -30,7 +30,7 @@ logger = logging.getLogger("fluid") logger.setLevel(logging.INFO) -class TdmSingleTrainer(SingleTrainer): +class TDMSingleTrainer(SingleTrainer): def processor_register(self): self.regist_context_processor('uninit', self.instance) self.regist_context_processor('init_pass', self.init) diff --git a/fleet_rec/run.py b/fleet_rec/run.py index 8000b85f3a2469899d3d360ab776846de7b78bf4..24351f2fc474ccc3c7aa497ad93114db2ca1184b 100644 --- a/fleet_rec/run.py +++ b/fleet_rec/run.py @@ -202,7 +202,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description='fleet-rec run') parser.add_argument("-m", "--model", type=str) parser.add_argument("-e", "--engine", type=str, - choices=["single", "local_cluster", "cluster"]) + choices=["single", "local_cluster", "cluster", "tdm_single"]) parser.add_argument("-d", "--device", type=str, choices=["cpu", "gpu"], default="cpu") diff --git a/models/recall/tdm/config.yaml b/models/recall/tdm/config.yaml index e679dd0f5bb4417f41af0737b3cd84342c5c36a6..954da49ab398da216ddff2a4e8c612ae266e8cff 100644 --- a/models/recall/tdm/config.yaml +++ b/models/recall/tdm/config.yaml @@ -17,20 +17,20 @@ train: # for cluster training strategy: "async" - epochs: 10 + epochs: 4 workspace: "fleetrec.models.recall.tdm" reader: batch_size: 32 class: "{workspace}/tdm_reader.py" - train_data_path: "{workspace}/data/train_data" - test_data_path: "{workspace}/data/test_data" + train_data_path: "{workspace}/data/train" + test_data_path: "{workspace}/data/test" model: models: "{workspace}/model.py" hyper_parameters: node_emb_size: 64 - input_emb_size: 64 + input_emb_size: 768 neg_sampling_list: [1, 2, 3, 4] output_positive: True topK: 1 @@ -52,10 +52,10 @@ train: persistables_model_path: "" load_tree: True - tree_layer_path: "" - tree_travel_path: "" - tree_info_path: "" - tree_emb_path: "" + tree_layer_path: "{workspace}/tree/layer_list.txt" + tree_travel_path: "{workspace}/tree/travel_list.npy" + tree_info_path: "{workspace}/tree/tree_info.npy" + tree_emb_path: "{workspace}/tree/tree_emb.npy" save_init_model: True init_model_path: "" diff --git a/models/recall/tdm/model.py b/models/recall/tdm/model.py index 6da19cb66dc812720624f9eb4599d73fb273a1a6..89eedf90b3a4ced601aa94e3c6d28a2fa0618cbf 100644 --- a/models/recall/tdm/model.py +++ b/models/recall/tdm/model.py @@ -45,7 +45,7 @@ class Model(ModelBase): self.node_emb_size = envs.get_global_env( "hyper_parameters.node_emb_size", 64, self._namespace) self.input_emb_size = envs.get_global_env( - "hyper_parameters.input_emb_size", 64, self._namespace) + "hyper_parameters.input_emb_size", 768, self._namespace) self.act = envs.get_global_env( "hyper_parameters.act", "tanh", self._namespace) self.neg_sampling_list = envs.get_global_env( @@ -61,6 +61,7 @@ class Model(ModelBase): def train_net(self): self.train_input() self.tdm_net() + self.create_info() self.avg_loss() self.metrics() @@ -174,11 +175,26 @@ class Model(ModelBase): mask_index.stop_gradient = True self.mask_cost = fluid.layers.gather_nd(cost, mask_index) + + softmax_prob = fluid.layers.unsqueeze(input=softmax_prob, axes=[1]) self.mask_prob = fluid.layers.gather_nd(softmax_prob, mask_index) self.mask_label = fluid.layers.gather_nd(labels_reshape, mask_index) self._predict = self.mask_prob + def create_info(self): + fluid.default_startup_program().global_block().create_var( + name="TDM_Tree_Info", + dtype=fluid.core.VarDesc.VarType.INT32, + shape=[self.node_nums, 3 + self.child_nums], + persistable=True, + initializer=fluid.initializer.ConstantInitializer(0)) + fluid.default_main_program().global_block().create_var( + name="TDM_Tree_Info", + dtype=fluid.core.VarDesc.VarType.INT32, + shape=[self.node_nums, 3 + self.child_nums], + persistable=True) + def avg_loss(self): avg_cost = fluid.layers.reduce_mean(self.mask_cost) self._cost = avg_cost diff --git a/models/recall/tdm/tdm_reader.py b/models/recall/tdm/tdm_reader.py index a2a853f48453a509629ca9c24f1d6551f32c74e1..32d33aeb40cd94138395acc03556e93f634a86d5 100644 --- a/models/recall/tdm/tdm_reader.py +++ b/models/recall/tdm/tdm_reader.py @@ -18,16 +18,17 @@ from __future__ import print_function from fleetrec.core.reader import Reader -from fleetrec.core.utils import envs -class TrainReader(reader): +class TrainReader(Reader): + def init(self): + pass - def reader(self, line): + def generate_sample(self, line): """ Read the data line by line and process it as a dictionary """ - def iterator(): + def reader(): """ This function needs to be implemented by the user, based on data format """ @@ -38,4 +39,4 @@ class TrainReader(reader): feature_name = ["input_emb", "item_label"] yield zip(feature_name, [input_emb] + [item_label]) - return Reader + return reader