diff --git a/transformer/predict.py b/transformer/predict.py index d33d4e5c909d9565dccbddaf3181e5a0b56c0d88..2fae99fc7510a7d65d2c3c765b5dfa0a11ef26e2 100644 --- a/transformer/predict.py +++ b/transformer/predict.py @@ -16,7 +16,9 @@ import logging import os import six import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import time +import contextlib import numpy as np import paddle @@ -27,10 +29,11 @@ from utils.check import check_gpu, check_version # include task-specific libs import reader -from model import Transformer, position_encoding_init +from transformer import InferTransformer, position_encoding_init -def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False): +def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, + output_eos=False): """ Post-process the decoded sequence. """ @@ -47,10 +50,13 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False): def do_predict(args): - if args.use_cuda: - place = fluid.CUDAPlace(0) - else: - place = fluid.CPUPlace() + device_ids = list(range(args.num_devices)) + + @contextlib.contextmanager + def null_guard(): + yield + + guard = fluid.dygraph.guard() if args.eager_run else null_guard() # define the data generator processor = reader.DataProcessor(fpattern=args.predict_file, @@ -69,68 +75,61 @@ def do_predict(args): unk_mark=args.special_token[2], max_length=args.max_length, n_head=args.n_head) - batch_generator = processor.data_generator(phase="predict", place=place) + batch_generator = processor.data_generator(phase="predict") args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ args.unk_idx = processor.get_vocab_summary() trg_idx2word = reader.DataProcessor.load_dict( dict_path=args.trg_vocab_fpath, reverse=True) - args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ - args.unk_idx = processor.get_vocab_summary() - - with fluid.dygraph.guard(place): + with guard: # define data loader - test_loader = fluid.io.DataLoader.from_generator(capacity=10) - test_loader.set_batch_generator(batch_generator, places=place) + test_loader = batch_generator # define model - transformer = Transformer( - args.src_vocab_size, args.trg_vocab_size, args.max_length + 1, - args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model, - args.d_inner_hid, args.prepostprocess_dropout, - args.attention_dropout, args.relu_dropout, args.preprocess_cmd, - args.postprocess_cmd, args.weight_sharing, args.bos_idx, - args.eos_idx) + transformer = InferTransformer(args.src_vocab_size, + args.trg_vocab_size, + args.max_length + 1, + args.n_layer, + args.n_head, + args.d_key, + args.d_value, + args.d_model, + args.d_inner_hid, + args.prepostprocess_dropout, + args.attention_dropout, + args.relu_dropout, + args.preprocess_cmd, + args.postprocess_cmd, + args.weight_sharing, + args.bos_idx, + args.eos_idx, + beam_size=args.beam_size, + max_out_len=args.max_out_len) # load the trained model assert args.init_from_params, ( "Please set init_from_params to load the infer model.") - model_dict, _ = fluid.load_dygraph( - os.path.join(args.init_from_params, "transformer")) - # to avoid a longer length than training, reset the size of position - # encoding to max_length - model_dict["encoder.pos_encoder.weight"] = position_encoding_init( - args.max_length + 1, args.d_model) - model_dict["decoder.pos_encoder.weight"] = position_encoding_init( - args.max_length + 1, args.d_model) - transformer.load_dict(model_dict) - - # set evaluate mode - transformer.eval() + transformer.load(os.path.join(args.init_from_params, "transformer")) f = open(args.output_file, "wb") for input_data in test_loader(): (src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias) = input_data - finished_seq, finished_scores = transformer.beam_search( - src_word, - src_pos, - src_slf_attn_bias, - trg_word, - trg_src_attn_bias, - bos_id=args.bos_idx, - eos_id=args.eos_idx, - beam_size=args.beam_size, - max_len=args.max_out_len) - finished_seq = finished_seq.numpy() - finished_scores = finished_scores.numpy() + finished_seq = transformer.test(inputs=(src_word, src_pos, + src_slf_attn_bias, + trg_src_attn_bias), + device='gpu', + device_ids=device_ids)[0] + finished_seq = np.transpose(finished_seq, [0, 2, 1]) for ins in finished_seq: for beam_idx, beam in enumerate(ins): if beam_idx >= args.n_best: break - id_list = post_process_seq(beam, args.bos_idx, args.eos_idx) + id_list = post_process_seq(beam, args.bos_idx, + args.eos_idx) word_list = [trg_idx2word[id] for id in id_list] sequence = b" ".join(word_list) + b"\n" f.write(sequence) + break if __name__ == "__main__": diff --git a/transformer/reader.py b/transformer/reader.py index ef23c5e1e32fa4cee1ba5a42bb970a1a135879a0..5e1a5dd77c128ae6978e44bafe48a028b11dac1c 100644 --- a/transformer/reader.py +++ b/transformer/reader.py @@ -114,7 +114,7 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head): return data_inputs -def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head, place): +def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head): """ Put all padded data needed by beam search decoder into a list. """ @@ -517,7 +517,7 @@ class DataProcessor(object): return __impl__ - def data_generator(self, phase, place=None): + def data_generator(self, phase): # Any token included in dict can be used to pad, since the paddings' loss # will be masked out by weights and make no effect on parameter gradients. src_pad_idx = trg_pad_idx = self._eos_idx @@ -540,7 +540,7 @@ class DataProcessor(object): def __for_predict__(): for data in data_reader(): data_inputs = prepare_infer_input(data, src_pad_idx, bos_idx, - n_head, place) + n_head) yield data_inputs return __for_train__ if phase == "train" else __for_predict__ diff --git a/transformer/rnn_api.py b/transformer/rnn_api.py new file mode 100644 index 0000000000000000000000000000000000000000..f2c231050aca257781894e0d2e4b59ba17272a3f --- /dev/null +++ b/transformer/rnn_api.py @@ -0,0 +1,778 @@ +import collections +import contextlib +import inspect +import six +import sys +from functools import partial, reduce + +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers.utils as utils +from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as +from paddle.fluid.dygraph import to_variable, Embedding, Linear +from paddle.fluid.data_feeder import convert_dtype + +from paddle.fluid import layers +from paddle.fluid.dygraph import Layer + + +class RNNUnit(Layer): + def get_initial_states(self, + batch_ref, + shape=None, + dtype=None, + init_value=0, + batch_dim_idx=0): + """ + Generate initialized states according to provided shape, data type and + value. + + Parameters: + batch_ref: A (possibly nested structure of) tensor variable[s]. + The first dimension of the tensor will be used as batch size to + initialize states. + shape: A (possiblely nested structure of) shape[s], where a shape is + represented as a list/tuple of integer). -1(for batch size) will + beautomatically inserted if shape is not started with it. If None, + property `state_shape` will be used. The default value is None. + dtype: A (possiblely nested structure of) data type[s]. The structure + must be same as that of `shape`, except when all tensors' in states + has the same data type, a single data type can be used. If None and + property `cell.state_shape` is not available, float32 will be used + as the data type. The default value is None. + init_value: A float value used to initialize states. + + Returns: + Variable: tensor variable[s] packed in the same structure provided \ + by shape, representing the initialized states. + """ + # TODO: use inputs and batch_size + batch_ref = flatten(batch_ref)[0] + + def _is_shape_sequence(seq): + if sys.version_info < (3, ): + integer_types = ( + int, + long, + ) + else: + integer_types = (int, ) + """For shape, list/tuple of integer is the finest-grained objection""" + if (isinstance(seq, list) or isinstance(seq, tuple)): + if reduce( + lambda flag, x: isinstance(x, integer_types) and flag, + seq, True): + return False + # TODO: Add check for the illegal + if isinstance(seq, dict): + return True + return (isinstance(seq, collections.Sequence) + and not isinstance(seq, six.string_types)) + + class Shape(object): + def __init__(self, shape): + self.shape = shape if shape[0] == -1 else ([-1] + list(shape)) + + # nested structure of shapes + states_shapes = self.state_shape if shape is None else shape + is_sequence_ori = utils.is_sequence + utils.is_sequence = _is_shape_sequence + states_shapes = map_structure(lambda shape: Shape(shape), + states_shapes) + utils.is_sequence = is_sequence_ori + + # nested structure of dtypes + try: + states_dtypes = self.state_dtype if dtype is None else dtype + except NotImplementedError: # use fp32 as default + states_dtypes = "float32" + if len(flatten(states_dtypes)) == 1: + dtype = flatten(states_dtypes)[0] + states_dtypes = map_structure(lambda shape: dtype, states_shapes) + + init_states = map_structure( + lambda shape, dtype: fluid.layers.fill_constant_batch_size_like( + input=batch_ref, + shape=shape.shape, + dtype=dtype, + value=init_value, + input_dim_idx=batch_dim_idx), states_shapes, states_dtypes) + return init_states + + @property + def state_shape(self): + """ + Abstract method (property). + Used to initialize states. + A (possiblely nested structure of) shape[s], where a shape is represented + as a list/tuple of integers (-1 for batch size would be automatically + inserted into a shape if shape is not started with it). + Not necessary to be implemented if states are not initialized by + `get_initial_states` or the `shape` argument is provided when using + `get_initial_states`. + """ + raise NotImplementedError( + "Please add implementaion for `state_shape` in the used cell.") + + @property + def state_dtype(self): + """ + Abstract method (property). + Used to initialize states. + A (possiblely nested structure of) data types[s]. The structure must be + same as that of `shape`, except when all tensors' in states has the same + data type, a signle data type can be used. + Not necessary to be implemented if states are not initialized + by `get_initial_states` or the `dtype` argument is provided when using + `get_initial_states`. + """ + raise NotImplementedError( + "Please add implementaion for `state_dtype` in the used cell.") + + +class BasicLSTMUnit(RNNUnit): + """ + **** + BasicLSTMUnit class, Using basic operator to build LSTM + The algorithm can be described as the code below. + .. math:: + i_t &= \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + b_i) + f_t &= \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + b_f + forget_bias ) + o_t &= \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + b_o) + \\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + b_c) + c_t &= f_t \odot c_{t-1} + i_t \odot \\tilde{c_t} + h_t &= o_t \odot tanh(c_t) + - $W$ terms denote weight matrices (e.g. $W_{ix}$ is the matrix + of weights from the input gate to the input) + - The b terms denote bias vectors ($bx_i$ and $bh_i$ are the input gate bias vector). + - sigmoid is the logistic sigmoid function. + - $i, f, o$ and $c$ are the input gate, forget gate, output gate, + and cell activation vectors, respectively, all of which have the same size as + the cell output activation vector $h$. + - The :math:`\odot` is the element-wise product of the vectors. + - :math:`tanh` is the activation functions. + - :math:`\\tilde{c_t}` is also called candidate hidden state, + which is computed based on the current input and the previous hidden state. + Args: + name_scope(string) : The name scope used to identify parameter and bias name + hidden_size (integer): The hidden size used in the Unit. + param_attr(ParamAttr|None): The parameter attribute for the learnable + weight matrix. Note: + If it is set to None or one attribute of ParamAttr, lstm_unit will + create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with Xavier. Default: None. + bias_attr (ParamAttr|None): The parameter attribute for the bias + of LSTM unit. + If it is set to None or one attribute of ParamAttr, lstm_unit will + create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized as zero. Default: None. + gate_activation (function|None): The activation function for gates (actGate). + Default: 'fluid.layers.sigmoid' + activation (function|None): The activation function for cells (actNode). + Default: 'fluid.layers.tanh' + forget_bias(float|1.0): forget bias used when computing forget gate + dtype(string): data type used in this unit + """ + def __init__(self, + hidden_size, + input_size, + param_attr=None, + bias_attr=None, + gate_activation=None, + activation=None, + forget_bias=1.0, + dtype='float32'): + super(BasicLSTMUnit, self).__init__(dtype) + + self._hidden_size = hidden_size + self._param_attr = param_attr + self._bias_attr = bias_attr + self._gate_activation = gate_activation or layers.sigmoid + self._activation = activation or layers.tanh + self._forget_bias = layers.fill_constant([1], + dtype=dtype, + value=forget_bias) + self._forget_bias.stop_gradient = False + self._dtype = dtype + self._input_size = input_size + + self._weight = self.create_parameter( + attr=self._param_attr, + shape=[ + self._input_size + self._hidden_size, 4 * self._hidden_size + ], + dtype=self._dtype) + + self._bias = self.create_parameter(attr=self._bias_attr, + shape=[4 * self._hidden_size], + dtype=self._dtype, + is_bias=True) + + def forward(self, input, state): + pre_hidden, pre_cell = state + concat_input_hidden = layers.concat([input, pre_hidden], 1) + gate_input = layers.matmul(x=concat_input_hidden, y=self._weight) + + gate_input = layers.elementwise_add(gate_input, self._bias) + i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1) + new_cell = layers.elementwise_add( + layers.elementwise_mul( + pre_cell, + layers.sigmoid(layers.elementwise_add(f, self._forget_bias))), + layers.elementwise_mul(layers.sigmoid(i), layers.tanh(j))) + new_hidden = layers.tanh(new_cell) * layers.sigmoid(o) + + return new_hidden, [new_hidden, new_cell] + + @property + def state_shape(self): + return [[self._hidden_size], [self._hidden_size]] + + +class RNN(fluid.dygraph.Layer): + def __init__(self, cell, is_reverse=False, time_major=False): + super(RNN, self).__init__() + self.cell = cell + if not hasattr(self.cell, "call"): + self.cell.call = self.cell.forward + self.is_reverse = is_reverse + self.time_major = time_major + self.batch_index, self.time_step_index = (1, 0) if time_major else (0, + 1) + + def forward(self, + inputs, + initial_states=None, + sequence_length=None, + **kwargs): + if fluid.in_dygraph_mode(): + + class ArrayWrapper(object): + def __init__(self, x): + self.array = [x] + + def append(self, x): + self.array.append(x) + return self + + def _maybe_copy(state, new_state, step_mask): + # TODO: use where_op + new_state = fluid.layers.elementwise_mul( + new_state, step_mask, + axis=0) - fluid.layers.elementwise_mul(state, + (step_mask - 1), + axis=0) + return new_state + + flat_inputs = flatten(inputs) + batch_size, time_steps = ( + flat_inputs[0].shape[self.batch_index], + flat_inputs[0].shape[self.time_step_index]) + + if initial_states is None: + initial_states = self.cell.get_initial_states( + batch_ref=inputs, batch_dim_idx=self.batch_index) + + if not self.time_major: + inputs = map_structure( + lambda x: fluid.layers.transpose(x, [1, 0] + list( + range(2, len(x.shape)))), inputs) + + if sequence_length: + mask = fluid.layers.sequence_mask( + sequence_length, + maxlen=time_steps, + dtype=flatten(initial_states)[0].dtype) + mask = fluid.layers.transpose(mask, [1, 0]) + + if self.is_reverse: + inputs = map_structure( + lambda x: fluid.layers.reverse(x, axis=[0]), inputs) + mask = fluid.layers.reverse( + mask, axis=[0]) if sequence_length else None + + states = initial_states + outputs = [] + for i in range(time_steps): + step_inputs = map_structure(lambda x: x[i], inputs) + step_outputs, new_states = self.cell(step_inputs, states, + **kwargs) + if sequence_length: + new_states = map_structure( + partial(_maybe_copy, step_mask=mask[i]), states, + new_states) + states = new_states + outputs = map_structure( + lambda x: ArrayWrapper(x), + step_outputs) if i == 0 else map_structure( + lambda x, x_array: x_array.append(x), step_outputs, + outputs) + + final_outputs = map_structure( + lambda x: fluid.layers.stack(x.array, + axis=self.time_step_index), + outputs) + + if self.is_reverse: + final_outputs = map_structure( + lambda x: fluid.layers.reverse(x, + axis=self.time_step_index), + final_outputs) + + final_states = new_states + else: + final_outputs, final_states = fluid.layers.rnn( + self.cell, + inputs, + initial_states=initial_states, + sequence_length=sequence_length, + time_major=self.time_major, + is_reverse=self.is_reverse, + **kwargs) + return final_outputs, final_states + + +from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, to_variable +place = fluid.CPUPlace() +executor = fluid.Executor(place) + + +class EncoderCell(RNNUnit): + def __init__(self, num_layers, input_size, hidden_size, dropout_prob=0.): + super(EncoderCell, self).__init__() + self.num_layers = num_layers + self.dropout_prob = dropout_prob + + self.lstm_cells = list() + for i in range(self.num_layers): + self.lstm_cells.append( + self.add_sublayer( + "layer_%d" % i, + BasicLSTMUnit(input_size if i == 0 else hidden_size, + hidden_size))) + + def forward(self, step_input, states): + new_states = [] + for i in range(self.num_layers): + out, new_state = self.lstm_cells[i](step_input, states[i]) + step_input = layers.dropout( + out, self.dropout_prob) if self.dropout_prob > 0 else out + new_states.append(new_state) + return step_input, new_states + + @property + def state_shape(self): + return [cell.state_shape for cell in self.lstm_cells] + + +class MultiHeadAttention(Layer): + """ + Multi-Head Attention + """ + + # def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): + # pass + + # def forward(self, queries, keys, values, attn_bias, cache=None): + # pass + + def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.): + super(MultiHeadAttention, self).__init__() + self.n_head = n_head + self.d_key = d_key + self.d_value = d_value + self.d_model = d_model + self.dropout_rate = dropout_rate + self.q_fc = Linear(input_dim=d_model, + output_dim=d_key * n_head, + bias_attr=False) + self.k_fc = Linear(input_dim=d_model, + output_dim=d_key * n_head, + bias_attr=False) + self.v_fc = Linear(input_dim=d_model, + output_dim=d_value * n_head, + bias_attr=False) + self.proj_fc = Linear(input_dim=d_value * n_head, + output_dim=d_model, + bias_attr=False) + + def forward(self, queries, keys, values, attn_bias, cache=None): + # compute q ,k ,v + keys = queries if keys is None else keys + values = keys if values is None else values + + q = self.q_fc(queries) + k = self.k_fc(keys) + v = self.v_fc(values) + + # split head + q = layers.reshape(x=q, shape=[0, 0, self.n_head, self.d_key]) + q = layers.transpose(x=q, perm=[0, 2, 1, 3]) + k = layers.reshape(x=k, shape=[0, 0, self.n_head, self.d_key]) + k = layers.transpose(x=k, perm=[0, 2, 1, 3]) + v = layers.reshape(x=v, shape=[0, 0, self.n_head, self.d_value]) + v = layers.transpose(x=v, perm=[0, 2, 1, 3]) + + if cache is not None: + cache_k, cache_v = cache["k"], cache["v"] + k = layers.concat([cache_k, k], axis=2) + v = layers.concat([cache_v, v], axis=2) + cache["k"], cache["v"] = k, v + + # scale dot product attention + product = layers.matmul(x=q, + y=k, + transpose_y=True, + alpha=self.d_model**-0.5) + if attn_bias: + product += attn_bias + weights = layers.softmax(product) + if self.dropout_rate: + weights = layers.dropout(weights, + dropout_prob=self.dropout_rate, + is_test=False) + + out = layers.matmul(weights, v) + + # combine heads + out = layers.transpose(out, perm=[0, 2, 1, 3]) + out = layers.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.proj_fc(out) + return out + + +class DynamicDecode(Layer): + def __init__(self, + decoder, + max_step_num=None, + output_time_major=False, + impute_finished=False, + is_test=False, + return_length=False): + super(DynamicDecode, self).__init__() + self.decoder = decoder + self.max_step_num = max_step_num + self.output_time_major = output_time_major + self.impute_finished = impute_finished + self.is_test = is_test + self.return_length = return_length + + def forward(self, inits=None, **kwargs): + if fluid.in_dygraph_mode(): + + class ArrayWrapper(object): + def __init__(self, x): + self.array = [x] + + def append(self, x): + self.array.append(x) + return self + + def __getitem__(self, item): + return self.array.__getitem__(item) + + def _maybe_copy(state, new_state, step_mask): + # TODO: use where_op + state_dtype = state.dtype + if convert_dtype(state_dtype) in ["bool"]: + state = layers.cast(state, dtype="float32") + new_state = layers.cast(new_state, dtype="float32") + if step_mask.dtype != state.dtype: + step_mask = layers.cast(step_mask, dtype=state.dtype) + # otherwise, renamed bool gradients of would be summed up leading + # to sum(bool) error. + step_mask.stop_gradient = True + new_state = layers.elementwise_mul( + state, step_mask, axis=0) - layers.elementwise_mul( + new_state, (step_mask - 1), axis=0) + if convert_dtype(state_dtype) in ["bool"]: + new_state = layers.cast(new_state, dtype=state_dtype) + return new_state + + initial_inputs, initial_states, initial_finished = self.decoder.initialize( + inits) + inputs, states, finished = (initial_inputs, initial_states, + initial_finished) + cond = layers.logical_not((layers.reduce_all(initial_finished))) + sequence_lengths = layers.cast(layers.zeros_like(initial_finished), + "int64") + outputs = None + + step_idx = 0 + step_idx_tensor = layers.fill_constant(shape=[1], + dtype="int64", + value=step_idx) + while cond.numpy(): + (step_outputs, next_states, next_inputs, + next_finished) = self.decoder.step(step_idx_tensor, inputs, + states, **kwargs) + next_finished = layers.logical_or(next_finished, finished) + next_sequence_lengths = layers.elementwise_add( + sequence_lengths, + layers.cast(layers.logical_not(finished), + sequence_lengths.dtype)) + + if self.impute_finished: # rectify the states for the finished. + next_states = map_structure( + lambda x, y: _maybe_copy(x, y, finished), states, + next_states) + outputs = map_structure( + lambda x: ArrayWrapper(x), + step_outputs) if step_idx == 0 else map_structure( + lambda x, x_array: x_array.append(x), step_outputs, + outputs) + inputs, states, finished, sequence_lengths = ( + next_inputs, next_states, next_finished, + next_sequence_lengths) + + layers.increment(x=step_idx_tensor, value=1.0, in_place=True) + step_idx += 1 + + layers.logical_not(layers.reduce_all(finished), cond) + if self.max_step_num is not None and step_idx > self.max_step_num: + break + + final_outputs = map_structure( + lambda x: fluid.layers.stack(x.array, axis=0), outputs) + final_states = states + + try: + final_outputs, final_states = self.decoder.finalize( + final_outputs, final_states, sequence_lengths) + except NotImplementedError: + pass + + if not self.output_time_major: + final_outputs = map_structure( + lambda x: layers.transpose(x, [1, 0] + list( + range(2, len(x.shape)))), final_outputs) + + return (final_outputs, final_states, + sequence_lengths) if self.return_length else ( + final_outputs, final_states) + else: + return fluid.layers.dynamic_decode( + self.decoder, + inits, + max_step_num=self.max_step_num, + output_time_major=self.output_time_major, + impute_finished=self.impute_finished, + is_test=self.is_test, + return_length=self.return_length, + **kwargs) + + +class TransfomerCell(object): + """ + Let inputs=(trg_word, trg_pos), states=cache to make Transformer can be + used as RNNCell + """ + def __init__(self, decoder): + self.decoder = decoder + + def __call__(self, inputs, states, trg_src_attn_bias, enc_output, + static_caches): + trg_word, trg_pos = inputs + for cache, static_cache in zip(states, static_caches): + cache.update(static_cache) + logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, + enc_output, states) + new_states = [{"k": cache["k"], "v": cache["v"]} for cache in states] + return logits, new_states + + +class TransformerBeamSearchDecoder(layers.BeamSearchDecoder): + def __init__(self, cell, start_token, end_token, beam_size, + var_dim_in_state): + super(TransformerBeamSearchDecoder, + self).__init__(cell, start_token, end_token, beam_size) + self.cell = cell + self.var_dim_in_state = var_dim_in_state + + def _merge_batch_beams_with_var_dim(self, x): + if not hasattr(self, "batch_size"): + self.batch_size = layers.shape(x)[0] + if not hasattr(self, "batch_beam_size"): + self.batch_beam_size = self.batch_size * self.beam_size + # init length of cache is 0, and it increases with decoding carrying on, + # thus need to reshape elaborately + var_dim_in_state = self.var_dim_in_state + 1 # count in beam dim + x = layers.transpose( + x, + list(range(var_dim_in_state, len(x.shape))) + + list(range(0, var_dim_in_state))) + x = layers.reshape(x, [0] * (len(x.shape) - var_dim_in_state) + + [self.batch_beam_size] + + list(x.shape[-var_dim_in_state + 2:])) + x = layers.transpose( + x, + list(range((len(x.shape) + 1 - var_dim_in_state), len(x.shape))) + + list(range(0, (len(x.shape) + 1 - var_dim_in_state)))) + return x + + def _split_batch_beams_with_var_dim(self, x): + var_dim_size = layers.shape(x)[self.var_dim_in_state] + x = layers.reshape( + x, [-1, self.beam_size] + list(x.shape[1:self.var_dim_in_state]) + + [var_dim_size] + list(x.shape[self.var_dim_in_state + 1:])) + return x + + def step(self, time, inputs, states, **kwargs): + # compared to RNN, Transformer has 3D data at every decoding step + inputs = layers.reshape(inputs, [-1, 1]) # token + pos = layers.ones_like(inputs) * time # pos + cell_states = map_structure(self._merge_batch_beams_with_var_dim, + states.cell_states) + + cell_outputs, next_cell_states = self.cell((inputs, pos), cell_states, + **kwargs) + cell_outputs = map_structure(self._split_batch_beams, cell_outputs) + next_cell_states = map_structure(self._split_batch_beams_with_var_dim, + next_cell_states) + + beam_search_output, beam_search_state = self._beam_search_step( + time=time, + logits=cell_outputs, + next_cell_states=next_cell_states, + beam_state=states) + next_inputs, finished = (beam_search_output.predicted_ids, + beam_search_state.finished) + + return (beam_search_output, beam_search_state, next_inputs, finished) + +''' +@contextlib.contextmanager +def eager_guard(is_eager): + if is_eager: + with fluid.dygraph.guard(): + yield + else: + yield + + +# print(flatten(np.random.rand(2,8,8))) +random_seed = 123 +np.random.seed(random_seed) +# print np.random.rand(2, 8) +batch_size = 2 +seq_len = 8 +hidden_size = 8 +vocab_size, embed_dim, num_layers, hidden_size = 100, 8, 2, 8 +bos_id, eos_id, beam_size, max_step_num = 0, 1, 5, 10 +time_major = False +eagar_run = False + +import torch + +with eager_guard(eagar_run): + fluid.default_main_program().random_seed = random_seed + fluid.default_startup_program().random_seed = random_seed + + inputs_data = np.random.rand(batch_size, seq_len, + hidden_size).astype("float32") + states_data = np.random.rand(batch_size, hidden_size).astype("float32") + + lstm_cell = BasicLSTMUnit(hidden_size=8, input_size=8) + lstm = RNN(cell=lstm_cell, time_major=time_major) + + inputs = to_variable(inputs_data) if eagar_run else fluid.data( + name="x", shape=[None, None, hidden_size], dtype="float32") + + states = lstm_cell.get_initial_states(batch_ref=inputs, + batch_dim_idx=1 if time_major else 0) + + out, _ = lstm(inputs, states) + # print states + + # print layers.BeamSearchDecoder.tile_beam_merge_with_batch(out, 5) + + # embedder = Embedding(size=(vocab_size, embed_dim)) + # output_layer = Linear(hidden_size, vocab_size) + # decoder = layers.BeamSearchDecoder(lstm_cell, + # bos_id, + # eos_id, + # beam_size, + # embedding_fn=embedder, + # output_fn=output_layer) + # dynamic_decoder = DynamicDecode(decoder, max_step_num) + # out,_ = dynamic_decoder(inits=states) + + # caches = [{ + # "k": + # layers.fill_constant_batch_size_like(out, + # shape=[-1, 8, 0, 64], + # dtype="float32", + # value=0), + # "v": + # layers.fill_constant_batch_size_like(out, + # shape=[-1, 8, 0, 64], + # dtype="float32", + # value=0) + # } for i in range(6)] + cache = layers.fill_constant_batch_size_like(out, + shape=[-1, 8, 0, 64], + dtype="float32", + value=0) + + print cache + # out = layers.BeamSearchDecoder.tile_beam_merge_with_batch(cache, 5) + # out = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(cache, 5) + # batch_beam_size = layers.shape(out)[0] * 5 + # print out + cell = TransfomerCell(None) + decoder = TransformerBeamSearchDecoder(cell, 0, 1, 5, 2) + cache = decoder._expand_to_beam_size(cache) + print cache + cache = decoder._merge_batch_beams_with_var_dim(cache) + print cache + cache1 = layers.fill_constant_batch_size_like(cache, + shape=[-1, 8, 1, 64], + dtype="float32", + value=0) + print cache1.shape + cache = layers.concat([cache, cache1], axis=2) + out = decoder._split_batch_beams_with_var_dim(cache) + # out = layers.transpose(out, + # list(range(3, len(out.shape))) + list(range(0, 3))) + # print out + # out = layers.reshape(out, list(out.shape[:2]) + [batch_beam_size, 8]) + # print out + # out = layers.transpose(out, [2,3,0,1]) + print out.shape + if eagar_run: + print "hehe" #out #.numpy() + else: + executor.run(fluid.default_startup_program()) + inputs = fluid.data(name="x", + shape=[None, None, hidden_size], + dtype="float32") + out_np = executor.run(feed={"x": inputs_data}, + fetch_list=[out.name])[0] + print np.array(out_np).shape + exit(0) + + # dygraph + # inputs = to_variable(inputs_data) + # states = lstm_cell.get_initial_states(batch_ref=inputs, + # batch_dim_idx=1 if time_major else 0) + + # print lstm(inputs, states)[0].numpy() + + # graph + executor.run(fluid.default_startup_program()) + inputs = fluid.data(name="x", + shape=[None, None, hidden_size], + dtype="float32") + states = lstm_cell.get_initial_states(batch_ref=inputs, + batch_dim_idx=1 if time_major else 0) + out, _ = lstm(inputs, states) + out_np = executor.run(feed={"x": inputs_data}, fetch_list=[out.name])[0] + print np.array(out_np) + + #print fluid.io.save_inference_model(dirname="test_model", feeded_var_names=["x"], target_vars=[out], executor=executor, model_filename="model.pdmodel", params_filename="params.pdparams") + # test_program, feed_target_names, fetch_targets = fluid.io.load_inference_model(dirname="test_model", executor=executor, model_filename="model.pdmodel", params_filename="params.pdparams") + # out = executor.run(program=test_program, feed={"x": np.random.rand(2, 8, 8).astype("float32")}, fetch_list=fetch_targets)[0] +''' \ No newline at end of file diff --git a/transformer/run.sh b/transformer/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..a55d7a7ac747636c2c4fcdfbec1a0f5160a7be05 --- /dev/null +++ b/transformer/run.sh @@ -0,0 +1,41 @@ +python -u train.py \ + --epoch 30 \ + --src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \ + --trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \ + --special_token '' '' '' \ + --training_file wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de.tiny \ + --validation_file wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \ + --batch_size 4096 \ + --print_step 1 \ + --use_cuda True \ + --random_seed 1000 \ + --save_step 10 \ + --eager_run True + #--init_from_pretrain_model base_model_dygraph/step_100000/ \ + #--init_from_checkpoint trained_models/step_200/transformer + #--n_head 16 \ + #--d_model 1024 \ + #--d_inner_hid 4096 \ + #--prepostprocess_dropout 0.3 +exit + +echo `date` + +python -u predict.py \ + --src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \ + --trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \ + --special_token '' '' '' \ + --predict_file wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \ + --batch_size 64 \ + --init_from_params base_model_dygraph/step_100000/ \ + --beam_size 5 \ + --max_out_len 255 \ + --output_file predict.txt \ + --eager_run True + #--max_length 500 \ + #--n_head 16 \ + #--d_model 1024 \ + #--d_inner_hid 4096 \ + #--prepostprocess_dropout 0.3 + +echo `date` \ No newline at end of file diff --git a/transformer/transformer.py b/transformer/transformer.py index e1011a55db16baa877b38044db3185a1b9def4e0..9579dc32882cc4c7d6b1c5b865d94ad7fce52907 100644 --- a/transformer/transformer.py +++ b/transformer/transformer.py @@ -189,6 +189,15 @@ class MultiHeadAttention(Layer): out = self.proj_fc(out) return out + def cal_kv(self, keys, values): + k = self.k_fc(keys) + v = self.v_fc(values) + k = layers.reshape(x=k, shape=[0, 0, self.n_head, self.d_key]) + k = layers.transpose(x=k, perm=[0, 2, 1, 3]) + v = layers.reshape(x=v, shape=[0, 0, self.n_head, self.d_value]) + v = layers.transpose(x=v, perm=[0, 2, 1, 3]) + return k, v + class FFN(Layer): """ @@ -441,6 +450,14 @@ class Decoder(Layer): return self.processer(dec_output) + def prepare_static_cache(self, enc_output): + return [ + dict( + zip(("static_k", "static_v"), + decoder_layer.cross_attn.cal_kv(enc_output, enc_output))) + for decoder_layer in self.decoder_layers + ] + class WrapDecoder(Layer): """ @@ -622,481 +639,96 @@ class Transformer(Model): trg_src_attn_bias, enc_output) return predict - def beam_search_v2(self, - src_word, - src_pos, - src_slf_attn_bias, - trg_word, - trg_src_attn_bias, - bos_id=0, - eos_id=1, - beam_size=4, - max_len=None, - alpha=0.6): - """ - Beam search with the alive and finished two queues, both have a beam size - capicity separately. It includes `grow_topk` `grow_alive` `grow_finish` as - steps. - - 1. `grow_topk` selects the top `2*beam_size` candidates to avoid all getting - EOS. - - 2. `grow_alive` selects the top `beam_size` non-EOS candidates as the inputs - of next decoding step. - - 3. `grow_finish` compares the already finished candidates in the finished queue - and newly added finished candidates from `grow_topk`, and selects the top - `beam_size` finished candidates. - """ - def expand_to_beam_size(tensor, beam_size): - tensor = layers.reshape(tensor, - [tensor.shape[0], 1] + tensor.shape[1:]) - tile_dims = [1] * len(tensor.shape) - tile_dims[1] = beam_size - return layers.expand(tensor, tile_dims) - - def merge_beam_dim(tensor): - return layers.reshape(tensor, [-1] + tensor.shape[2:]) - - # run encoder - enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias) - # constant number - inf = float(1. * 1e7) - batch_size = enc_output.shape[0] - max_len = (enc_output.shape[1] + 20) if max_len is None else max_len - - ### initialize states of beam search ### - ## init for the alive ## - initial_log_probs = to_variable( - np.array([[0.] + [-inf] * (beam_size - 1)], dtype="float32")) - alive_log_probs = layers.expand(initial_log_probs, [batch_size, 1]) - alive_seq = to_variable( - np.tile(np.array([[[bos_id]]], dtype="int64"), - (batch_size, beam_size, 1))) - - ## init for the finished ## - finished_scores = to_variable( - np.array([[-inf] * beam_size], dtype="float32")) - finished_scores = layers.expand(finished_scores, [batch_size, 1]) - finished_seq = to_variable( - np.tile(np.array([[[bos_id]]], dtype="int64"), - (batch_size, beam_size, 1))) - finished_flags = layers.zeros_like(finished_scores) - - ### initialize inputs and states of transformer decoder ### - ## init inputs for decoder, shaped `[batch_size*beam_size, ...]` - trg_word = layers.reshape(alive_seq[:, :, -1], - [batch_size * beam_size, 1]) - trg_src_attn_bias = merge_beam_dim( - expand_to_beam_size(trg_src_attn_bias, beam_size)) - enc_output = merge_beam_dim(expand_to_beam_size(enc_output, beam_size)) - ## init states (caches) for transformer, need to be updated according to selected beam - caches = [{ - "k": - layers.fill_constant( - shape=[batch_size * beam_size, self.n_head, 0, self.d_key], - dtype=enc_output.dtype, - value=0), - "v": - layers.fill_constant( - shape=[batch_size * beam_size, self.n_head, 0, self.d_value], - dtype=enc_output.dtype, - value=0), - } for i in range(self.n_layer)] +from rnn_api import TransformerBeamSearchDecoder, DynamicDecode - def update_states(caches, beam_idx, beam_size): - for cache in caches: - cache["k"] = gather_2d_by_gather(cache["k"], beam_idx, - beam_size, batch_size, False) - cache["v"] = gather_2d_by_gather(cache["v"], beam_idx, - beam_size, batch_size, False) - return caches - - def gather_2d_by_gather(tensor_nd, - beam_idx, - beam_size, - batch_size, - need_flat=True): - batch_idx = layers.range(0, batch_size, 1, - dtype="int64") * beam_size - flat_tensor = merge_beam_dim(tensor_nd) if need_flat else tensor_nd - idx = layers.reshape(layers.elementwise_add(beam_idx, batch_idx, 0), - [-1]) - new_flat_tensor = layers.gather(flat_tensor, idx) - new_tensor_nd = layers.reshape( - new_flat_tensor, - shape=[batch_size, beam_idx.shape[1]] + - tensor_nd.shape[2:]) if need_flat else new_flat_tensor - return new_tensor_nd - - def early_finish(alive_log_probs, finished_scores, - finished_in_finished): - max_length_penalty = np.power(((5. + max_len) / 6.), alpha) - # The best possible score of the most likely alive sequence - lower_bound_alive_scores = alive_log_probs[:, 0] / max_length_penalty - - # Now to compute the lowest score of a finished sequence in finished - # If the sequence isn't finished, we multiply it's score by 0. since - # scores are all -ve, taking the min will give us the score of the lowest - # finished item. - lowest_score_of_fininshed_in_finished = layers.reduce_min( - finished_scores * finished_in_finished, 1) - # If none of the sequences have finished, then the min will be 0 and - # we have to replace it by -ve INF if it is. The score of any seq in alive - # will be much higher than -ve INF and the termination condition will not - # be met. - lowest_score_of_fininshed_in_finished += ( - 1. - layers.reduce_max(finished_in_finished, 1)) * -inf - bound_is_met = layers.reduce_all( - layers.greater_than(lowest_score_of_fininshed_in_finished, - lower_bound_alive_scores)) - - return bound_is_met - - def grow_topk(i, logits, alive_seq, alive_log_probs, states): - logits = layers.reshape(logits, [batch_size, beam_size, -1]) - candidate_log_probs = layers.log(layers.softmax(logits, axis=2)) - log_probs = layers.elementwise_add(candidate_log_probs, - alive_log_probs, 0) - - length_penalty = np.power(5.0 + (i + 1.0) / 6.0, alpha) - curr_scores = log_probs / length_penalty - flat_curr_scores = layers.reshape(curr_scores, [batch_size, -1]) - - topk_scores, topk_ids = layers.topk(flat_curr_scores, - k=beam_size * 2) - - topk_log_probs = topk_scores * length_penalty - - topk_beam_index = topk_ids // self.trg_vocab_size - topk_ids = topk_ids % self.trg_vocab_size - - # use gather as gather_nd, TODO: use gather_nd - topk_seq = gather_2d_by_gather(alive_seq, topk_beam_index, - beam_size, batch_size) - topk_seq = layers.concat( - [topk_seq, - layers.reshape(topk_ids, topk_ids.shape + [1])], - axis=2) - states = update_states(states, topk_beam_index, beam_size) - eos = layers.fill_constant(shape=topk_ids.shape, - dtype="int64", - value=eos_id) - topk_finished = layers.cast(layers.equal(topk_ids, eos), "float32") - - #topk_seq: [batch_size, 2*beam_size, i+1] - #topk_log_probs, topk_scores, topk_finished: [batch_size, 2*beam_size] - return topk_seq, topk_log_probs, topk_scores, topk_finished, states - - def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished, - states): - curr_scores += curr_finished * -inf - _, topk_indexes = layers.topk(curr_scores, k=beam_size) - alive_seq = gather_2d_by_gather(curr_seq, topk_indexes, - beam_size * 2, batch_size) - alive_log_probs = gather_2d_by_gather(curr_log_probs, topk_indexes, - beam_size * 2, batch_size) - states = update_states(states, topk_indexes, beam_size * 2) - - return alive_seq, alive_log_probs, states - - def grow_finished(finished_seq, finished_scores, finished_flags, - curr_seq, curr_scores, curr_finished): - # finished scores - finished_seq = layers.concat([ - finished_seq, - layers.fill_constant(shape=[batch_size, beam_size, 1], - dtype="int64", - value=eos_id) - ], - axis=2) - # Set the scores of the unfinished seq in curr_seq to large negative - # values - curr_scores += (1. - curr_finished) * -inf - # concatenating the sequences and scores along beam axis - curr_finished_seq = layers.concat([finished_seq, curr_seq], axis=1) - curr_finished_scores = layers.concat([finished_scores, curr_scores], - axis=1) - curr_finished_flags = layers.concat([finished_flags, curr_finished], - axis=1) - _, topk_indexes = layers.topk(curr_finished_scores, k=beam_size) - finished_seq = gather_2d_by_gather(curr_finished_seq, topk_indexes, - beam_size * 3, batch_size) - finished_scores = gather_2d_by_gather(curr_finished_scores, - topk_indexes, beam_size * 3, - batch_size) - finished_flags = gather_2d_by_gather(curr_finished_flags, - topk_indexes, beam_size * 3, - batch_size) - return finished_seq, finished_scores, finished_flags - - for i in range(max_len): - trg_pos = layers.fill_constant(shape=trg_word.shape, - dtype="int64", - value=i) - logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, - enc_output, caches) - topk_seq, topk_log_probs, topk_scores, topk_finished, states = grow_topk( - i, logits, alive_seq, alive_log_probs, caches) - alive_seq, alive_log_probs, states = grow_alive( - topk_seq, topk_scores, topk_log_probs, topk_finished, states) - finished_seq, finished_scores, finished_flags = grow_finished( - finished_seq, finished_scores, finished_flags, topk_seq, - topk_scores, topk_finished) - trg_word = layers.reshape(alive_seq[:, :, -1], - [batch_size * beam_size, 1]) - - if early_finish(alive_log_probs, finished_scores, - finished_flags).numpy(): - break - - return finished_seq, finished_scores - - def beam_search(self, - src_word, - src_pos, - src_slf_attn_bias, - trg_word, - trg_src_attn_bias, - bos_id=0, - eos_id=1, - beam_size=4, - max_len=256): - if beam_size == 1: - return self._greedy_search(src_word, - src_pos, - src_slf_attn_bias, - trg_word, - trg_src_attn_bias, - bos_id=bos_id, - eos_id=eos_id, - max_len=max_len) - else: - return self._beam_search(src_word, - src_pos, - src_slf_attn_bias, - trg_word, - trg_src_attn_bias, - bos_id=bos_id, - eos_id=eos_id, - beam_size=beam_size, - max_len=max_len) - - def _beam_search(self, - src_word, - src_pos, - src_slf_attn_bias, - trg_word, - trg_src_attn_bias, - bos_id=0, - eos_id=1, - beam_size=4, - max_len=256): - def expand_to_beam_size(tensor, beam_size): - tensor = layers.reshape(tensor, - [tensor.shape[0], 1] + tensor.shape[1:]) - tile_dims = [1] * len(tensor.shape) - tile_dims[1] = beam_size - return layers.expand(tensor, tile_dims) - - def merge_batch_beams(tensor): - return layers.reshape(tensor, [tensor.shape[0] * tensor.shape[1]] + - tensor.shape[2:]) - - def split_batch_beams(tensor): - return layers.reshape(tensor, - shape=[-1, beam_size] + - list(tensor.shape[1:])) - - def mask_probs(probs, finished, noend_mask_tensor): - # TODO: use where_op - finished = layers.cast(finished, dtype=probs.dtype) - probs = layers.elementwise_mul(layers.expand( - layers.unsqueeze(finished, [2]), [1, 1, self.trg_vocab_size]), - noend_mask_tensor, - axis=-1) - layers.elementwise_mul( - probs, (finished - 1), axis=0) - return probs - - def gather(x, indices, batch_pos): - topk_coordinates = layers.stack([batch_pos, indices], axis=2) - return layers.gather_nd(x, topk_coordinates) - - def update_states(func, caches): - for cache in caches: # no need to update static_kv - cache["k"] = func(cache["k"]) - cache["v"] = func(cache["v"]) - return caches - - # run encoder - enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias) - # constant number - inf = float(1. * 1e7) - batch_size = enc_output.shape[0] - max_len = (enc_output.shape[1] + 20) if max_len is None else max_len - vocab_size_tensor = layers.fill_constant(shape=[1], - dtype="int64", - value=self.trg_vocab_size) - end_token_tensor = to_variable( - np.full([batch_size, beam_size], eos_id, dtype="int64")) - noend_array = [-inf] * self.trg_vocab_size - noend_array[eos_id] = 0 - noend_mask_tensor = to_variable(np.array(noend_array,dtype="float32")) - batch_pos = layers.expand( - layers.unsqueeze( - to_variable(np.arange(0, batch_size, 1, dtype="int64")), [1]), - [1, beam_size]) - - predict_ids = [] - parent_ids = [] - ### initialize states of beam search ### - log_probs = to_variable( - np.array([[0.] + [-inf] * (beam_size - 1)] * batch_size, - dtype="float32")) - finished = to_variable(np.full([batch_size, beam_size], 0, - dtype="bool")) - ### initialize inputs and states of transformer decoder ### - ## init inputs for decoder, shaped `[batch_size*beam_size, ...]` - trg_word = layers.fill_constant(shape=[batch_size * beam_size, 1], - dtype="int64", - value=bos_id) - trg_pos = layers.zeros_like(trg_word) - trg_src_attn_bias = merge_batch_beams( - expand_to_beam_size(trg_src_attn_bias, beam_size)) - enc_output = merge_batch_beams(expand_to_beam_size(enc_output, beam_size)) - ## init states (caches) for transformer, need to be updated according to selected beam - caches = [{ - "k": - layers.fill_constant( - shape=[batch_size * beam_size, self.n_head, 0, self.d_key], - dtype=enc_output.dtype, - value=0), - "v": - layers.fill_constant( - shape=[batch_size * beam_size, self.n_head, 0, self.d_value], - dtype=enc_output.dtype, - value=0), - } for i in range(self.n_layer)] +class TransfomerCell(object): + """ + Let inputs=(trg_word, trg_pos), states=cache to make Transformer can be + used as RNNCell + """ + def __init__(self, decoder): + self.decoder = decoder + + def __call__(self, inputs, states, trg_src_attn_bias, enc_output, + static_caches): + trg_word, trg_pos = inputs + for cache, static_cache in zip(states, static_caches): + cache.update(static_cache) + logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, + enc_output, states) + new_states = [{"k": cache["k"], "v": cache["v"]} for cache in states] + return logits, new_states - for i in range(max_len): - trg_pos = layers.fill_constant(shape=trg_word.shape, - dtype="int64", - value=i) - caches = update_states( # can not be reshaped since the 0 size - lambda x: x if i == 0 else merge_batch_beams(x), caches) - logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, - enc_output, caches) - caches = update_states(split_batch_beams, caches) - step_log_probs = split_batch_beams( - layers.log(layers.softmax(logits))) - step_log_probs = mask_probs(step_log_probs, finished, - noend_mask_tensor) - log_probs = layers.elementwise_add(x=step_log_probs, - y=log_probs, - axis=0) - log_probs = layers.reshape(log_probs, - [-1, beam_size * self.trg_vocab_size]) - scores = log_probs - topk_scores, topk_indices = layers.topk(input=scores, k=beam_size) - beam_indices = layers.elementwise_floordiv( - topk_indices, vocab_size_tensor) - token_indices = layers.elementwise_mod( - topk_indices, vocab_size_tensor) - - # update states - caches = update_states(lambda x: gather(x, beam_indices, batch_pos), - caches) - log_probs = gather(log_probs, topk_indices, batch_pos) - finished = gather(finished, beam_indices, batch_pos) - finished = layers.logical_or( - finished, layers.equal(token_indices, end_token_tensor)) - trg_word = layers.reshape(token_indices, [-1, 1]) - - predict_ids.append(token_indices) - parent_ids.append(beam_indices) - - if layers.reduce_all(finished).numpy(): - break - - predict_ids = layers.stack(predict_ids, axis=0) - parent_ids = layers.stack(parent_ids, axis=0) - finished_seq = layers.transpose( - layers.gather_tree(predict_ids, parent_ids), [1, 2, 0]) - finished_scores = topk_scores - - return finished_seq, finished_scores - - def _greedy_search(self, - src_word, - src_pos, - src_slf_attn_bias, - trg_word, - trg_src_attn_bias, - bos_id=0, - eos_id=1, - max_len=256): - # run encoder - enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias) - # constant number - batch_size = enc_output.shape[0] - max_len = (enc_output.shape[1] + 20) if max_len is None else max_len - end_token_tensor = layers.fill_constant(shape=[batch_size, 1], - dtype="int64", - value=eos_id) - - predict_ids = [] - log_probs = layers.fill_constant(shape=[batch_size, 1], - dtype="float32", - value=0) - trg_word = layers.fill_constant(shape=[batch_size, 1], - dtype="int64", - value=bos_id) - finished = layers.fill_constant(shape=[batch_size, 1], - dtype="bool", - value=0) - - ## init states (caches) for transformer +class InferTransformer(Transformer): + """ + model for prediction + """ + def __init__(self, + src_vocab_size, + trg_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + bos_id=0, + eos_id=1, + beam_size=4, + max_out_len=256): + args = locals() + args.pop("self") + self.beam_size = args.pop("beam_size") + self.max_out_len = args.pop("max_out_len") + super(InferTransformer, self).__init__(**args) + cell = TransfomerCell(self.decoder) + self.beam_search_decoder = DynamicDecode( + TransformerBeamSearchDecoder(cell, + bos_id, + eos_id, + beam_size, + var_dim_in_state=2), max_out_len) + + + @shape_hints(src_word=[None, None], + src_pos=[None, None], + src_slf_attn_bias=[None, 8, None, None], + trg_src_attn_bias=[None, 8, None, None]) + def forward(self, src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias): + enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias) + ## init states (caches) for transformer, need to be updated according to selected beam caches = [{ "k": - layers.fill_constant( - shape=[batch_size, self.n_head, 0, self.d_key], + layers.fill_constant_batch_size_like( + input=enc_output, + shape=[-1, self.n_head, 0, self.d_key], dtype=enc_output.dtype, value=0), "v": - layers.fill_constant( - shape=[batch_size, self.n_head, 0, self.d_value], + layers.fill_constant_batch_size_like( + input=enc_output, + shape=[-1, self.n_head, 0, self.d_value], dtype=enc_output.dtype, value=0), } for i in range(self.n_layer)] - - for i in range(max_len): - trg_pos = layers.fill_constant(shape=trg_word.shape, - dtype="int64", - value=i) - logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, - enc_output, caches) - step_log_probs = layers.log(layers.softmax(logits)) - log_probs = layers.elementwise_add(x=step_log_probs, - y=log_probs, - axis=0) - scores = log_probs - topk_scores, topk_indices = layers.topk(input=scores, k=1) - - finished = layers.logical_or( - finished, layers.equal(topk_indices, end_token_tensor)) - trg_word = topk_indices - log_probs = topk_scores - - predict_ids.append(topk_indices) - - if layers.reduce_all(finished).numpy(): - break - - predict_ids = layers.stack(predict_ids, axis=0) - finished_seq = layers.transpose(predict_ids, [1, 2, 0]) - finished_scores = topk_scores - - return finished_seq, finished_scores + enc_output = TransformerBeamSearchDecoder.tile_beam_merge_with_batch( + enc_output, self.beam_size) + trg_src_attn_bias = TransformerBeamSearchDecoder.tile_beam_merge_with_batch( + trg_src_attn_bias, self.beam_size) + static_caches = self.decoder.decoder.prepare_static_cache( + enc_output) + rs, _ = self.beam_search_decoder(inits=caches, + enc_output=enc_output, + trg_src_attn_bias=trg_src_attn_bias, + static_caches=static_caches) + return rs diff --git a/transformer_pr.tar.gz b/transformer_pr.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..deef2d1b39078b1296793da269cc86fa7f76efbf Binary files /dev/null and b/transformer_pr.tar.gz differ