提交 ee6bd53b 编写于 作者: F frankwhzhang

add hyper_parameters

上级 f49acc00
...@@ -37,6 +37,10 @@ class Model(object): ...@@ -37,6 +37,10 @@ class Model(object):
self._fetch_interval = 20 self._fetch_interval = 20
self._namespace = "train.model" self._namespace = "train.model"
self._platform = envs.get_platform() self._platform = envs.get_platform()
self._init_hyper_parameters()
def _init_hyper_parameters(self):
pass
def _init_slots(self): def _init_slots(self):
sparse_slots = envs.get_global_env("sparse_slots", None, sparse_slots = envs.get_global_env("sparse_slots", None,
......
...@@ -23,6 +23,8 @@ from paddlerec.core.model import Model as ModelBase ...@@ -23,6 +23,8 @@ from paddlerec.core.model import Model as ModelBase
class Model(ModelBase): class Model(ModelBase):
def __init__(self, config): def __init__(self, config):
ModelBase.__init__(self, config) ModelBase.__init__(self, config)
def _init_hyper_parameters(self):
self.item_len = envs.get_global_env("hyper_parameters.self.item_len", self.item_len = envs.get_global_env("hyper_parameters.self.item_len",
None, self._namespace) None, self._namespace)
self.hidden_size = envs.get_global_env("hyper_parameters.hidden_size", self.hidden_size = envs.get_global_env("hyper_parameters.hidden_size",
...@@ -65,7 +67,7 @@ class Model(ModelBase): ...@@ -65,7 +67,7 @@ class Model(ModelBase):
self._data_var = inputs self._data_var = inputs
self._data_loader = fluid.io.DataLoader.from_generator( self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._data_var, feed_list=self._data_var,
capacity=10000, capacity=64,
use_double_buffer=False, use_double_buffer=False,
iterable=False) iterable=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册