diff --git a/fleet_rec/core/trainers/tdm_cluster_trainer.py b/fleet_rec/core/trainers/tdm_cluster_trainer.py index 336a0581215748dde2122968e95fb451c7247c77..ba7bd72b0271b7c44e3805dabac5f4a95ff1082d 100644 --- a/fleet_rec/core/trainers/tdm_cluster_trainer.py +++ b/fleet_rec/core/trainers/tdm_cluster_trainer.py @@ -50,13 +50,13 @@ class TDMClusterTrainer(TranspileTrainer): namespace = "train.startup" load_tree = envs.get_global_env( - "cluster.load_tree", True, namespace) + "tree.load_tree", True, namespace) self.tree_layer_path = envs.get_global_env( - "cluster.tree_layer_path", "", namespace) + "tree.tree_layer_path", "", namespace) self.tree_travel_path = envs.get_global_env( - "cluster.tree_travel_path", "", namespace) + "tree.tree_travel_path", "", namespace) self.tree_info_path = envs.get_global_env( - "cluster.tree_info_path", "", namespace) + "tree.tree_info_path", "", namespace) save_init_model = envs.get_global_env( "cluster.save_init_model", False, namespace) diff --git a/models/recall/tdm/config.yaml b/models/recall/tdm/config.yaml index edcd3a932123d14c518e1a12f3dd750ff91b9ff9..0e575804d373dbadb2462b881b3fe8fb5cde1643 100644 --- a/models/recall/tdm/config.yaml +++ b/models/recall/tdm/config.yaml @@ -54,7 +54,6 @@ train: tree_travel_path: "{workspace}/tree/travel_list.npy" tree_info_path: "{workspace}/tree/tree_info.npy" tree_emb_path: "{workspace}/tree/tree_emb.npy" - single: load_persistables: False persistables_model_path: ""