提交 9f64b843 编写于 作者: C chengmo

update & ifx

上级 12c654fe
...@@ -19,15 +19,22 @@ import yaml ...@@ -19,15 +19,22 @@ import yaml
from fleetrec.core.utils import envs from fleetrec.core.utils import envs
trainer_abs = os.path.join(os.path.dirname(os.path.abspath(__file__)), "trainers") trainer_abs = os.path.join(os.path.dirname(
os.path.abspath(__file__)), "trainers")
trainers = {} trainers = {}
def trainer_registry(): def trainer_registry():
trainers["SingleTrainer"] = os.path.join(trainer_abs, "single_trainer.py") trainers["SingleTrainer"] = os.path.join(
trainers["ClusterTrainer"] = os.path.join(trainer_abs, "cluster_trainer.py") trainer_abs, "single_trainer.py")
trainers["CtrCodingTrainer"] = os.path.join(trainer_abs, "ctr_coding_trainer.py") trainers["ClusterTrainer"] = os.path.join(
trainers["CtrModulTrainer"] = os.path.join(trainer_abs, "ctr_modul_trainer.py") trainer_abs, "cluster_trainer.py")
trainers["CtrCodingTrainer"] = os.path.join(
trainer_abs, "ctr_coding_trainer.py")
trainers["CtrModulTrainer"] = os.path.join(
trainer_abs, "ctr_modul_trainer.py")
trainers["TDMSingleTrainer"] = os.path.join(
trainer_abs, "tdm_single_trainer.py")
trainer_registry() trainer_registry()
...@@ -46,7 +53,8 @@ class TrainerFactory(object): ...@@ -46,7 +53,8 @@ class TrainerFactory(object):
if trainer_abs is None: if trainer_abs is None:
if not os.path.isfile(train_mode): if not os.path.isfile(train_mode):
raise IOError("trainer {} can not be recognized".format(train_mode)) raise IOError(
"trainer {} can not be recognized".format(train_mode))
trainer_abs = train_mode trainer_abs = train_mode
train_mode = "UserDefineTrainer" train_mode = "UserDefineTrainer"
......
...@@ -30,7 +30,7 @@ logger = logging.getLogger("fluid") ...@@ -30,7 +30,7 @@ logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
class TdmSingleTrainer(SingleTrainer): class TDMSingleTrainer(SingleTrainer):
def processor_register(self): def processor_register(self):
self.regist_context_processor('uninit', self.instance) self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init) self.regist_context_processor('init_pass', self.init)
......
...@@ -202,7 +202,7 @@ if __name__ == "__main__": ...@@ -202,7 +202,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description='fleet-rec run') parser = argparse.ArgumentParser(description='fleet-rec run')
parser.add_argument("-m", "--model", type=str) parser.add_argument("-m", "--model", type=str)
parser.add_argument("-e", "--engine", type=str, parser.add_argument("-e", "--engine", type=str,
choices=["single", "local_cluster", "cluster"]) choices=["single", "local_cluster", "cluster", "tdm_single"])
parser.add_argument("-d", "--device", type=str, parser.add_argument("-d", "--device", type=str,
choices=["cpu", "gpu"], default="cpu") choices=["cpu", "gpu"], default="cpu")
......
...@@ -17,20 +17,20 @@ train: ...@@ -17,20 +17,20 @@ train:
# for cluster training # for cluster training
strategy: "async" strategy: "async"
epochs: 10 epochs: 4
workspace: "fleetrec.models.recall.tdm" workspace: "fleetrec.models.recall.tdm"
reader: reader:
batch_size: 32 batch_size: 32
class: "{workspace}/tdm_reader.py" class: "{workspace}/tdm_reader.py"
train_data_path: "{workspace}/data/train_data" train_data_path: "{workspace}/data/train"
test_data_path: "{workspace}/data/test_data" test_data_path: "{workspace}/data/test"
model: model:
models: "{workspace}/model.py" models: "{workspace}/model.py"
hyper_parameters: hyper_parameters:
node_emb_size: 64 node_emb_size: 64
input_emb_size: 64 input_emb_size: 768
neg_sampling_list: [1, 2, 3, 4] neg_sampling_list: [1, 2, 3, 4]
output_positive: True output_positive: True
topK: 1 topK: 1
...@@ -52,10 +52,10 @@ train: ...@@ -52,10 +52,10 @@ train:
persistables_model_path: "" persistables_model_path: ""
load_tree: True load_tree: True
tree_layer_path: "" tree_layer_path: "{workspace}/tree/layer_list.txt"
tree_travel_path: "" tree_travel_path: "{workspace}/tree/travel_list.npy"
tree_info_path: "" tree_info_path: "{workspace}/tree/tree_info.npy"
tree_emb_path: "" tree_emb_path: "{workspace}/tree/tree_emb.npy"
save_init_model: True save_init_model: True
init_model_path: "" init_model_path: ""
......
...@@ -45,7 +45,7 @@ class Model(ModelBase): ...@@ -45,7 +45,7 @@ class Model(ModelBase):
self.node_emb_size = envs.get_global_env( self.node_emb_size = envs.get_global_env(
"hyper_parameters.node_emb_size", 64, self._namespace) "hyper_parameters.node_emb_size", 64, self._namespace)
self.input_emb_size = envs.get_global_env( self.input_emb_size = envs.get_global_env(
"hyper_parameters.input_emb_size", 64, self._namespace) "hyper_parameters.input_emb_size", 768, self._namespace)
self.act = envs.get_global_env( self.act = envs.get_global_env(
"hyper_parameters.act", "tanh", self._namespace) "hyper_parameters.act", "tanh", self._namespace)
self.neg_sampling_list = envs.get_global_env( self.neg_sampling_list = envs.get_global_env(
...@@ -61,6 +61,7 @@ class Model(ModelBase): ...@@ -61,6 +61,7 @@ class Model(ModelBase):
def train_net(self): def train_net(self):
self.train_input() self.train_input()
self.tdm_net() self.tdm_net()
self.create_info()
self.avg_loss() self.avg_loss()
self.metrics() self.metrics()
...@@ -174,11 +175,26 @@ class Model(ModelBase): ...@@ -174,11 +175,26 @@ class Model(ModelBase):
mask_index.stop_gradient = True mask_index.stop_gradient = True
self.mask_cost = fluid.layers.gather_nd(cost, mask_index) self.mask_cost = fluid.layers.gather_nd(cost, mask_index)
softmax_prob = fluid.layers.unsqueeze(input=softmax_prob, axes=[1])
self.mask_prob = fluid.layers.gather_nd(softmax_prob, mask_index) self.mask_prob = fluid.layers.gather_nd(softmax_prob, mask_index)
self.mask_label = fluid.layers.gather_nd(labels_reshape, mask_index) self.mask_label = fluid.layers.gather_nd(labels_reshape, mask_index)
self._predict = self.mask_prob self._predict = self.mask_prob
def create_info(self):
fluid.default_startup_program().global_block().create_var(
name="TDM_Tree_Info",
dtype=fluid.core.VarDesc.VarType.INT32,
shape=[self.node_nums, 3 + self.child_nums],
persistable=True,
initializer=fluid.initializer.ConstantInitializer(0))
fluid.default_main_program().global_block().create_var(
name="TDM_Tree_Info",
dtype=fluid.core.VarDesc.VarType.INT32,
shape=[self.node_nums, 3 + self.child_nums],
persistable=True)
def avg_loss(self): def avg_loss(self):
avg_cost = fluid.layers.reduce_mean(self.mask_cost) avg_cost = fluid.layers.reduce_mean(self.mask_cost)
self._cost = avg_cost self._cost = avg_cost
......
...@@ -18,16 +18,17 @@ ...@@ -18,16 +18,17 @@
from __future__ import print_function from __future__ import print_function
from fleetrec.core.reader import Reader from fleetrec.core.reader import Reader
from fleetrec.core.utils import envs
class TrainReader(reader): class TrainReader(Reader):
def init(self):
pass
def reader(self, line): def generate_sample(self, line):
""" """
Read the data line by line and process it as a dictionary Read the data line by line and process it as a dictionary
""" """
def iterator(): def reader():
""" """
This function needs to be implemented by the user, based on data format This function needs to be implemented by the user, based on data format
""" """
...@@ -38,4 +39,4 @@ class TrainReader(reader): ...@@ -38,4 +39,4 @@ class TrainReader(reader):
feature_name = ["input_emb", "item_label"] feature_name = ["input_emb", "item_label"]
yield zip(feature_name, [input_emb] + [item_label]) yield zip(feature_name, [input_emb] + [item_label])
return Reader return reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册