diff --git a/fluid/rnn_beam_search/beam_search_api.py b/fluid/rnn_beam_search/beam_search_api.py index a62ee605f937c200eebaffdcf132151a1574887a..78489d365c68a71afb8577cfb12953e431cf3b87 100644 --- a/fluid/rnn_beam_search/beam_search_api.py +++ b/fluid/rnn_beam_search/beam_search_api.py @@ -195,7 +195,6 @@ class StateCell(object): 'Please make sure %s in input ' 'place holder.' % (input_name, input_name)) self._inputs[input_name] = input_value - self._state_updater(self) def update_states(self): diff --git a/fluid/rnn_beam_search/simple_seq2seq.py b/fluid/rnn_beam_search/simple_seq2seq.py index 37e380ea0f08cea30e80d6d9d82d66259b6d8853..eaf887890e8d4b050702c265902053c07a5f68fd 100644 --- a/fluid/rnn_beam_search/simple_seq2seq.py +++ b/fluid/rnn_beam_search/simple_seq2seq.py @@ -64,7 +64,8 @@ def decoder_state_cell(context): def updater(state_cell): current_word = state_cell.get_input('x') prev_h = state_cell.get_state('h') - h = pd.fc(input=[current_word, prev_h], size=decoder_size, act='tanh') + # make sure lod of h heritted from prev_h + h = pd.fc(input=[prev_h, current_word], size=decoder_size, act='tanh') state_cell.set_state('h', h) return state_cell @@ -101,9 +102,9 @@ def decoder_decode(state_cell): name="init_scores", shape=[1], dtype="float32", lod_level=2) def embedding(input): - pd.embedding( + return pd.embedding( input=input, - size=[dict_dim, word_dim], + size=[dict_size, word_dim], dtype='float32', is_sparse=IS_SPARSE, param_attr=fluid.ParamAttr('vemb')) @@ -237,5 +238,5 @@ def decode_main(): if __name__ == '__main__': - train_main() - #decode_main() + #train_main() + decode_main()