提交 8de0aa46 编写于 作者: C chengmo

fix tdm trainer

上级 cb55d3b7
...@@ -28,6 +28,8 @@ import numpy as np ...@@ -28,6 +28,8 @@ import numpy as np
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid") logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
special_param = ["TDM_Tree_Travel", "TDM_Tree_Layer",
"TDM_Tree_Info", "TDM_Tree_Emb"]
class TDMSingleTrainer(SingleTrainer): class TDMSingleTrainer(SingleTrainer):
...@@ -95,13 +97,13 @@ class TDMSingleTrainer(SingleTrainer): ...@@ -95,13 +97,13 @@ class TDMSingleTrainer(SingleTrainer):
if load_tree: if load_tree:
# 将明文树结构及数据,set到组网中的Variale中 # 将明文树结构及数据,set到组网中的Variale中
# 不使用NumpyInitialize方法是考虑到树结构相关数据size过大,有性能风险 # 不使用NumpyInitialize方法是考虑到树结构相关数据size过大,有性能风险
for param_name in Numpy_model: for param_name in special_param:
param_t = fluid.global_scope().find_var(param_name).get_tensor() param_t = fluid.global_scope().find_var(param_name).get_tensor()
param_array = self.tdm_prepare(param_name) param_array = self.tdm_prepare(param_name)
if param_name == 'TDM_Tree_Emb': if param_name == 'TDM_Tree_Emb':
param_t.set(param_array.astype('float32'), place) param_t.set(param_array.astype('float32'), self._place)
else: else:
param_t.set(param_array.astype('int32'), place) param_t.set(param_array.astype('int32'), self._place)
if save_init_model: if save_init_model:
logger.info("Begin Save Init model.") logger.info("Begin Save Init model.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册