提交 48df2159 编写于 作者: M malin10

fix hyper_parameters

上级 a7741b38
...@@ -32,8 +32,8 @@ hyper_parameters: ...@@ -32,8 +32,8 @@ hyper_parameters:
class: sgd class: sgd
learning_rate: 0.01 learning_rate: 0.01
strategy: async strategy: async
TRIGRAM_D: 1000 trigram_d: 1000
NEG: 4 neg_num: 4
fc_sizes: [300, 300, 128] fc_sizes: [300, 300, 128]
fc_acts: ['tanh', 'tanh', 'tanh'] fc_acts: ['tanh', 'tanh', 'tanh']
......
...@@ -23,8 +23,8 @@ class Model(ModelBase): ...@@ -23,8 +23,8 @@ class Model(ModelBase):
ModelBase.__init__(self, config) ModelBase.__init__(self, config)
def _init_hyper_parameters(self): def _init_hyper_parameters(self):
self.TRIGRAM_D = envs.get_global_env("hyper_parameters.TRIGRAM_D") self.trigram_d = envs.get_global_env("hyper_parameters.trigram_d")
self.Neg = envs.get_global_env("hyper_parameters.NEG") self.neg_num = envs.get_global_env("hyper_parameters.neg_num")
self.hidden_layers = envs.get_global_env("hyper_parameters.fc_sizes") self.hidden_layers = envs.get_global_env("hyper_parameters.fc_sizes")
self.hidden_acts = envs.get_global_env("hyper_parameters.fc_acts") self.hidden_acts = envs.get_global_env("hyper_parameters.fc_acts")
self.learning_rate = envs.get_global_env( self.learning_rate = envs.get_global_env(
...@@ -33,12 +33,12 @@ class Model(ModelBase): ...@@ -33,12 +33,12 @@ class Model(ModelBase):
def input_data(self, is_infer=False, **kwargs): def input_data(self, is_infer=False, **kwargs):
query = fluid.data( query = fluid.data(
name="query", name="query",
shape=[-1, self.TRIGRAM_D], shape=[-1, self.trigram_d],
dtype='float32', dtype='float32',
lod_level=0) lod_level=0)
doc_pos = fluid.data( doc_pos = fluid.data(
name="doc_pos", name="doc_pos",
shape=[-1, self.TRIGRAM_D], shape=[-1, self.trigram_d],
dtype='float32', dtype='float32',
lod_level=0) lod_level=0)
...@@ -48,9 +48,9 @@ class Model(ModelBase): ...@@ -48,9 +48,9 @@ class Model(ModelBase):
doc_negs = [ doc_negs = [
fluid.data( fluid.data(
name="doc_neg_" + str(i), name="doc_neg_" + str(i),
shape=[-1, self.TRIGRAM_D], shape=[-1, self.trigram_d],
dtype="float32", dtype="float32",
lod_level=0) for i in range(self.Neg) lod_level=0) for i in range(self.neg_num)
] ]
return [query, doc_pos] + doc_negs return [query, doc_pos] + doc_negs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册