diff --git a/fleet_rec/core/trainers/tdm_single_trainer.py b/fleet_rec/core/trainers/tdm_single_trainer.py index a894a02904b9d12f4aa889fe0ddb0d834d847d16..118746458dc27459c3aa312d337004489f7da1a4 100644 --- a/fleet_rec/core/trainers/tdm_single_trainer.py +++ b/fleet_rec/core/trainers/tdm_single_trainer.py @@ -28,6 +28,8 @@ import numpy as np logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger("fluid") logger.setLevel(logging.INFO) +special_param = ["TDM_Tree_Travel", "TDM_Tree_Layer", + "TDM_Tree_Info", "TDM_Tree_Emb"] class TDMSingleTrainer(SingleTrainer): @@ -95,13 +97,13 @@ class TDMSingleTrainer(SingleTrainer): if load_tree: # 将明文树结构及数据,set到组网中的Variale中 # 不使用NumpyInitialize方法是考虑到树结构相关数据size过大,有性能风险 - for param_name in Numpy_model: + for param_name in special_param: param_t = fluid.global_scope().find_var(param_name).get_tensor() param_array = self.tdm_prepare(param_name) if param_name == 'TDM_Tree_Emb': - param_t.set(param_array.astype('float32'), place) + param_t.set(param_array.astype('float32'), self._place) else: - param_t.set(param_array.astype('int32'), place) + param_t.set(param_array.astype('int32'), self._place) if save_init_model: logger.info("Begin Save Init model.")