提交 14c1e75e 编写于 作者: Y yangyaming

Make sure lod of state kept consistent.

上级 abfe896e
...@@ -195,7 +195,6 @@ class StateCell(object): ...@@ -195,7 +195,6 @@ class StateCell(object):
'Please make sure %s in input ' 'Please make sure %s in input '
'place holder.' % (input_name, input_name)) 'place holder.' % (input_name, input_name))
self._inputs[input_name] = input_value self._inputs[input_name] = input_value
self._state_updater(self) self._state_updater(self)
def update_states(self): def update_states(self):
......
...@@ -64,7 +64,8 @@ def decoder_state_cell(context): ...@@ -64,7 +64,8 @@ def decoder_state_cell(context):
def updater(state_cell): def updater(state_cell):
current_word = state_cell.get_input('x') current_word = state_cell.get_input('x')
prev_h = state_cell.get_state('h') prev_h = state_cell.get_state('h')
h = pd.fc(input=[current_word, prev_h], size=decoder_size, act='tanh') # make sure lod of h heritted from prev_h
h = pd.fc(input=[prev_h, current_word], size=decoder_size, act='tanh')
state_cell.set_state('h', h) state_cell.set_state('h', h)
return state_cell return state_cell
...@@ -101,9 +102,9 @@ def decoder_decode(state_cell): ...@@ -101,9 +102,9 @@ def decoder_decode(state_cell):
name="init_scores", shape=[1], dtype="float32", lod_level=2) name="init_scores", shape=[1], dtype="float32", lod_level=2)
def embedding(input): def embedding(input):
pd.embedding( return pd.embedding(
input=input, input=input,
size=[dict_dim, word_dim], size=[dict_size, word_dim],
dtype='float32', dtype='float32',
is_sparse=IS_SPARSE, is_sparse=IS_SPARSE,
param_attr=fluid.ParamAttr('vemb')) param_attr=fluid.ParamAttr('vemb'))
...@@ -237,5 +238,5 @@ def decode_main(): ...@@ -237,5 +238,5 @@ def decode_main():
if __name__ == '__main__': if __name__ == '__main__':
train_main() #train_main()
#decode_main() decode_main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册