From 14c1e75ec9d0c8f3d7c4c66b3cea1a0b149e99c1 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Sun, 18 Mar 2018 17:26:48 +0800 Subject: [PATCH] Make sure lod of state kept consistent. --- fluid/rnn_beam_search/beam_search_api.py | 1 - fluid/rnn_beam_search/simple_seq2seq.py | 11 ++++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/fluid/rnn_beam_search/beam_search_api.py b/fluid/rnn_beam_search/beam_search_api.py index a62ee605..78489d36 100644 --- a/fluid/rnn_beam_search/beam_search_api.py +++ b/fluid/rnn_beam_search/beam_search_api.py @@ -195,7 +195,6 @@ class StateCell(object): 'Please make sure %s in input ' 'place holder.' % (input_name, input_name)) self._inputs[input_name] = input_value - self._state_updater(self) def update_states(self): diff --git a/fluid/rnn_beam_search/simple_seq2seq.py b/fluid/rnn_beam_search/simple_seq2seq.py index 37e380ea..eaf88789 100644 --- a/fluid/rnn_beam_search/simple_seq2seq.py +++ b/fluid/rnn_beam_search/simple_seq2seq.py @@ -64,7 +64,8 @@ def decoder_state_cell(context): def updater(state_cell): current_word = state_cell.get_input('x') prev_h = state_cell.get_state('h') - h = pd.fc(input=[current_word, prev_h], size=decoder_size, act='tanh') + # make sure lod of h heritted from prev_h + h = pd.fc(input=[prev_h, current_word], size=decoder_size, act='tanh') state_cell.set_state('h', h) return state_cell @@ -101,9 +102,9 @@ def decoder_decode(state_cell): name="init_scores", shape=[1], dtype="float32", lod_level=2) def embedding(input): - pd.embedding( + return pd.embedding( input=input, - size=[dict_dim, word_dim], + size=[dict_size, word_dim], dtype='float32', is_sparse=IS_SPARSE, param_attr=fluid.ParamAttr('vemb')) @@ -237,5 +238,5 @@ def decode_main(): if __name__ == '__main__': - train_main() - #decode_main() + #train_main() + decode_main() -- GitLab