提交 14c1e75e 编写于 作者: Y yangyaming

Make sure lod of state kept consistent.

上级 abfe896e
......@@ -195,7 +195,6 @@ class StateCell(object):
'Please make sure %s in input '
'place holder.' % (input_name, input_name))
self._inputs[input_name] = input_value
self._state_updater(self)
def update_states(self):
......
......@@ -64,7 +64,8 @@ def decoder_state_cell(context):
def updater(state_cell):
current_word = state_cell.get_input('x')
prev_h = state_cell.get_state('h')
h = pd.fc(input=[current_word, prev_h], size=decoder_size, act='tanh')
# make sure lod of h heritted from prev_h
h = pd.fc(input=[prev_h, current_word], size=decoder_size, act='tanh')
state_cell.set_state('h', h)
return state_cell
......@@ -101,9 +102,9 @@ def decoder_decode(state_cell):
name="init_scores", shape=[1], dtype="float32", lod_level=2)
def embedding(input):
pd.embedding(
return pd.embedding(
input=input,
size=[dict_dim, word_dim],
size=[dict_size, word_dim],
dtype='float32',
is_sparse=IS_SPARSE,
param_attr=fluid.ParamAttr('vemb'))
......@@ -237,5 +238,5 @@ def decode_main():
if __name__ == '__main__':
train_main()
#decode_main()
#train_main()
decode_main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册