提交 9f64b843 编写于 作者: C chengmo

update & ifx

上级 12c654fe
......@@ -19,15 +19,22 @@ import yaml
from fleetrec.core.utils import envs
trainer_abs = os.path.join(os.path.dirname(os.path.abspath(__file__)), "trainers")
trainer_abs = os.path.join(os.path.dirname(
os.path.abspath(__file__)), "trainers")
trainers = {}
def trainer_registry():
trainers["SingleTrainer"] = os.path.join(trainer_abs, "single_trainer.py")
trainers["ClusterTrainer"] = os.path.join(trainer_abs, "cluster_trainer.py")
trainers["CtrCodingTrainer"] = os.path.join(trainer_abs, "ctr_coding_trainer.py")
trainers["CtrModulTrainer"] = os.path.join(trainer_abs, "ctr_modul_trainer.py")
trainers["SingleTrainer"] = os.path.join(
trainer_abs, "single_trainer.py")
trainers["ClusterTrainer"] = os.path.join(
trainer_abs, "cluster_trainer.py")
trainers["CtrCodingTrainer"] = os.path.join(
trainer_abs, "ctr_coding_trainer.py")
trainers["CtrModulTrainer"] = os.path.join(
trainer_abs, "ctr_modul_trainer.py")
trainers["TDMSingleTrainer"] = os.path.join(
trainer_abs, "tdm_single_trainer.py")
trainer_registry()
......@@ -46,7 +53,8 @@ class TrainerFactory(object):
if trainer_abs is None:
if not os.path.isfile(train_mode):
raise IOError("trainer {} can not be recognized".format(train_mode))
raise IOError(
"trainer {} can not be recognized".format(train_mode))
trainer_abs = train_mode
train_mode = "UserDefineTrainer"
......
......@@ -30,7 +30,7 @@ logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
class TdmSingleTrainer(SingleTrainer):
class TDMSingleTrainer(SingleTrainer):
def processor_register(self):
self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init)
......
......@@ -202,7 +202,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description='fleet-rec run')
parser.add_argument("-m", "--model", type=str)
parser.add_argument("-e", "--engine", type=str,
choices=["single", "local_cluster", "cluster"])
choices=["single", "local_cluster", "cluster", "tdm_single"])
parser.add_argument("-d", "--device", type=str,
choices=["cpu", "gpu"], default="cpu")
......
......@@ -17,20 +17,20 @@ train:
# for cluster training
strategy: "async"
epochs: 10
epochs: 4
workspace: "fleetrec.models.recall.tdm"
reader:
batch_size: 32
class: "{workspace}/tdm_reader.py"
train_data_path: "{workspace}/data/train_data"
test_data_path: "{workspace}/data/test_data"
train_data_path: "{workspace}/data/train"
test_data_path: "{workspace}/data/test"
model:
models: "{workspace}/model.py"
hyper_parameters:
node_emb_size: 64
input_emb_size: 64
input_emb_size: 768
neg_sampling_list: [1, 2, 3, 4]
output_positive: True
topK: 1
......@@ -52,10 +52,10 @@ train:
persistables_model_path: ""
load_tree: True
tree_layer_path: ""
tree_travel_path: ""
tree_info_path: ""
tree_emb_path: ""
tree_layer_path: "{workspace}/tree/layer_list.txt"
tree_travel_path: "{workspace}/tree/travel_list.npy"
tree_info_path: "{workspace}/tree/tree_info.npy"
tree_emb_path: "{workspace}/tree/tree_emb.npy"
save_init_model: True
init_model_path: ""
......
......@@ -45,7 +45,7 @@ class Model(ModelBase):
self.node_emb_size = envs.get_global_env(
"hyper_parameters.node_emb_size", 64, self._namespace)
self.input_emb_size = envs.get_global_env(
"hyper_parameters.input_emb_size", 64, self._namespace)
"hyper_parameters.input_emb_size", 768, self._namespace)
self.act = envs.get_global_env(
"hyper_parameters.act", "tanh", self._namespace)
self.neg_sampling_list = envs.get_global_env(
......@@ -61,6 +61,7 @@ class Model(ModelBase):
def train_net(self):
self.train_input()
self.tdm_net()
self.create_info()
self.avg_loss()
self.metrics()
......@@ -174,11 +175,26 @@ class Model(ModelBase):
mask_index.stop_gradient = True
self.mask_cost = fluid.layers.gather_nd(cost, mask_index)
softmax_prob = fluid.layers.unsqueeze(input=softmax_prob, axes=[1])
self.mask_prob = fluid.layers.gather_nd(softmax_prob, mask_index)
self.mask_label = fluid.layers.gather_nd(labels_reshape, mask_index)
self._predict = self.mask_prob
def create_info(self):
fluid.default_startup_program().global_block().create_var(
name="TDM_Tree_Info",
dtype=fluid.core.VarDesc.VarType.INT32,
shape=[self.node_nums, 3 + self.child_nums],
persistable=True,
initializer=fluid.initializer.ConstantInitializer(0))
fluid.default_main_program().global_block().create_var(
name="TDM_Tree_Info",
dtype=fluid.core.VarDesc.VarType.INT32,
shape=[self.node_nums, 3 + self.child_nums],
persistable=True)
def avg_loss(self):
avg_cost = fluid.layers.reduce_mean(self.mask_cost)
self._cost = avg_cost
......
......@@ -18,16 +18,17 @@
from __future__ import print_function
from fleetrec.core.reader import Reader
from fleetrec.core.utils import envs
class TrainReader(reader):
class TrainReader(Reader):
def init(self):
pass
def reader(self, line):
def generate_sample(self, line):
"""
Read the data line by line and process it as a dictionary
"""
def iterator():
def reader():
"""
This function needs to be implemented by the user, based on data format
"""
......@@ -38,4 +39,4 @@ class TrainReader(reader):
feature_name = ["input_emb", "item_label"]
yield zip(feature_name, [input_emb] + [item_label])
return Reader
return reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册