From 8de0aa46aae6e901e0101dec4e68e9d5c84ce4fa Mon Sep 17 00:00:00 2001 From: chengmo Date: Wed, 6 May 2020 16:53:46 +0800 Subject: [PATCH] fix tdm trainer --- fleet_rec/core/trainers/tdm_single_trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/fleet_rec/core/trainers/tdm_single_trainer.py b/fleet_rec/core/trainers/tdm_single_trainer.py index a894a029..11874645 100644 --- a/fleet_rec/core/trainers/tdm_single_trainer.py +++ b/fleet_rec/core/trainers/tdm_single_trainer.py @@ -28,6 +28,8 @@ import numpy as np logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger("fluid") logger.setLevel(logging.INFO) +special_param = ["TDM_Tree_Travel", "TDM_Tree_Layer", + "TDM_Tree_Info", "TDM_Tree_Emb"] class TDMSingleTrainer(SingleTrainer): @@ -95,13 +97,13 @@ class TDMSingleTrainer(SingleTrainer): if load_tree: # 将明文树结构及数据,set到组网中的Variale中 # 不使用NumpyInitialize方法是考虑到树结构相关数据size过大,有性能风险 - for param_name in Numpy_model: + for param_name in special_param: param_t = fluid.global_scope().find_var(param_name).get_tensor() param_array = self.tdm_prepare(param_name) if param_name == 'TDM_Tree_Emb': - param_t.set(param_array.astype('float32'), place) + param_t.set(param_array.astype('float32'), self._place) else: - param_t.set(param_array.astype('int32'), place) + param_t.set(param_array.astype('int32'), self._place) if save_init_model: logger.info("Begin Save Init model.") -- GitLab