提交 b2eb5149 编写于 作者: G guosheng

Add beam search for Transformer.

上级 2a0991e1
...@@ -16,7 +16,9 @@ import logging ...@@ -16,7 +16,9 @@ import logging
import os import os
import six import six
import sys import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import time import time
import contextlib
import numpy as np import numpy as np
import paddle import paddle
...@@ -27,10 +29,11 @@ from utils.check import check_gpu, check_version ...@@ -27,10 +29,11 @@ from utils.check import check_gpu, check_version
# include task-specific libs # include task-specific libs
import reader import reader
from model import Transformer, position_encoding_init from transformer import InferTransformer, position_encoding_init
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False): def post_process_seq(seq, bos_idx, eos_idx, output_bos=False,
output_eos=False):
""" """
Post-process the decoded sequence. Post-process the decoded sequence.
""" """
...@@ -47,10 +50,13 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False): ...@@ -47,10 +50,13 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
def do_predict(args): def do_predict(args):
if args.use_cuda: device_ids = list(range(args.num_devices))
place = fluid.CUDAPlace(0)
else: @contextlib.contextmanager
place = fluid.CPUPlace() def null_guard():
yield
guard = fluid.dygraph.guard() if args.eager_run else null_guard()
# define the data generator # define the data generator
processor = reader.DataProcessor(fpattern=args.predict_file, processor = reader.DataProcessor(fpattern=args.predict_file,
...@@ -69,68 +75,61 @@ def do_predict(args): ...@@ -69,68 +75,61 @@ def do_predict(args):
unk_mark=args.special_token[2], unk_mark=args.special_token[2],
max_length=args.max_length, max_length=args.max_length,
n_head=args.n_head) n_head=args.n_head)
batch_generator = processor.data_generator(phase="predict", place=place) batch_generator = processor.data_generator(phase="predict")
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = processor.get_vocab_summary() args.unk_idx = processor.get_vocab_summary()
trg_idx2word = reader.DataProcessor.load_dict( trg_idx2word = reader.DataProcessor.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True) dict_path=args.trg_vocab_fpath, reverse=True)
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ with guard:
args.unk_idx = processor.get_vocab_summary()
with fluid.dygraph.guard(place):
# define data loader # define data loader
test_loader = fluid.io.DataLoader.from_generator(capacity=10) test_loader = batch_generator
test_loader.set_batch_generator(batch_generator, places=place)
# define model # define model
transformer = Transformer( transformer = InferTransformer(args.src_vocab_size,
args.src_vocab_size, args.trg_vocab_size, args.max_length + 1, args.trg_vocab_size,
args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model, args.max_length + 1,
args.d_inner_hid, args.prepostprocess_dropout, args.n_layer,
args.attention_dropout, args.relu_dropout, args.preprocess_cmd, args.n_head,
args.postprocess_cmd, args.weight_sharing, args.bos_idx, args.d_key,
args.eos_idx) args.d_value,
args.d_model,
args.d_inner_hid,
args.prepostprocess_dropout,
args.attention_dropout,
args.relu_dropout,
args.preprocess_cmd,
args.postprocess_cmd,
args.weight_sharing,
args.bos_idx,
args.eos_idx,
beam_size=args.beam_size,
max_out_len=args.max_out_len)
# load the trained model # load the trained model
assert args.init_from_params, ( assert args.init_from_params, (
"Please set init_from_params to load the infer model.") "Please set init_from_params to load the infer model.")
model_dict, _ = fluid.load_dygraph( transformer.load(os.path.join(args.init_from_params, "transformer"))
os.path.join(args.init_from_params, "transformer"))
# to avoid a longer length than training, reset the size of position
# encoding to max_length
model_dict["encoder.pos_encoder.weight"] = position_encoding_init(
args.max_length + 1, args.d_model)
model_dict["decoder.pos_encoder.weight"] = position_encoding_init(
args.max_length + 1, args.d_model)
transformer.load_dict(model_dict)
# set evaluate mode
transformer.eval()
f = open(args.output_file, "wb") f = open(args.output_file, "wb")
for input_data in test_loader(): for input_data in test_loader():
(src_word, src_pos, src_slf_attn_bias, trg_word, (src_word, src_pos, src_slf_attn_bias, trg_word,
trg_src_attn_bias) = input_data trg_src_attn_bias) = input_data
finished_seq, finished_scores = transformer.beam_search( finished_seq = transformer.test(inputs=(src_word, src_pos,
src_word, src_slf_attn_bias,
src_pos, trg_src_attn_bias),
src_slf_attn_bias, device='gpu',
trg_word, device_ids=device_ids)[0]
trg_src_attn_bias, finished_seq = np.transpose(finished_seq, [0, 2, 1])
bos_id=args.bos_idx,
eos_id=args.eos_idx,
beam_size=args.beam_size,
max_len=args.max_out_len)
finished_seq = finished_seq.numpy()
finished_scores = finished_scores.numpy()
for ins in finished_seq: for ins in finished_seq:
for beam_idx, beam in enumerate(ins): for beam_idx, beam in enumerate(ins):
if beam_idx >= args.n_best: break if beam_idx >= args.n_best: break
id_list = post_process_seq(beam, args.bos_idx, args.eos_idx) id_list = post_process_seq(beam, args.bos_idx,
args.eos_idx)
word_list = [trg_idx2word[id] for id in id_list] word_list = [trg_idx2word[id] for id in id_list]
sequence = b" ".join(word_list) + b"\n" sequence = b" ".join(word_list) + b"\n"
f.write(sequence) f.write(sequence)
break
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -114,7 +114,7 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head): ...@@ -114,7 +114,7 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
return data_inputs return data_inputs
def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head, place): def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head):
""" """
Put all padded data needed by beam search decoder into a list. Put all padded data needed by beam search decoder into a list.
""" """
...@@ -517,7 +517,7 @@ class DataProcessor(object): ...@@ -517,7 +517,7 @@ class DataProcessor(object):
return __impl__ return __impl__
def data_generator(self, phase, place=None): def data_generator(self, phase):
# Any token included in dict can be used to pad, since the paddings' loss # Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients. # will be masked out by weights and make no effect on parameter gradients.
src_pad_idx = trg_pad_idx = self._eos_idx src_pad_idx = trg_pad_idx = self._eos_idx
...@@ -540,7 +540,7 @@ class DataProcessor(object): ...@@ -540,7 +540,7 @@ class DataProcessor(object):
def __for_predict__(): def __for_predict__():
for data in data_reader(): for data in data_reader():
data_inputs = prepare_infer_input(data, src_pad_idx, bos_idx, data_inputs = prepare_infer_input(data, src_pad_idx, bos_idx,
n_head, place) n_head)
yield data_inputs yield data_inputs
return __for_train__ if phase == "train" else __for_predict__ return __for_train__ if phase == "train" else __for_predict__
......
import collections
import contextlib
import inspect
import six
import sys
from functools import partial, reduce
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers.utils as utils
from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as
from paddle.fluid.dygraph import to_variable, Embedding, Linear
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid import layers
from paddle.fluid.dygraph import Layer
class RNNUnit(Layer):
def get_initial_states(self,
batch_ref,
shape=None,
dtype=None,
init_value=0,
batch_dim_idx=0):
"""
Generate initialized states according to provided shape, data type and
value.
Parameters:
batch_ref: A (possibly nested structure of) tensor variable[s].
The first dimension of the tensor will be used as batch size to
initialize states.
shape: A (possiblely nested structure of) shape[s], where a shape is
represented as a list/tuple of integer). -1(for batch size) will
beautomatically inserted if shape is not started with it. If None,
property `state_shape` will be used. The default value is None.
dtype: A (possiblely nested structure of) data type[s]. The structure
must be same as that of `shape`, except when all tensors' in states
has the same data type, a single data type can be used. If None and
property `cell.state_shape` is not available, float32 will be used
as the data type. The default value is None.
init_value: A float value used to initialize states.
Returns:
Variable: tensor variable[s] packed in the same structure provided \
by shape, representing the initialized states.
"""
# TODO: use inputs and batch_size
batch_ref = flatten(batch_ref)[0]
def _is_shape_sequence(seq):
if sys.version_info < (3, ):
integer_types = (
int,
long,
)
else:
integer_types = (int, )
"""For shape, list/tuple of integer is the finest-grained objection"""
if (isinstance(seq, list) or isinstance(seq, tuple)):
if reduce(
lambda flag, x: isinstance(x, integer_types) and flag,
seq, True):
return False
# TODO: Add check for the illegal
if isinstance(seq, dict):
return True
return (isinstance(seq, collections.Sequence)
and not isinstance(seq, six.string_types))
class Shape(object):
def __init__(self, shape):
self.shape = shape if shape[0] == -1 else ([-1] + list(shape))
# nested structure of shapes
states_shapes = self.state_shape if shape is None else shape
is_sequence_ori = utils.is_sequence
utils.is_sequence = _is_shape_sequence
states_shapes = map_structure(lambda shape: Shape(shape),
states_shapes)
utils.is_sequence = is_sequence_ori
# nested structure of dtypes
try:
states_dtypes = self.state_dtype if dtype is None else dtype
except NotImplementedError: # use fp32 as default
states_dtypes = "float32"
if len(flatten(states_dtypes)) == 1:
dtype = flatten(states_dtypes)[0]
states_dtypes = map_structure(lambda shape: dtype, states_shapes)
init_states = map_structure(
lambda shape, dtype: fluid.layers.fill_constant_batch_size_like(
input=batch_ref,
shape=shape.shape,
dtype=dtype,
value=init_value,
input_dim_idx=batch_dim_idx), states_shapes, states_dtypes)
return init_states
@property
def state_shape(self):
"""
Abstract method (property).
Used to initialize states.
A (possiblely nested structure of) shape[s], where a shape is represented
as a list/tuple of integers (-1 for batch size would be automatically
inserted into a shape if shape is not started with it).
Not necessary to be implemented if states are not initialized by
`get_initial_states` or the `shape` argument is provided when using
`get_initial_states`.
"""
raise NotImplementedError(
"Please add implementaion for `state_shape` in the used cell.")
@property
def state_dtype(self):
"""
Abstract method (property).
Used to initialize states.
A (possiblely nested structure of) data types[s]. The structure must be
same as that of `shape`, except when all tensors' in states has the same
data type, a signle data type can be used.
Not necessary to be implemented if states are not initialized
by `get_initial_states` or the `dtype` argument is provided when using
`get_initial_states`.
"""
raise NotImplementedError(
"Please add implementaion for `state_dtype` in the used cell.")
class BasicLSTMUnit(RNNUnit):
"""
****
BasicLSTMUnit class, Using basic operator to build LSTM
The algorithm can be described as the code below.
.. math::
i_t &= \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + b_i)
f_t &= \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + b_f + forget_bias )
o_t &= \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + b_o)
\\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + b_c)
c_t &= f_t \odot c_{t-1} + i_t \odot \\tilde{c_t}
h_t &= o_t \odot tanh(c_t)
- $W$ terms denote weight matrices (e.g. $W_{ix}$ is the matrix
of weights from the input gate to the input)
- The b terms denote bias vectors ($bx_i$ and $bh_i$ are the input gate bias vector).
- sigmoid is the logistic sigmoid function.
- $i, f, o$ and $c$ are the input gate, forget gate, output gate,
and cell activation vectors, respectively, all of which have the same size as
the cell output activation vector $h$.
- The :math:`\odot` is the element-wise product of the vectors.
- :math:`tanh` is the activation functions.
- :math:`\\tilde{c_t}` is also called candidate hidden state,
which is computed based on the current input and the previous hidden state.
Args:
name_scope(string) : The name scope used to identify parameter and bias name
hidden_size (integer): The hidden size used in the Unit.
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight matrix. Note:
If it is set to None or one attribute of ParamAttr, lstm_unit will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|None): The parameter attribute for the bias
of LSTM unit.
If it is set to None or one attribute of ParamAttr, lstm_unit will
create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized as zero. Default: None.
gate_activation (function|None): The activation function for gates (actGate).
Default: 'fluid.layers.sigmoid'
activation (function|None): The activation function for cells (actNode).
Default: 'fluid.layers.tanh'
forget_bias(float|1.0): forget bias used when computing forget gate
dtype(string): data type used in this unit
"""
def __init__(self,
hidden_size,
input_size,
param_attr=None,
bias_attr=None,
gate_activation=None,
activation=None,
forget_bias=1.0,
dtype='float32'):
super(BasicLSTMUnit, self).__init__(dtype)
self._hidden_size = hidden_size
self._param_attr = param_attr
self._bias_attr = bias_attr
self._gate_activation = gate_activation or layers.sigmoid
self._activation = activation or layers.tanh
self._forget_bias = layers.fill_constant([1],
dtype=dtype,
value=forget_bias)
self._forget_bias.stop_gradient = False
self._dtype = dtype
self._input_size = input_size
self._weight = self.create_parameter(
attr=self._param_attr,
shape=[
self._input_size + self._hidden_size, 4 * self._hidden_size
],
dtype=self._dtype)
self._bias = self.create_parameter(attr=self._bias_attr,
shape=[4 * self._hidden_size],
dtype=self._dtype,
is_bias=True)
def forward(self, input, state):
pre_hidden, pre_cell = state
concat_input_hidden = layers.concat([input, pre_hidden], 1)
gate_input = layers.matmul(x=concat_input_hidden, y=self._weight)
gate_input = layers.elementwise_add(gate_input, self._bias)
i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1)
new_cell = layers.elementwise_add(
layers.elementwise_mul(
pre_cell,
layers.sigmoid(layers.elementwise_add(f, self._forget_bias))),
layers.elementwise_mul(layers.sigmoid(i), layers.tanh(j)))
new_hidden = layers.tanh(new_cell) * layers.sigmoid(o)
return new_hidden, [new_hidden, new_cell]
@property
def state_shape(self):
return [[self._hidden_size], [self._hidden_size]]
class RNN(fluid.dygraph.Layer):
def __init__(self, cell, is_reverse=False, time_major=False):
super(RNN, self).__init__()
self.cell = cell
if not hasattr(self.cell, "call"):
self.cell.call = self.cell.forward
self.is_reverse = is_reverse
self.time_major = time_major
self.batch_index, self.time_step_index = (1, 0) if time_major else (0,
1)
def forward(self,
inputs,
initial_states=None,
sequence_length=None,
**kwargs):
if fluid.in_dygraph_mode():
class ArrayWrapper(object):
def __init__(self, x):
self.array = [x]
def append(self, x):
self.array.append(x)
return self
def _maybe_copy(state, new_state, step_mask):
# TODO: use where_op
new_state = fluid.layers.elementwise_mul(
new_state, step_mask,
axis=0) - fluid.layers.elementwise_mul(state,
(step_mask - 1),
axis=0)
return new_state
flat_inputs = flatten(inputs)
batch_size, time_steps = (
flat_inputs[0].shape[self.batch_index],
flat_inputs[0].shape[self.time_step_index])
if initial_states is None:
initial_states = self.cell.get_initial_states(
batch_ref=inputs, batch_dim_idx=self.batch_index)
if not self.time_major:
inputs = map_structure(
lambda x: fluid.layers.transpose(x, [1, 0] + list(
range(2, len(x.shape)))), inputs)
if sequence_length:
mask = fluid.layers.sequence_mask(
sequence_length,
maxlen=time_steps,
dtype=flatten(initial_states)[0].dtype)
mask = fluid.layers.transpose(mask, [1, 0])
if self.is_reverse:
inputs = map_structure(
lambda x: fluid.layers.reverse(x, axis=[0]), inputs)
mask = fluid.layers.reverse(
mask, axis=[0]) if sequence_length else None
states = initial_states
outputs = []
for i in range(time_steps):
step_inputs = map_structure(lambda x: x[i], inputs)
step_outputs, new_states = self.cell(step_inputs, states,
**kwargs)
if sequence_length:
new_states = map_structure(
partial(_maybe_copy, step_mask=mask[i]), states,
new_states)
states = new_states
outputs = map_structure(
lambda x: ArrayWrapper(x),
step_outputs) if i == 0 else map_structure(
lambda x, x_array: x_array.append(x), step_outputs,
outputs)
final_outputs = map_structure(
lambda x: fluid.layers.stack(x.array,
axis=self.time_step_index),
outputs)
if self.is_reverse:
final_outputs = map_structure(
lambda x: fluid.layers.reverse(x,
axis=self.time_step_index),
final_outputs)
final_states = new_states
else:
final_outputs, final_states = fluid.layers.rnn(
self.cell,
inputs,
initial_states=initial_states,
sequence_length=sequence_length,
time_major=self.time_major,
is_reverse=self.is_reverse,
**kwargs)
return final_outputs, final_states
from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, to_variable
place = fluid.CPUPlace()
executor = fluid.Executor(place)
class EncoderCell(RNNUnit):
def __init__(self, num_layers, input_size, hidden_size, dropout_prob=0.):
super(EncoderCell, self).__init__()
self.num_layers = num_layers
self.dropout_prob = dropout_prob
self.lstm_cells = list()
for i in range(self.num_layers):
self.lstm_cells.append(
self.add_sublayer(
"layer_%d" % i,
BasicLSTMUnit(input_size if i == 0 else hidden_size,
hidden_size)))
def forward(self, step_input, states):
new_states = []
for i in range(self.num_layers):
out, new_state = self.lstm_cells[i](step_input, states[i])
step_input = layers.dropout(
out, self.dropout_prob) if self.dropout_prob > 0 else out
new_states.append(new_state)
return step_input, new_states
@property
def state_shape(self):
return [cell.state_shape for cell in self.lstm_cells]
class MultiHeadAttention(Layer):
"""
Multi-Head Attention
"""
# def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
# pass
# def forward(self, queries, keys, values, attn_bias, cache=None):
# pass
def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
super(MultiHeadAttention, self).__init__()
self.n_head = n_head
self.d_key = d_key
self.d_value = d_value
self.d_model = d_model
self.dropout_rate = dropout_rate
self.q_fc = Linear(input_dim=d_model,
output_dim=d_key * n_head,
bias_attr=False)
self.k_fc = Linear(input_dim=d_model,
output_dim=d_key * n_head,
bias_attr=False)
self.v_fc = Linear(input_dim=d_model,
output_dim=d_value * n_head,
bias_attr=False)
self.proj_fc = Linear(input_dim=d_value * n_head,
output_dim=d_model,
bias_attr=False)
def forward(self, queries, keys, values, attn_bias, cache=None):
# compute q ,k ,v
keys = queries if keys is None else keys
values = keys if values is None else values
q = self.q_fc(queries)
k = self.k_fc(keys)
v = self.v_fc(values)
# split head
q = layers.reshape(x=q, shape=[0, 0, self.n_head, self.d_key])
q = layers.transpose(x=q, perm=[0, 2, 1, 3])
k = layers.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
k = layers.transpose(x=k, perm=[0, 2, 1, 3])
v = layers.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
v = layers.transpose(x=v, perm=[0, 2, 1, 3])
if cache is not None:
cache_k, cache_v = cache["k"], cache["v"]
k = layers.concat([cache_k, k], axis=2)
v = layers.concat([cache_v, v], axis=2)
cache["k"], cache["v"] = k, v
# scale dot product attention
product = layers.matmul(x=q,
y=k,
transpose_y=True,
alpha=self.d_model**-0.5)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if self.dropout_rate:
weights = layers.dropout(weights,
dropout_prob=self.dropout_rate,
is_test=False)
out = layers.matmul(weights, v)
# combine heads
out = layers.transpose(out, perm=[0, 2, 1, 3])
out = layers.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.proj_fc(out)
return out
class DynamicDecode(Layer):
def __init__(self,
decoder,
max_step_num=None,
output_time_major=False,
impute_finished=False,
is_test=False,
return_length=False):
super(DynamicDecode, self).__init__()
self.decoder = decoder
self.max_step_num = max_step_num
self.output_time_major = output_time_major
self.impute_finished = impute_finished
self.is_test = is_test
self.return_length = return_length
def forward(self, inits=None, **kwargs):
if fluid.in_dygraph_mode():
class ArrayWrapper(object):
def __init__(self, x):
self.array = [x]
def append(self, x):
self.array.append(x)
return self
def __getitem__(self, item):
return self.array.__getitem__(item)
def _maybe_copy(state, new_state, step_mask):
# TODO: use where_op
state_dtype = state.dtype
if convert_dtype(state_dtype) in ["bool"]:
state = layers.cast(state, dtype="float32")
new_state = layers.cast(new_state, dtype="float32")
if step_mask.dtype != state.dtype:
step_mask = layers.cast(step_mask, dtype=state.dtype)
# otherwise, renamed bool gradients of would be summed up leading
# to sum(bool) error.
step_mask.stop_gradient = True
new_state = layers.elementwise_mul(
state, step_mask, axis=0) - layers.elementwise_mul(
new_state, (step_mask - 1), axis=0)
if convert_dtype(state_dtype) in ["bool"]:
new_state = layers.cast(new_state, dtype=state_dtype)
return new_state
initial_inputs, initial_states, initial_finished = self.decoder.initialize(
inits)
inputs, states, finished = (initial_inputs, initial_states,
initial_finished)
cond = layers.logical_not((layers.reduce_all(initial_finished)))
sequence_lengths = layers.cast(layers.zeros_like(initial_finished),
"int64")
outputs = None
step_idx = 0
step_idx_tensor = layers.fill_constant(shape=[1],
dtype="int64",
value=step_idx)
while cond.numpy():
(step_outputs, next_states, next_inputs,
next_finished) = self.decoder.step(step_idx_tensor, inputs,
states, **kwargs)
next_finished = layers.logical_or(next_finished, finished)
next_sequence_lengths = layers.elementwise_add(
sequence_lengths,
layers.cast(layers.logical_not(finished),
sequence_lengths.dtype))
if self.impute_finished: # rectify the states for the finished.
next_states = map_structure(
lambda x, y: _maybe_copy(x, y, finished), states,
next_states)
outputs = map_structure(
lambda x: ArrayWrapper(x),
step_outputs) if step_idx == 0 else map_structure(
lambda x, x_array: x_array.append(x), step_outputs,
outputs)
inputs, states, finished, sequence_lengths = (
next_inputs, next_states, next_finished,
next_sequence_lengths)
layers.increment(x=step_idx_tensor, value=1.0, in_place=True)
step_idx += 1
layers.logical_not(layers.reduce_all(finished), cond)
if self.max_step_num is not None and step_idx > self.max_step_num:
break
final_outputs = map_structure(
lambda x: fluid.layers.stack(x.array, axis=0), outputs)
final_states = states
try:
final_outputs, final_states = self.decoder.finalize(
final_outputs, final_states, sequence_lengths)
except NotImplementedError:
pass
if not self.output_time_major:
final_outputs = map_structure(
lambda x: layers.transpose(x, [1, 0] + list(
range(2, len(x.shape)))), final_outputs)
return (final_outputs, final_states,
sequence_lengths) if self.return_length else (
final_outputs, final_states)
else:
return fluid.layers.dynamic_decode(
self.decoder,
inits,
max_step_num=self.max_step_num,
output_time_major=self.output_time_major,
impute_finished=self.impute_finished,
is_test=self.is_test,
return_length=self.return_length,
**kwargs)
class TransfomerCell(object):
"""
Let inputs=(trg_word, trg_pos), states=cache to make Transformer can be
used as RNNCell
"""
def __init__(self, decoder):
self.decoder = decoder
def __call__(self, inputs, states, trg_src_attn_bias, enc_output,
static_caches):
trg_word, trg_pos = inputs
for cache, static_cache in zip(states, static_caches):
cache.update(static_cache)
logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
enc_output, states)
new_states = [{"k": cache["k"], "v": cache["v"]} for cache in states]
return logits, new_states
class TransformerBeamSearchDecoder(layers.BeamSearchDecoder):
def __init__(self, cell, start_token, end_token, beam_size,
var_dim_in_state):
super(TransformerBeamSearchDecoder,
self).__init__(cell, start_token, end_token, beam_size)
self.cell = cell
self.var_dim_in_state = var_dim_in_state
def _merge_batch_beams_with_var_dim(self, x):
if not hasattr(self, "batch_size"):
self.batch_size = layers.shape(x)[0]
if not hasattr(self, "batch_beam_size"):
self.batch_beam_size = self.batch_size * self.beam_size
# init length of cache is 0, and it increases with decoding carrying on,
# thus need to reshape elaborately
var_dim_in_state = self.var_dim_in_state + 1 # count in beam dim
x = layers.transpose(
x,
list(range(var_dim_in_state, len(x.shape))) +
list(range(0, var_dim_in_state)))
x = layers.reshape(x, [0] * (len(x.shape) - var_dim_in_state) +
[self.batch_beam_size] +
list(x.shape[-var_dim_in_state + 2:]))
x = layers.transpose(
x,
list(range((len(x.shape) + 1 - var_dim_in_state), len(x.shape))) +
list(range(0, (len(x.shape) + 1 - var_dim_in_state))))
return x
def _split_batch_beams_with_var_dim(self, x):
var_dim_size = layers.shape(x)[self.var_dim_in_state]
x = layers.reshape(
x, [-1, self.beam_size] + list(x.shape[1:self.var_dim_in_state]) +
[var_dim_size] + list(x.shape[self.var_dim_in_state + 1:]))
return x
def step(self, time, inputs, states, **kwargs):
# compared to RNN, Transformer has 3D data at every decoding step
inputs = layers.reshape(inputs, [-1, 1]) # token
pos = layers.ones_like(inputs) * time # pos
cell_states = map_structure(self._merge_batch_beams_with_var_dim,
states.cell_states)
cell_outputs, next_cell_states = self.cell((inputs, pos), cell_states,
**kwargs)
cell_outputs = map_structure(self._split_batch_beams, cell_outputs)
next_cell_states = map_structure(self._split_batch_beams_with_var_dim,
next_cell_states)
beam_search_output, beam_search_state = self._beam_search_step(
time=time,
logits=cell_outputs,
next_cell_states=next_cell_states,
beam_state=states)
next_inputs, finished = (beam_search_output.predicted_ids,
beam_search_state.finished)
return (beam_search_output, beam_search_state, next_inputs, finished)
'''
@contextlib.contextmanager
def eager_guard(is_eager):
if is_eager:
with fluid.dygraph.guard():
yield
else:
yield
# print(flatten(np.random.rand(2,8,8)))
random_seed = 123
np.random.seed(random_seed)
# print np.random.rand(2, 8)
batch_size = 2
seq_len = 8
hidden_size = 8
vocab_size, embed_dim, num_layers, hidden_size = 100, 8, 2, 8
bos_id, eos_id, beam_size, max_step_num = 0, 1, 5, 10
time_major = False
eagar_run = False
import torch
with eager_guard(eagar_run):
fluid.default_main_program().random_seed = random_seed
fluid.default_startup_program().random_seed = random_seed
inputs_data = np.random.rand(batch_size, seq_len,
hidden_size).astype("float32")
states_data = np.random.rand(batch_size, hidden_size).astype("float32")
lstm_cell = BasicLSTMUnit(hidden_size=8, input_size=8)
lstm = RNN(cell=lstm_cell, time_major=time_major)
inputs = to_variable(inputs_data) if eagar_run else fluid.data(
name="x", shape=[None, None, hidden_size], dtype="float32")
states = lstm_cell.get_initial_states(batch_ref=inputs,
batch_dim_idx=1 if time_major else 0)
out, _ = lstm(inputs, states)
# print states
# print layers.BeamSearchDecoder.tile_beam_merge_with_batch(out, 5)
# embedder = Embedding(size=(vocab_size, embed_dim))
# output_layer = Linear(hidden_size, vocab_size)
# decoder = layers.BeamSearchDecoder(lstm_cell,
# bos_id,
# eos_id,
# beam_size,
# embedding_fn=embedder,
# output_fn=output_layer)
# dynamic_decoder = DynamicDecode(decoder, max_step_num)
# out,_ = dynamic_decoder(inits=states)
# caches = [{
# "k":
# layers.fill_constant_batch_size_like(out,
# shape=[-1, 8, 0, 64],
# dtype="float32",
# value=0),
# "v":
# layers.fill_constant_batch_size_like(out,
# shape=[-1, 8, 0, 64],
# dtype="float32",
# value=0)
# } for i in range(6)]
cache = layers.fill_constant_batch_size_like(out,
shape=[-1, 8, 0, 64],
dtype="float32",
value=0)
print cache
# out = layers.BeamSearchDecoder.tile_beam_merge_with_batch(cache, 5)
# out = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(cache, 5)
# batch_beam_size = layers.shape(out)[0] * 5
# print out
cell = TransfomerCell(None)
decoder = TransformerBeamSearchDecoder(cell, 0, 1, 5, 2)
cache = decoder._expand_to_beam_size(cache)
print cache
cache = decoder._merge_batch_beams_with_var_dim(cache)
print cache
cache1 = layers.fill_constant_batch_size_like(cache,
shape=[-1, 8, 1, 64],
dtype="float32",
value=0)
print cache1.shape
cache = layers.concat([cache, cache1], axis=2)
out = decoder._split_batch_beams_with_var_dim(cache)
# out = layers.transpose(out,
# list(range(3, len(out.shape))) + list(range(0, 3)))
# print out
# out = layers.reshape(out, list(out.shape[:2]) + [batch_beam_size, 8])
# print out
# out = layers.transpose(out, [2,3,0,1])
print out.shape
if eagar_run:
print "hehe" #out #.numpy()
else:
executor.run(fluid.default_startup_program())
inputs = fluid.data(name="x",
shape=[None, None, hidden_size],
dtype="float32")
out_np = executor.run(feed={"x": inputs_data},
fetch_list=[out.name])[0]
print np.array(out_np).shape
exit(0)
# dygraph
# inputs = to_variable(inputs_data)
# states = lstm_cell.get_initial_states(batch_ref=inputs,
# batch_dim_idx=1 if time_major else 0)
# print lstm(inputs, states)[0].numpy()
# graph
executor.run(fluid.default_startup_program())
inputs = fluid.data(name="x",
shape=[None, None, hidden_size],
dtype="float32")
states = lstm_cell.get_initial_states(batch_ref=inputs,
batch_dim_idx=1 if time_major else 0)
out, _ = lstm(inputs, states)
out_np = executor.run(feed={"x": inputs_data}, fetch_list=[out.name])[0]
print np.array(out_np)
#print fluid.io.save_inference_model(dirname="test_model", feeded_var_names=["x"], target_vars=[out], executor=executor, model_filename="model.pdmodel", params_filename="params.pdparams")
# test_program, feed_target_names, fetch_targets = fluid.io.load_inference_model(dirname="test_model", executor=executor, model_filename="model.pdmodel", params_filename="params.pdparams")
# out = executor.run(program=test_program, feed={"x": np.random.rand(2, 8, 8).astype("float32")}, fetch_list=fetch_targets)[0]
'''
\ No newline at end of file
python -u train.py \
--epoch 30 \
--src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de.tiny \
--validation_file wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \
--print_step 1 \
--use_cuda True \
--random_seed 1000 \
--save_step 10 \
--eager_run True
#--init_from_pretrain_model base_model_dygraph/step_100000/ \
#--init_from_checkpoint trained_models/step_200/transformer
#--n_head 16 \
#--d_model 1024 \
#--d_inner_hid 4096 \
#--prepostprocess_dropout 0.3
exit
echo `date`
python -u predict.py \
--src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--predict_file wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 64 \
--init_from_params base_model_dygraph/step_100000/ \
--beam_size 5 \
--max_out_len 255 \
--output_file predict.txt \
--eager_run True
#--max_length 500 \
#--n_head 16 \
#--d_model 1024 \
#--d_inner_hid 4096 \
#--prepostprocess_dropout 0.3
echo `date`
\ No newline at end of file
...@@ -189,6 +189,15 @@ class MultiHeadAttention(Layer): ...@@ -189,6 +189,15 @@ class MultiHeadAttention(Layer):
out = self.proj_fc(out) out = self.proj_fc(out)
return out return out
def cal_kv(self, keys, values):
k = self.k_fc(keys)
v = self.v_fc(values)
k = layers.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
k = layers.transpose(x=k, perm=[0, 2, 1, 3])
v = layers.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
v = layers.transpose(x=v, perm=[0, 2, 1, 3])
return k, v
class FFN(Layer): class FFN(Layer):
""" """
...@@ -441,6 +450,14 @@ class Decoder(Layer): ...@@ -441,6 +450,14 @@ class Decoder(Layer):
return self.processer(dec_output) return self.processer(dec_output)
def prepare_static_cache(self, enc_output):
return [
dict(
zip(("static_k", "static_v"),
decoder_layer.cross_attn.cal_kv(enc_output, enc_output)))
for decoder_layer in self.decoder_layers
]
class WrapDecoder(Layer): class WrapDecoder(Layer):
""" """
...@@ -622,481 +639,96 @@ class Transformer(Model): ...@@ -622,481 +639,96 @@ class Transformer(Model):
trg_src_attn_bias, enc_output) trg_src_attn_bias, enc_output)
return predict return predict
def beam_search_v2(self,
src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=0,
eos_id=1,
beam_size=4,
max_len=None,
alpha=0.6):
"""
Beam search with the alive and finished two queues, both have a beam size
capicity separately. It includes `grow_topk` `grow_alive` `grow_finish` as
steps.
1. `grow_topk` selects the top `2*beam_size` candidates to avoid all getting
EOS.
2. `grow_alive` selects the top `beam_size` non-EOS candidates as the inputs
of next decoding step.
3. `grow_finish` compares the already finished candidates in the finished queue
and newly added finished candidates from `grow_topk`, and selects the top
`beam_size` finished candidates.
"""
def expand_to_beam_size(tensor, beam_size):
tensor = layers.reshape(tensor,
[tensor.shape[0], 1] + tensor.shape[1:])
tile_dims = [1] * len(tensor.shape)
tile_dims[1] = beam_size
return layers.expand(tensor, tile_dims)
def merge_beam_dim(tensor):
return layers.reshape(tensor, [-1] + tensor.shape[2:])
# run encoder
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
# constant number from rnn_api import TransformerBeamSearchDecoder, DynamicDecode
inf = float(1. * 1e7)
batch_size = enc_output.shape[0]
max_len = (enc_output.shape[1] + 20) if max_len is None else max_len
### initialize states of beam search ###
## init for the alive ##
initial_log_probs = to_variable(
np.array([[0.] + [-inf] * (beam_size - 1)], dtype="float32"))
alive_log_probs = layers.expand(initial_log_probs, [batch_size, 1])
alive_seq = to_variable(
np.tile(np.array([[[bos_id]]], dtype="int64"),
(batch_size, beam_size, 1)))
## init for the finished ##
finished_scores = to_variable(
np.array([[-inf] * beam_size], dtype="float32"))
finished_scores = layers.expand(finished_scores, [batch_size, 1])
finished_seq = to_variable(
np.tile(np.array([[[bos_id]]], dtype="int64"),
(batch_size, beam_size, 1)))
finished_flags = layers.zeros_like(finished_scores)
### initialize inputs and states of transformer decoder ###
## init inputs for decoder, shaped `[batch_size*beam_size, ...]`
trg_word = layers.reshape(alive_seq[:, :, -1],
[batch_size * beam_size, 1])
trg_src_attn_bias = merge_beam_dim(
expand_to_beam_size(trg_src_attn_bias, beam_size))
enc_output = merge_beam_dim(expand_to_beam_size(enc_output, beam_size))
## init states (caches) for transformer, need to be updated according to selected beam
caches = [{
"k":
layers.fill_constant(
shape=[batch_size * beam_size, self.n_head, 0, self.d_key],
dtype=enc_output.dtype,
value=0),
"v":
layers.fill_constant(
shape=[batch_size * beam_size, self.n_head, 0, self.d_value],
dtype=enc_output.dtype,
value=0),
} for i in range(self.n_layer)]
def update_states(caches, beam_idx, beam_size):
for cache in caches:
cache["k"] = gather_2d_by_gather(cache["k"], beam_idx,
beam_size, batch_size, False)
cache["v"] = gather_2d_by_gather(cache["v"], beam_idx,
beam_size, batch_size, False)
return caches
def gather_2d_by_gather(tensor_nd,
beam_idx,
beam_size,
batch_size,
need_flat=True):
batch_idx = layers.range(0, batch_size, 1,
dtype="int64") * beam_size
flat_tensor = merge_beam_dim(tensor_nd) if need_flat else tensor_nd
idx = layers.reshape(layers.elementwise_add(beam_idx, batch_idx, 0),
[-1])
new_flat_tensor = layers.gather(flat_tensor, idx)
new_tensor_nd = layers.reshape(
new_flat_tensor,
shape=[batch_size, beam_idx.shape[1]] +
tensor_nd.shape[2:]) if need_flat else new_flat_tensor
return new_tensor_nd
def early_finish(alive_log_probs, finished_scores,
finished_in_finished):
max_length_penalty = np.power(((5. + max_len) / 6.), alpha)
# The best possible score of the most likely alive sequence
lower_bound_alive_scores = alive_log_probs[:, 0] / max_length_penalty
# Now to compute the lowest score of a finished sequence in finished
# If the sequence isn't finished, we multiply it's score by 0. since
# scores are all -ve, taking the min will give us the score of the lowest
# finished item.
lowest_score_of_fininshed_in_finished = layers.reduce_min(
finished_scores * finished_in_finished, 1)
# If none of the sequences have finished, then the min will be 0 and
# we have to replace it by -ve INF if it is. The score of any seq in alive
# will be much higher than -ve INF and the termination condition will not
# be met.
lowest_score_of_fininshed_in_finished += (
1. - layers.reduce_max(finished_in_finished, 1)) * -inf
bound_is_met = layers.reduce_all(
layers.greater_than(lowest_score_of_fininshed_in_finished,
lower_bound_alive_scores))
return bound_is_met
def grow_topk(i, logits, alive_seq, alive_log_probs, states):
logits = layers.reshape(logits, [batch_size, beam_size, -1])
candidate_log_probs = layers.log(layers.softmax(logits, axis=2))
log_probs = layers.elementwise_add(candidate_log_probs,
alive_log_probs, 0)
length_penalty = np.power(5.0 + (i + 1.0) / 6.0, alpha)
curr_scores = log_probs / length_penalty
flat_curr_scores = layers.reshape(curr_scores, [batch_size, -1])
topk_scores, topk_ids = layers.topk(flat_curr_scores,
k=beam_size * 2)
topk_log_probs = topk_scores * length_penalty
topk_beam_index = topk_ids // self.trg_vocab_size
topk_ids = topk_ids % self.trg_vocab_size
# use gather as gather_nd, TODO: use gather_nd
topk_seq = gather_2d_by_gather(alive_seq, topk_beam_index,
beam_size, batch_size)
topk_seq = layers.concat(
[topk_seq,
layers.reshape(topk_ids, topk_ids.shape + [1])],
axis=2)
states = update_states(states, topk_beam_index, beam_size)
eos = layers.fill_constant(shape=topk_ids.shape,
dtype="int64",
value=eos_id)
topk_finished = layers.cast(layers.equal(topk_ids, eos), "float32")
#topk_seq: [batch_size, 2*beam_size, i+1]
#topk_log_probs, topk_scores, topk_finished: [batch_size, 2*beam_size]
return topk_seq, topk_log_probs, topk_scores, topk_finished, states
def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished,
states):
curr_scores += curr_finished * -inf
_, topk_indexes = layers.topk(curr_scores, k=beam_size)
alive_seq = gather_2d_by_gather(curr_seq, topk_indexes,
beam_size * 2, batch_size)
alive_log_probs = gather_2d_by_gather(curr_log_probs, topk_indexes,
beam_size * 2, batch_size)
states = update_states(states, topk_indexes, beam_size * 2)
return alive_seq, alive_log_probs, states
def grow_finished(finished_seq, finished_scores, finished_flags,
curr_seq, curr_scores, curr_finished):
# finished scores
finished_seq = layers.concat([
finished_seq,
layers.fill_constant(shape=[batch_size, beam_size, 1],
dtype="int64",
value=eos_id)
],
axis=2)
# Set the scores of the unfinished seq in curr_seq to large negative
# values
curr_scores += (1. - curr_finished) * -inf
# concatenating the sequences and scores along beam axis
curr_finished_seq = layers.concat([finished_seq, curr_seq], axis=1)
curr_finished_scores = layers.concat([finished_scores, curr_scores],
axis=1)
curr_finished_flags = layers.concat([finished_flags, curr_finished],
axis=1)
_, topk_indexes = layers.topk(curr_finished_scores, k=beam_size)
finished_seq = gather_2d_by_gather(curr_finished_seq, topk_indexes,
beam_size * 3, batch_size)
finished_scores = gather_2d_by_gather(curr_finished_scores,
topk_indexes, beam_size * 3,
batch_size)
finished_flags = gather_2d_by_gather(curr_finished_flags,
topk_indexes, beam_size * 3,
batch_size)
return finished_seq, finished_scores, finished_flags
for i in range(max_len):
trg_pos = layers.fill_constant(shape=trg_word.shape,
dtype="int64",
value=i)
logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
enc_output, caches)
topk_seq, topk_log_probs, topk_scores, topk_finished, states = grow_topk(
i, logits, alive_seq, alive_log_probs, caches)
alive_seq, alive_log_probs, states = grow_alive(
topk_seq, topk_scores, topk_log_probs, topk_finished, states)
finished_seq, finished_scores, finished_flags = grow_finished(
finished_seq, finished_scores, finished_flags, topk_seq,
topk_scores, topk_finished)
trg_word = layers.reshape(alive_seq[:, :, -1],
[batch_size * beam_size, 1])
if early_finish(alive_log_probs, finished_scores,
finished_flags).numpy():
break
return finished_seq, finished_scores
def beam_search(self,
src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=0,
eos_id=1,
beam_size=4,
max_len=256):
if beam_size == 1:
return self._greedy_search(src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=bos_id,
eos_id=eos_id,
max_len=max_len)
else:
return self._beam_search(src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=bos_id,
eos_id=eos_id,
beam_size=beam_size,
max_len=max_len)
def _beam_search(self,
src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=0,
eos_id=1,
beam_size=4,
max_len=256):
def expand_to_beam_size(tensor, beam_size):
tensor = layers.reshape(tensor,
[tensor.shape[0], 1] + tensor.shape[1:])
tile_dims = [1] * len(tensor.shape)
tile_dims[1] = beam_size
return layers.expand(tensor, tile_dims)
def merge_batch_beams(tensor):
return layers.reshape(tensor, [tensor.shape[0] * tensor.shape[1]] +
tensor.shape[2:])
def split_batch_beams(tensor):
return layers.reshape(tensor,
shape=[-1, beam_size] +
list(tensor.shape[1:]))
def mask_probs(probs, finished, noend_mask_tensor):
# TODO: use where_op
finished = layers.cast(finished, dtype=probs.dtype)
probs = layers.elementwise_mul(layers.expand(
layers.unsqueeze(finished, [2]), [1, 1, self.trg_vocab_size]),
noend_mask_tensor,
axis=-1) - layers.elementwise_mul(
probs, (finished - 1), axis=0)
return probs
def gather(x, indices, batch_pos):
topk_coordinates = layers.stack([batch_pos, indices], axis=2)
return layers.gather_nd(x, topk_coordinates)
def update_states(func, caches):
for cache in caches: # no need to update static_kv
cache["k"] = func(cache["k"])
cache["v"] = func(cache["v"])
return caches
# run encoder
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
# constant number class TransfomerCell(object):
inf = float(1. * 1e7) """
batch_size = enc_output.shape[0] Let inputs=(trg_word, trg_pos), states=cache to make Transformer can be
max_len = (enc_output.shape[1] + 20) if max_len is None else max_len used as RNNCell
vocab_size_tensor = layers.fill_constant(shape=[1], """
dtype="int64", def __init__(self, decoder):
value=self.trg_vocab_size) self.decoder = decoder
end_token_tensor = to_variable(
np.full([batch_size, beam_size], eos_id, dtype="int64")) def __call__(self, inputs, states, trg_src_attn_bias, enc_output,
noend_array = [-inf] * self.trg_vocab_size static_caches):
noend_array[eos_id] = 0 trg_word, trg_pos = inputs
noend_mask_tensor = to_variable(np.array(noend_array,dtype="float32")) for cache, static_cache in zip(states, static_caches):
batch_pos = layers.expand( cache.update(static_cache)
layers.unsqueeze( logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
to_variable(np.arange(0, batch_size, 1, dtype="int64")), [1]), enc_output, states)
[1, beam_size]) new_states = [{"k": cache["k"], "v": cache["v"]} for cache in states]
return logits, new_states
predict_ids = []
parent_ids = []
### initialize states of beam search ###
log_probs = to_variable(
np.array([[0.] + [-inf] * (beam_size - 1)] * batch_size,
dtype="float32"))
finished = to_variable(np.full([batch_size, beam_size], 0,
dtype="bool"))
### initialize inputs and states of transformer decoder ###
## init inputs for decoder, shaped `[batch_size*beam_size, ...]`
trg_word = layers.fill_constant(shape=[batch_size * beam_size, 1],
dtype="int64",
value=bos_id)
trg_pos = layers.zeros_like(trg_word)
trg_src_attn_bias = merge_batch_beams(
expand_to_beam_size(trg_src_attn_bias, beam_size))
enc_output = merge_batch_beams(expand_to_beam_size(enc_output, beam_size))
## init states (caches) for transformer, need to be updated according to selected beam
caches = [{
"k":
layers.fill_constant(
shape=[batch_size * beam_size, self.n_head, 0, self.d_key],
dtype=enc_output.dtype,
value=0),
"v":
layers.fill_constant(
shape=[batch_size * beam_size, self.n_head, 0, self.d_value],
dtype=enc_output.dtype,
value=0),
} for i in range(self.n_layer)]
for i in range(max_len):
trg_pos = layers.fill_constant(shape=trg_word.shape,
dtype="int64",
value=i)
caches = update_states( # can not be reshaped since the 0 size
lambda x: x if i == 0 else merge_batch_beams(x), caches)
logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
enc_output, caches)
caches = update_states(split_batch_beams, caches)
step_log_probs = split_batch_beams(
layers.log(layers.softmax(logits)))
step_log_probs = mask_probs(step_log_probs, finished,
noend_mask_tensor)
log_probs = layers.elementwise_add(x=step_log_probs,
y=log_probs,
axis=0)
log_probs = layers.reshape(log_probs,
[-1, beam_size * self.trg_vocab_size])
scores = log_probs
topk_scores, topk_indices = layers.topk(input=scores, k=beam_size)
beam_indices = layers.elementwise_floordiv(
topk_indices, vocab_size_tensor)
token_indices = layers.elementwise_mod(
topk_indices, vocab_size_tensor)
# update states
caches = update_states(lambda x: gather(x, beam_indices, batch_pos),
caches)
log_probs = gather(log_probs, topk_indices, batch_pos)
finished = gather(finished, beam_indices, batch_pos)
finished = layers.logical_or(
finished, layers.equal(token_indices, end_token_tensor))
trg_word = layers.reshape(token_indices, [-1, 1])
predict_ids.append(token_indices)
parent_ids.append(beam_indices)
if layers.reduce_all(finished).numpy():
break
predict_ids = layers.stack(predict_ids, axis=0)
parent_ids = layers.stack(parent_ids, axis=0)
finished_seq = layers.transpose(
layers.gather_tree(predict_ids, parent_ids), [1, 2, 0])
finished_scores = topk_scores
return finished_seq, finished_scores
def _greedy_search(self,
src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=0,
eos_id=1,
max_len=256):
# run encoder
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
# constant number class InferTransformer(Transformer):
batch_size = enc_output.shape[0] """
max_len = (enc_output.shape[1] + 20) if max_len is None else max_len model for prediction
end_token_tensor = layers.fill_constant(shape=[batch_size, 1], """
dtype="int64", def __init__(self,
value=eos_id) src_vocab_size,
trg_vocab_size,
predict_ids = [] max_length,
log_probs = layers.fill_constant(shape=[batch_size, 1], n_layer,
dtype="float32", n_head,
value=0) d_key,
trg_word = layers.fill_constant(shape=[batch_size, 1], d_value,
dtype="int64", d_model,
value=bos_id) d_inner_hid,
finished = layers.fill_constant(shape=[batch_size, 1], prepostprocess_dropout,
dtype="bool", attention_dropout,
value=0) relu_dropout,
preprocess_cmd,
## init states (caches) for transformer postprocess_cmd,
weight_sharing,
bos_id=0,
eos_id=1,
beam_size=4,
max_out_len=256):
args = locals()
args.pop("self")
self.beam_size = args.pop("beam_size")
self.max_out_len = args.pop("max_out_len")
super(InferTransformer, self).__init__(**args)
cell = TransfomerCell(self.decoder)
self.beam_search_decoder = DynamicDecode(
TransformerBeamSearchDecoder(cell,
bos_id,
eos_id,
beam_size,
var_dim_in_state=2), max_out_len)
@shape_hints(src_word=[None, None],
src_pos=[None, None],
src_slf_attn_bias=[None, 8, None, None],
trg_src_attn_bias=[None, 8, None, None])
def forward(self, src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias):
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
## init states (caches) for transformer, need to be updated according to selected beam
caches = [{ caches = [{
"k": "k":
layers.fill_constant( layers.fill_constant_batch_size_like(
shape=[batch_size, self.n_head, 0, self.d_key], input=enc_output,
shape=[-1, self.n_head, 0, self.d_key],
dtype=enc_output.dtype, dtype=enc_output.dtype,
value=0), value=0),
"v": "v":
layers.fill_constant( layers.fill_constant_batch_size_like(
shape=[batch_size, self.n_head, 0, self.d_value], input=enc_output,
shape=[-1, self.n_head, 0, self.d_value],
dtype=enc_output.dtype, dtype=enc_output.dtype,
value=0), value=0),
} for i in range(self.n_layer)] } for i in range(self.n_layer)]
enc_output = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(
for i in range(max_len): enc_output, self.beam_size)
trg_pos = layers.fill_constant(shape=trg_word.shape, trg_src_attn_bias = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(
dtype="int64", trg_src_attn_bias, self.beam_size)
value=i) static_caches = self.decoder.decoder.prepare_static_cache(
logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, enc_output)
enc_output, caches) rs, _ = self.beam_search_decoder(inits=caches,
step_log_probs = layers.log(layers.softmax(logits)) enc_output=enc_output,
log_probs = layers.elementwise_add(x=step_log_probs, trg_src_attn_bias=trg_src_attn_bias,
y=log_probs, static_caches=static_caches)
axis=0) return rs
scores = log_probs
topk_scores, topk_indices = layers.topk(input=scores, k=1)
finished = layers.logical_or(
finished, layers.equal(topk_indices, end_token_tensor))
trg_word = topk_indices
log_probs = topk_scores
predict_ids.append(topk_indices)
if layers.reduce_all(finished).numpy():
break
predict_ids = layers.stack(predict_ids, axis=0)
finished_seq = layers.transpose(predict_ids, [1, 2, 0])
finished_scores = topk_scores
return finished_seq, finished_scores
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册