diff --git a/fluid/rnn_beam_search/attention_seq2seq.py b/fluid/rnn_beam_search/attention_seq2seq.py index 8eb547befb82011047a59a0fa318e517246a2756..9eb6bc3f24f4f4272e74d6b5d8ba52ad6204a917 100644 --- a/fluid/rnn_beam_search/attention_seq2seq.py +++ b/fluid/rnn_beam_search/attention_seq2seq.py @@ -181,7 +181,8 @@ def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim, context = fluid.layers.sequence_pool(input=scaled, pool_type='sum') return context - def updater(state_cell): + @state_cell.state_updater + def state_updater(state_cell): current_word = state_cell.get_input('x') encoder_vec = state_cell.get_input('encoder_vec') encoder_proj = state_cell.get_input('encoder_proj') @@ -194,8 +195,6 @@ def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim, state_cell.set_state('h', h) state_cell.set_state('c', c) - state_cell.register_updater(updater) - if not is_generating: trg_word_idx = fluid.layers.data( name='target_sequence', shape=[1], dtype='int64', lod_level=1) @@ -233,7 +232,68 @@ def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim, return avg_cost, feeding_list else: - pass + init_ids = fluid.layers.data( + name="init_ids", shape=[1], dtype="int64", lod_level=2) + init_scores = fluid.layers.data( + name="init_scores", shape=[1], dtype="float32", lod_level=2) + ''' + src_embedding = fluid.layers.embedding( + input=src_word_idx, + size=[source_dict_dim, embedding_dim], + dtype='float32') + ''' + + src_embedding = fluid.layers.embedding( + input=src_word_idx, + size=[source_dict_dim, embedding_dim], + dtype='float32', + ParamAttr=()) + + decoder = BeamSearchDecoder(state_cell, max_len=max_length) + + with decoder.block(): + # encoder_vec = prev_scores + # encoder_proj = prev_scores + prev_ids = decoder.read_array(init=init_ids, is_ids=True) + prev_scores = decoder.read_array(init=init_scores, is_scores=True) + # need make sure the weight shared + prev_ids_embedding = fluid.layers.embedding(prev_ids) + prev_h = decoder.state_cell.get_state('h') + prev_c = decoder.state_cell.get_state('c') + prev_h_expanded = fluid.layers.sequence_expand(prev_h, prev_scores) + prev_c_expanded = fluid.layers.sequence_expand(prev_c, prev_scores) + decoder.state_cell.set_state('h', prev_h_expanded) + decoder.state_cell.set_state('c', prev_c_expanded) + + decoder.state_cell.compute_state(inputs={ + 'x': prev_ids_embedding, + 'encoder_vec': None, + 'encoder_proj': None + }) + + current_state = decoder.state_cell.get_state('h') + scores = fluid.layers.fc(input=current_state, + size=target_dict_dim, + act='softmax') + topk_scores, topk_indices = fluid.layers.topk(scores, k=beam_size) + selected_ids, selected_scores = fluid.layers.beam_search( + prev_ids, + topk_indices, + topk_scores, + beam_size, + end_id=10, + level=0) + decoder.state_cell.update_states() + decoder.update_array(prev_ids, selected_ids) + decoder.update_array(prev_scores, selected_scores) + + translation_ids, translation_scores = decoder() + + feeding_list = [ + "source_sequence", "target_sequence", "init_ids", "init_scores" + ] + + return translation_ids, translation_scores, feeding_list def to_lodtensor(data, place): @@ -345,7 +405,43 @@ def train(): def infer(): - pass + translation_ids, translation_scores, feeding_list = seq_to_seq_net( + args.embedding_dim, + args.encoder_size, + args.decoder_size, + args.dict_size, + args.dict_size, + True, + beam_size=args.beam_size, + max_length=args.max_length) + + fluid.memory_optimize(fluid.default_main_program(), print_log=False) + + test_batch_generator = paddle.v2.batch( + paddle.v2.reader.shuffle( + paddle.v2.dataset.wmt14.test(args.dict_size), buf_size=1000), + batch_size=args.batch_size) + + place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace() + exe = Executor(place) + exe.run(framework.default_startup_program()) + + for batch_id, data in enumerate(test_batch_generator()): + src_seq, word_num = to_lodtensor(map(lambda x: x[0], data), place) + trg_seq, word_num = to_lodtensor(map(lambda x: x[1], data), place) + lbl_seq, _ = to_lodtensor(map(lambda x: x[2], data), place) + + fetch_outs = exe.run(framework.default_main_program(), + feed={ + feeding_list[0]: src_seq, + feeding_list[1]: trg_seq, + feeding_list[2]: lbl_seq + }, + fetch_list=[avg_cost]) + + avg_cost_val = np.array(fetch_outs[0]) + print('pass_id=%d, batch_id=%d, train_loss: %f' % (pass_id, batch_id, + avg_cost_val)) if __name__ == '__main__': diff --git a/fluid/rnn_beam_search/beam_search_api.py b/fluid/rnn_beam_search/beam_search_api.py index 04386ce5b809fe2d7a2778323f3ddf25869087d8..a5630a2f9d9968f731569a9b673c047dd8d3f497 100644 --- a/fluid/rnn_beam_search/beam_search_api.py +++ b/fluid/rnn_beam_search/beam_search_api.py @@ -108,6 +108,7 @@ class StateCell(object): self._in_decoder = False self._states_holder = {} self._switched_decoder = False + self._state_updater = None def enter_decoder(self, decoder_obj): if self._in_decoder == True or self._cur_decoder_obj is not None: @@ -172,8 +173,16 @@ class StateCell(object): def set_state(self, state_name, state_value): self._cur_states[state_name] = state_value - def register_updater(self, state_updater): - self._state_updater = state_updater + def state_updater(self, updater): + self._state_updater = updater + + def _decorator(state_cell): + if state_cell == self: + raise TypeError('Updater should only accept a StateCell object ' + 'as argument.') + updater(state_cell) + + return _decorator def compute_state(self, inputs): if self._in_decoder and not self._switched_decoder: diff --git a/fluid/rnn_beam_search/simple_seq2seq.py b/fluid/rnn_beam_search/simple_seq2seq.py index baaafc8189dffb1e483aef2026e8aa4036d18b8a..31d870277dbda5bc4896443d6b140335ce25abea 100644 --- a/fluid/rnn_beam_search/simple_seq2seq.py +++ b/fluid/rnn_beam_search/simple_seq2seq.py @@ -55,18 +55,17 @@ def encoder(): return encoder_out -def updater(state_cell): - current_word = state_cell.get_input('x') - prev_h = state_cell.get_state('h') - h = pd.fc(input=[current_word, prev_h], size=decoder_size, act='tanh') - state_cell.set_state('h', h) - - def decoder_train(context): h = InitState(init=context) state_cell = StateCell( cell_size=decoder_size, inputs={'x': None}, states={'h': h}) - state_cell.register_updater(updater) + + @state_cell.state_updater + def updater(state_cell): + current_word = state_cell.get_input('x') + prev_h = state_cell.get_state('h') + h = pd.fc(input=[current_word, prev_h], size=decoder_size, act='tanh') + state_cell.set_state('h', h) # decoder trg_language_word = pd.data(