提交 48df2159 编写于 作者: M malin10

fix hyper_parameters

上级 a7741b38
......@@ -32,8 +32,8 @@ hyper_parameters:
class: sgd
learning_rate: 0.01
strategy: async
TRIGRAM_D: 1000
NEG: 4
trigram_d: 1000
neg_num: 4
fc_sizes: [300, 300, 128]
fc_acts: ['tanh', 'tanh', 'tanh']
......
......@@ -23,8 +23,8 @@ class Model(ModelBase):
ModelBase.__init__(self, config)
def _init_hyper_parameters(self):
self.TRIGRAM_D = envs.get_global_env("hyper_parameters.TRIGRAM_D")
self.Neg = envs.get_global_env("hyper_parameters.NEG")
self.trigram_d = envs.get_global_env("hyper_parameters.trigram_d")
self.neg_num = envs.get_global_env("hyper_parameters.neg_num")
self.hidden_layers = envs.get_global_env("hyper_parameters.fc_sizes")
self.hidden_acts = envs.get_global_env("hyper_parameters.fc_acts")
self.learning_rate = envs.get_global_env(
......@@ -33,12 +33,12 @@ class Model(ModelBase):
def input_data(self, is_infer=False, **kwargs):
query = fluid.data(
name="query",
shape=[-1, self.TRIGRAM_D],
shape=[-1, self.trigram_d],
dtype='float32',
lod_level=0)
doc_pos = fluid.data(
name="doc_pos",
shape=[-1, self.TRIGRAM_D],
shape=[-1, self.trigram_d],
dtype='float32',
lod_level=0)
......@@ -48,9 +48,9 @@ class Model(ModelBase):
doc_negs = [
fluid.data(
name="doc_neg_" + str(i),
shape=[-1, self.TRIGRAM_D],
shape=[-1, self.trigram_d],
dtype="float32",
lod_level=0) for i in range(self.Neg)
lod_level=0) for i in range(self.neg_num)
]
return [query, doc_pos] + doc_negs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册