User interface design for Beam Search and Decoder
Created by: pkuyym
I think it's better to design the user interface first before considering the api implementation. The following design may be adjusted when considering the implementation.
Decoder
Decoder is an important component used to generate sequences for many tasks like machine translation etc. The generation process is different when do training and inference. In training phase, decoder selects top 1 token with highest probability, instead in inference phase, decoder selects multiple candidate sequences with highest likelihood estimation. Basically, we can classify the decoder to training decoder and inference decoder. We should make sure the interface of each decoder customizable and labored free to add more strategy like scheduled sampling for training decoder and beam search for inference decoder.
Beam Search
Beam search is a heuristic search algorithm usually used to improve the estimate of the highest likelihood output sequence. Beam search is a general algorithm and commonly incorporated with inference decoder to reduce the search space significantly. Currently, PaddlePaddle Fluid has provided the basic components of beam search and simple validation has been done on a simple Seq2Seq model. However using the raw components directly makes the model chaotic and tediously long. We should hide the common logic and expose a simple interface without harming the generality.
User interface
The generation process is step-by-step, in each step, the model state will be updated according to the decoder results of previous step, and the decoder calculates a probability distribution according to current model state to generate new tokens. For training decoder and inference decoder who shared one state cell, the state computation logic is the same. We can decouple the state computation logic and decoder generation logic.
State Cell
A state cell only cares the state computation logic. The computation logic can be as simple as vanilla RNN or can be as complex as LSTM with attention mechanisms. The state cell only describe the computation logic, the actual executation is invoked by decoder. We can define a state cell as followings:
- Vanilla RNN
state_cell = fluid.layers.StateCell(step_inputs={'x': None},
init_states={'h': init_h},
size=cell_size)
with state_cell.block() as cell:
step_x = cell.step_input('x')
prev_state = cell.state('h')
cur_state = fluid.layers.fc(input=[step_x, prev_state], size=cell.size)
cell.update_state('h', cur_state) # update the state
- LSTM RNN
state_cell = fluid.layers.StateCell(
step_inputs={'x': None},
init_states={'h': init_h, 'c': init_c},
size=cell_size)
with state_cell.block() as cell:
step_x = cell.step_input('x')
prev_h = cell.state('h')
prev_c = cell.state('c')
cur_h, cur_c = fluid.layers.lstm_step(step_x, prev_h, prev_c, cell.size)
cell.update_state('h', cur_h) # update the hidden state
cell.update_state('c', cur_c) # update the cell value
- LSTM RNN with attention
state_cell = fluid.layers.StateCell(
step_inputs={'x': None, 'context': encoder_output},
init_states={'h': init_h, 'c': init_c},
size=cell_size)
with state_cell.block() as cell:
step_x = cell.step_input('x')
step_context = cell.step_input('context')
prev_h = cell.state('h')
context = fluid.layers.simple_attention(step_context, prev_h)
lstm_input = fluid.layers.concat(input=[context, step_x], axis=1)
prev_c = cell.state('c')
cur_h, cur_c = fluid.layers.lstm_step(lstm_input, prev_h, prev_c, cell.size)
cell.update_state('h', cur_h) # update the hidden state
cell.update_state('c', cur_c) # update the cell value
Incorporated with decoder
state_cell = fluid.layers.StateCell(
step_inputs={'x': None},
init_states={'h': init_h},
size=cell_size)
with state_cell.block() as cell:
step_x = cell.step_input('x')
prev_state = cell.state('h')
cur_state = fluid.layers.fc(input=[step_x, prev_state], size=cell.size)
cell.update_state('h', cur_state) # update the state
if is_training:
training_decoder = fluid.layers.training_decoder(state_cell=state_cell)
with training_decoder.block() as decoder:
step_x = decoder.step_input(x)
decoder.state_cell.update_state({'x': step_x})
cur_state = decoder.state_cell.state('h')
probs = fluid.layers.fc(input=cur_state,
size=output_dim,
act='softmax')
decoder.output(probs)
decoder_output = training_decoder()
else:
beam_search_decoder = fluid.layers.beam_search_decoder(state_cell=state_cell,
maximum_len,
beam_width,
eos_token,
embedding_func)
with beam_search_decoder.block() as decoder:
prev_step_ids = decoder.step_input(init_ids)
prev_step_scores = decoder.step_input(init_scores)
prev_step_ids_embedding = decoder.embedding(prev_step_ids)
decoder.state_cell.update_state({'x': prev_step_ids_embedding})
cur_state = decoder.state_cell.state('h')
cur_state_expaned = fluid.layers.sequence_expand(cur_state, prev_step_ids)
cur_step_scores = fluid.layers.fc(input=cur_state,
size=output_dim,
act='softmax')
topk_scores, topk_indices = fluid.layers.topk(cur_step_scores,
k=decoder.beam_width)
selected_step_ids, selected_step_scores = fluid.layers.beam_search(prev_step_ids,
topk_indices,
topk_scores,
decoder.beam_width,
end_id=decoder.eos_token,
level=0)
decoder_path_ids, decoder_path_scores = beam_search_decoder()