From b5ab897940ccd189a5ca60c425a295fb5ba66fe9 Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Fri, 9 Dec 2022 17:28:44 +0800 Subject: [PATCH] [remove fluid] Remove fluid APIs (#48641) --- python/paddle/fluid/layers/rnn.py | 692 ------------------ .../fluid/tests/unittests/dist_transformer.py | 155 ---- .../unittests/test_beam_search_decode_op.py | 48 -- .../tests/unittests/test_beam_search_op.py | 117 --- .../tests/unittests/test_rnn_decode_api.py | 30 +- 5 files changed, 11 insertions(+), 1031 deletions(-) diff --git a/python/paddle/fluid/layers/rnn.py b/python/paddle/fluid/layers/rnn.py index 52c0d133f0..31dfc905a6 100644 --- a/python/paddle/fluid/layers/rnn.py +++ b/python/paddle/fluid/layers/rnn.py @@ -42,18 +42,12 @@ __all__ = [ 'rnn', 'birnn', 'dynamic_decode', - 'DecodeHelper', - 'TrainingHelper', - 'GreedyEmbeddingHelper', - 'SampleEmbeddingHelper', 'dynamic_lstm', 'dynamic_lstmp', 'dynamic_gru', 'gru_unit', 'lstm_unit', 'lstm', - 'beam_search', - 'beam_search_decode', ] @@ -1234,447 +1228,6 @@ def dynamic_decode( ) -class DecodeHelper: - """ - DecodeHelper is the base class for any helper instance used in `BasicDecoder`. - It provides interface to implement sampling and produce inputs for the next - time step in dynamic decoding. - """ - - def initialize(self): - r""" - DecodeHelper initialization to produce inputs for the first decoding step - and give the initial status telling whether each sequence in the batch - is finished. It is the partial of the initialization of `BasicDecoder`. - - Returns: - tuple: A tuple( :code:`(initial_inputs, initial_finished)` ). \ - `initial_inputs` is a (possibly nested structure of) tensor \ - variable[s], and the tensor's shape is `[batch_size, ...]`. \ - `initial_finished` is a bool tensor with shape `[batch_size]`. - """ - pass - - def sample(self, time, outputs, states): - """ - Perform sampling with some strategies according to `outputs`. It is the - partial of `BasicDecoder.step`. - - Parameters: - time(Variable): An `int64` tensor with shape `[1]` provided by the caller, - representing the current time step number of decoding. - outputs(Variable): A tensor variable. Usually it's data type is float32 - or float64, and it's shape is `[batch_size, vocabulary_size]`, - representing the predicted logits of current step. It is same as - `outputs` returned by `BasicDecoder.output_fn(BasicDecoder.cell.call())`. - states(Variable): A (possibly nested structure of) tensor variable[s]. - It is same as `new_states` returned by `BasicDecoder.cell.call()`. - - Returns: - Variable: An `int64` tensor representing the sampled ids. - """ - pass - - def next_inputs(self, time, outputs, states, sample_ids): - r""" - Produce the inputs and states for next time step and give status telling - whether each minibatch entry is finished. It is called after `sample` in - `BasicDecoder.step`. It is the partial of `BasicDecoder.step`. - - Parameters: - time(Variable): An `int64` tensor with shape `[1]` provided by the caller, - representing the current time step number of decoding. - outputs(Variable): A tensor variable. Usually it's data type is float32 - or float64, and it's shape is `[batch_size, vocabulary_size]`, - representing the predicted logits of current step. It is same as - `outputs` returned by `BasicDecoder.output_fn(BasicDecoder.cell.call())`. - states(Variable): A (possibly nested structure of) tensor variable[s]. - It is same as `new_states` returned by `BasicDecoder.cell.call()`. - sample_ids(Variable): A (possibly nested structure of) tensor variable[s]. - It is same as `sample_ids` returned by `sample()`. - - Returns: - tuple: A tuple( :code:`(finished, next_inputs, next_states)` ). \ - `next_inputs` and `next_states` both are a (possibly nested \ - structure of) tensor variable[s], and the structure, shape and \ - data type of `next_states` must be same as the input argument \ - `states`. `finished` is a bool tensor with shape `[batch_size]`. - """ - pass - - -class TrainingHelper(DecodeHelper): - """ - TrainingHelper is a subclass of DecodeHelper. It is a decoding helper - slicing from the full sequence inputs as the inputs for corresponding - step. And it uses `argmax` to sample from the outputs of `cell.call()`. - - Since the needs of sequence inputs, it is used mostly for teach-forcing MLE - (maximum likelihood) training, and the sampled would not be used. - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - import paddle.fluid.layers as layers - trg_emb = fluid.data(name="trg_emb", - shape=[None, None, 128], - dtype="float32") - trg_seq_length = fluid.data(name="trg_seq_length", - shape=[None], - dtype="int64") - helper = layers.TrainingHelper(trg_emb, trg_seq_length) - decoder_cell = layers.GRUCell(hidden_size=128) - decoder = layers.BasicDecoder(decoder_cell, helper) - outputs = layers.dynamic_decode( - decoder, - inits=decoder_cell.get_initial_states(trg_emb), - is_test=False) - """ - - def __init__(self, inputs, sequence_length, time_major=False): - """ - Constructor of TrainingHelper. - - Parameters: - inputs(Variable): A (possibly nested structure of) tensor variable[s]. - The shape of tensor should be `[batch_size, sequence_length, ...]` - for `time_major == False` or `[sequence_length, batch_size, ...]` - for `time_major == True`. It represents the inputs to be sliced - from at every decoding step. - sequence_length(Variable): A tensor with shape `[batch_size]`. - It stores real length of each instance in `inputs`, by which we - can label the finished status of each instance at every decoding - step. - time_major(bool, optional): Indicate the data layout of Tensor included - in `inputs`. If `False`, the data layout would be batch major with - shape `[batch_size, sequence_length, ...]`. If `True`, the data - layout would be time major with shape `[sequence_length, batch_size, ...]`. - Default: `False`. - """ - self.inputs = inputs - self.sequence_length = sequence_length - self.time_major = time_major - # extend inputs to avoid to slice out of range in `next_inputs` - # may be easier and have better performance than condition_op - self.inputs_ = map_structure( - lambda x: paddle.nn.functional.pad( - x, - pad=([0, 1] + [0, 0] * (len(x.shape) - 1)) - if time_major - else ([0, 0, 0, 1] + [0, 0] * (len(x.shape) - 2)), - ), - self.inputs, - ) - - def initialize(self): - r""" - TrainingHelper initialization produces inputs for the first decoding - step by slicing at the first time step of full sequence inputs, and it - gives initial status telling whether each sequence in the batch is - finished. It is the partial of the initialization of `BasicDecoder`. - - Returns: - tuple: A tuple( :code:`(initial_inputs, initial_finished)` ). \ - `initial_inputs` is a (possibly nested structure of) tensor \ - variable[s], and the tensor's shape is `[batch_size, ...]`. \ - `initial_finished` is a bool tensor with shape `[batch_size]`. - """ - init_finished = paddle.equal( - self.sequence_length, - tensor.fill_constant( - shape=[1], dtype=self.sequence_length.dtype, value=0 - ), - ) - # TODO: support zero length - init_inputs = map_structure( - lambda x: x[0] if self.time_major else x[:, 0], self.inputs - ) - return init_inputs, init_finished - - def sample(self, time, outputs, states): - r""" - Perform sampling by using `argmax` according to the `outputs`. Mostly - the sampled ids would not be used since the inputs for next decoding - step would be got by slicing. - - Parameters: - time(Variable): An `int64` tensor with shape `[1]` provided by the - caller, representing the current time step number of decoding. - outputs(Variable): A tensor variable. Usually it's data type is float32 - or float64, and it's shape is `[batch_size, vocabulary_size]`, - representing the predicted logits of current step. It is same as - `outputs` returned by `BasicDecoder.output_fn(BasicDecoder.cell.call())`. - states(Variable): A (possibly nested structure of) tensor variable[s]. - It is same as `new_states` returned by `BasicDecoder.cell.call()`. - - Returns: - Variable: An `int64` tensor with shape `[batch_size]`, representing \ - the sampled ids. - """ - sample_ids = tensor.argmax(outputs, axis=-1) - return sample_ids - - def next_inputs(self, time, outputs, states, sample_ids): - r""" - Generate inputs for the next decoding step by slicing at corresponding - step of the full sequence inputs. Simultaneously, produce the states - for next time step by directly using the input `states` and emit status - telling whether each minibatch entry reaches to the corresponding length. - - Parameters: - time(Variable): An `int64` tensor with shape `[1]` provided by the - caller, representing the current time step number of decoding. - outputs(Variable): A tensor variable. Usually it's data type is float32 - or float64, and it's shape is `[batch_size, vocabulary_size]`, - representing the predicted logits of current step. It is same as - `outputs` returned by `BasicDecoder.output_fn(BasicDecoder.cell.call())`. - states(Variable): A (possibly nested structure of) tensor variable[s]. - It is same as `new_states` returned by `BasicDecoder.cell.call()`. - sample_ids(Variable): An `int64` tensor variable shaped `[batch_size]`. - It is same as `sample_ids` returned by `sample()`. - - Returns: - tuple: A tuple( :code:`(finished, next_inputs, next_states)` ). \ - `next_inputs` and `next_states` both are a (possibly nested \ - structure of) tensor variable[s], and the tensor's shape is \ - `[batch_size, ...]`. `next_states` is identical to the input \ - argument `states`. `finished` is a `bool` Tensor with \ - shape `[batch_size]`. - """ - # TODO: compatibility of int32 and int64 - time = ( - tensor.cast(time, "int32") - if convert_dtype(time.dtype) not in ["int32"] - else time - ) - if self.sequence_length.dtype != time.dtype: - self.sequence_length = tensor.cast(self.sequence_length, time.dtype) - next_time = time + 1 - finished = paddle.less_equal(self.sequence_length, next_time) - - def _slice(x): # TODO: use Variable.__getitem__ - axes = [0 if self.time_major else 1] - return paddle.squeeze( - paddle.slice( - x, axes=axes, starts=[next_time], ends=[next_time + 1] - ), - axis=axes, - ) - - next_inputs = map_structure(_slice, self.inputs_) - return finished, next_inputs, states - - -class GreedyEmbeddingHelper(DecodeHelper): - """ - GreedyEmbeddingHelper is a subclass of DecodeHelper. It is a decoding helper - uses the argmax of the output (treated as logits) and passes the results - through an embedding layer to get inputs for the next decoding step. - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - import paddle.fluid.layers as layers - trg_emb = fluid.data(name="trg_emb", - shape=[None, None, 128], - dtype="float32") - - trg_embeder = lambda x: fluid.embedding( - x, size=[10000, 128], param_attr=fluid.ParamAttr(name="trg_embedding")) - output_layer = lambda x: layers.fc(x, - size=10000, - num_flatten_dims=len(x.shape) - 1, - param_attr=fluid.ParamAttr(name= - "output_w"), - bias_attr=False) - helper = layers.GreedyEmbeddingHelper(trg_embeder, start_tokens=0, end_token=1) - decoder_cell = layers.GRUCell(hidden_size=128) - decoder = layers.BasicDecoder(decoder_cell, helper, output_fn=output_layer) - outputs = layers.dynamic_decode( - decoder=decoder, inits=decoder_cell.get_initial_states(encoder_output)) - """ - - def __init__(self, embedding_fn, start_tokens, end_token): - r""" - Constructor of GreedyEmbeddingHelper. - - Parameters: - embedding_fn(callable): A functor to apply on the argmax results. - Mostly it is an embedding layer to transform ids to embeddings. - **Note that fluid.embedding should be used here rather than - fluid.layers.embedding, since shape of ids is [batch_size]. - when using fluid.layers.embedding, must unsqueeze in embedding_fn.** - start_tokens(Variable): A `int64` tensor shaped `[batch_size]`, - representing the start tokens. - end_token(int): The end token id. - - Returns: - tuple: A tuple( :code:`(initial_inputs, initial_states, finished)` ). \ - `initial_inputs` and `initial_states` both are a (possibly nested \ - structure of) tensor variable[s], and `finished` is a tensor with \ - bool data type. - """ - self.embedding_fn = embedding_fn - self.start_tokens = start_tokens - self.end_token = tensor.fill_constant( - shape=[1], dtype="int64", value=end_token - ) - - def initialize(self): - r""" - GreedyEmbeddingHelper initialization produces inputs for the first decoding - step by using `start_tokens` of the constructor, and gives initial - status telling whether each sequence in the batch is finished. - It is the partial of the initialization of `BasicDecoder`. - - Returns: - tuple: A tuple( :code:`(initial_inputs, initial_finished)` ). \ - `initial_inputs` is same as `start_tokens` of the constructor. \ - `initial_finished` is a `bool` tensor filled by False and has \ - the same shape as `start_tokens`. - """ - # TODO: remove the restriction of force_cpu - init_finished = tensor.fill_constant_batch_size_like( - input=self.start_tokens, - shape=[-1], - dtype="bool", - value=False, - force_cpu=True, - ) - init_inputs = self.embedding_fn(self.start_tokens) - return init_inputs, init_finished - - def sample(self, time, outputs, states): - r""" - Perform sampling by using `argmax` according to the `outputs`. - - Parameters: - time(Variable): An `int64` tensor with shape `[1]` provided by the - caller, representing the current time step number of decoding. - outputs(Variable): A tensor variable. Usually it's data type is float32 - or float64, and it's shape is `[batch_size, vocabulary_size]`, - representing the predicted logits of current step. It is same as - `outputs` returned by `BasicDecoder.output_fn(BasicDecoder.cell.call())`. - states(Variable): A (possibly nested structure of) tensor variable[s]. - It is same as `new_states` returned by `BasicDecoder.cell.call()`. - - Returns: - Variable: An `int64` tensor with shape `[batch_size]`, representing \ - the sampled ids. - """ - sample_ids = tensor.argmax(outputs, axis=-1) - return sample_ids - - def next_inputs(self, time, outputs, states, sample_ids): - r""" - Generate inputs for the next decoding step by applying `embedding_fn` - to `sample_ids`. Simultaneously, produce the states for next time step - by directly using the input `states` and emit status telling whether - each minibatch entry gets an `end_token` sample. - - Parameters: - time(Variable): An `int64` tensor with shape `[1]` provided by the - caller, representing the current time step number of decoding. - outputs(Variable): A tensor variable. Usually it's data type is float32 - or float64, and it's shape is `[batch_size, vocabulary_size]`, - representing the predicted logits of current step. It is same as - `outputs` returned by `BasicDecoder.output_fn(BasicDecoder.cell.call())`. - states(Variable): A (possibly nested structure of) tensor variable[s]. - It is same as `new_states` returned by `BasicDecoder.cell.call()`. - sample_ids(Variable): An `int64` tensor variable shaped `[batch_size]`. - It is same as `sample_ids` returned by `sample()`. - - Returns: - tuple: A tuple( :code:`(finished, next_inputs, next_states)` ). \ - `next_inputs` and `next_states` both are a (possibly nested \ - structure of) tensor variable[s], and the tensor's shape is \ - `[batch_size, ...]`. `next_states` is identical to the input \ - argument `states`. `finished` is a `bool` Tensor with \ - shape `[batch_size]`. - """ - finished = paddle.equal(sample_ids, self.end_token) - next_inputs = self.embedding_fn(sample_ids) - return finished, next_inputs, states - - -class SampleEmbeddingHelper(GreedyEmbeddingHelper): - """ - SampleEmbeddingHelper is a subclass of GreedyEmbeddingHelper. It is a decoding - helper uses sampling (from a distribution) instead of argmax of the output - (treated as logits) and passes the results through an embedding layer to get - inputs for the next decoding step. - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - import paddle.fluid.layers as layers - trg_emb = fluid.data(name="trg_emb", - shape=[None, None, 128], - dtype="float32") - - trg_embeder = lambda x: fluid.embedding( - x, size=[10000, 128], param_attr=fluid.ParamAttr(name="trg_embedding")) - output_layer = lambda x: layers.fc(x, - size=10000, - num_flatten_dims=len(x.shape) - 1, - param_attr=fluid.ParamAttr(name= - "output_w"), - bias_attr=False) - helper = layers.SampleEmbeddingHelper(trg_embeder, start_tokens=0, end_token=1) - decoder_cell = layers.GRUCell(hidden_size=128) - decoder = layers.BasicDecoder(decoder_cell, helper, output_fn=output_layer) - outputs = layers.dynamic_decode( - decoder=decoder, inits=decoder_cell.get_initial_states(encoder_output)) - """ - - def __init__( - self, - embedding_fn, - start_tokens, - end_token, - softmax_temperature=None, - seed=None, - ): - r""" - Constructor of SampleEmbeddingHelper. - - Parameters: - embedding_fn(callable): A functor to apply on the argmax results. - Mostly it is an embedding layer to transform ids to embeddings. - **Note that fluid.embedding should be used here rather than - fluid.layers.embedding, since shape of ids is [batch_size]. - when using fluid.layers.embedding, must unsqueeze in embedding_fn.** - start_tokens(Variable): A `int64` tensor shaped `[batch_size]`, - representing the start tokens. - end_token(int): The end token id. - softmax_temperature(float, optional): the value to divide the logits - by before computing the softmax. Higher temperatures (above 1.0) - lead to more random, while lower temperatures push the sampling - distribution towards the argmax. It must be strictly greater than - 0. Defaults to None, meaning using a temperature valued 1.0. - seed: (int, optional) The sampling seed. Defaults to None, meaning not - to use fixed seed. - - Returns: - tuple: A tuple( :code:`(initial_inputs, initial_states, finished)` ). \ - `initial_inputs` and `initial_states` both are a (possibly nested \ - structure of) tensor variable[s], and `finished` is a tensor with \ - bool data type. - """ - super().__init__(embedding_fn, start_tokens, end_token) - self.softmax_temperature = ( - tensor.fill_constant( - shape=[1], dtype="float32", value=softmax_temperature - ) - if softmax_temperature is not None - else None - ) - self.seed = seed - - def dynamic_lstm( input, size, @@ -2619,251 +2172,6 @@ def gru_unit( return updated_hidden, reset_hidden_pre, gate -def beam_search( - pre_ids, - pre_scores, - ids, - scores, - beam_size, - end_id, - level=0, - is_accumulated=True, - name=None, - return_parent_idx=False, -): - r""" - - Beam search is a classical algorithm for selecting candidate words in a - machine translation task. - - Refer to `Beam search `_ - for more details. - - **This operator only supports LoDTensor.** It is used after finishing - scores calculation to perform beam search for one time step. Specifically, - after ``ids`` and ``scores`` have been produced, it selects the top-K - ( `k` is ``beam_size`` ) candidate word ids of current step from ``ids`` - according to the corresponding ``scores``. Additionally, ``pre_id`` and - ``pre_scores`` are the output of `beam_search` at previous step, they - are needed for special use to handle ended candidate translations. - - Note that if ``is_accumulated`` is True, the ``scores`` passed in should - be accumulated scores. Otherwise, the ``scores`` are - considered as the probabilities of single step and would be transformed to - the log field and added up with ``pre_scores`` for final scores in this - operator. Length penalty should be done with extra operators before calculating - the accumulated scores if needed. - - Please see the following demo for a fully beam search usage example: - - fluid/tests/book/test_machine_translation.py - - Args: - pre_ids(Variable): A LodTensor variable (lod level is 2), representing - the selected ids of previous step. It is the output of beam_search - at previous step. Its shape is `[batch_size, 1]` and its lod is - `[[0, 1, ... , batch_size], [0, 1, ..., batch_size]]` at the - first step. The data type should be int64. - pre_scores(Variable): A LodTensor variable has the same shape and lod - with ``pre_ids`` , representing the accumulated scores corresponding - to the selected ids of previous step. It is the output of - beam_search at previous step. The data type should be float32 or float64. - ids(Variable|None): A LodTensor variable containing the candidates ids. - It has the same lod with ``pre_ids`` and its shape should be - `[batch_size * beam_size, K]`, where `K` supposed to be greater than - ``beam_size`` and the first dimension size (decrease as samples reach - to the end) should be same as that of ``pre_ids`` . The data type - should be int64. It can be None, which use index in ``scores`` as - ids. - scores(Variable): A LodTensor variable containing the accumulated - scores corresponding to ``ids`` . Both its shape and lod are same as - those of ``ids`` . The data type should be float32 or float64. - beam_size(int): The beam width used in beam search. - end_id(int): The id of end token. - level(int): **It can be ignored and mustn't change currently.** - The 2 level lod used in this operator has the following - meaning: The first level describes how many beams each sample has, - which would change to 0 when beams of the sample all end (batch reduce); - The second level describes how many times each beam is selected. - Default 0, which shouldn't be changed currently. - is_accumulated(bool): Whether the input ``score`` is accumulated scores. - Default True. - name(str, optional): For detailed information, please refer - to :ref:`api_guide_Name`. Usually name is no need to set and - None by default. - return_parent_idx(bool, optional): Whether to return an extra Tensor variable - in output, which stores the selected ids' parent index in - ``pre_ids`` and can be used to update RNN's states by gather operator. - Default False. - - Returns: - tuple: The tuple contains two or three LodTensor variables. The two LodTensor, \ - representing the selected ids and the corresponding accumulated scores of \ - current step, have the same shape `[batch_size, beam_size]` and lod with 2 levels, \ - and have data types int64 and float32. If ``return_parent_idx`` is True, \ - an extra Tensor variable preserving the selected ids' parent index \ - is included, whose shape is `[batch_size * beam_size]` and data type \ - is int64. - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - import paddle - paddle.enable_static() - - # Suppose `probs` contains predicted results from the computation - # cell and `pre_ids` and `pre_scores` is the output of beam_search - # at previous step. - beam_size = 4 - end_id = 1 - pre_ids = fluid.data( - name='pre_id', shape=[None, 1], lod_level=2, dtype='int64') - pre_scores = fluid.data( - name='pre_scores', shape=[None, 1], lod_level=2, dtype='float32') - probs = fluid.data( - name='probs', shape=[None, 10000], dtype='float32') - topk_scores, topk_indices = fluid.layers.topk(probs, k=beam_size) - accu_scores = fluid.layers.elementwise_add( - x=paddle.log(x=topk_scores), - y=paddle.reshape(pre_scores, shape=[-1]), - axis=0) - selected_ids, selected_scores = fluid.layers.beam_search( - pre_ids=pre_ids, - pre_scores=pre_scores, - ids=topk_indices, - scores=accu_scores, - beam_size=beam_size, - end_id=end_id) - """ - check_variable_and_dtype(pre_ids, 'pre_ids', ['int64'], 'beam_search') - check_variable_and_dtype( - pre_scores, 'pre_scores', ['float32', 'float64'], 'beam_search' - ) - check_type(ids, 'ids', (Variable, type(None)), 'beam_search') - check_variable_and_dtype( - scores, 'scores', ['float32', 'float64'], 'beam_search' - ) - helper = LayerHelper('beam_search', **locals()) - score_type = pre_scores.dtype - id_type = pre_ids.dtype - - inputs = {"pre_ids": pre_ids, "pre_scores": pre_scores, "scores": scores} - if ids is not None: - inputs["ids"] = ids - - selected_scores = helper.create_variable_for_type_inference( - dtype=score_type - ) - selected_ids = helper.create_variable_for_type_inference(dtype=id_type) - # parent_idx is a tensor used to gather cell states at the next time - # step. Though lod in selected_ids can also be used to gather by - # sequence_expand, it is not efficient. - # gather_op's index input only supports int32 dtype currently - parent_idx = helper.create_variable_for_type_inference(dtype="int32") - - helper.append_op( - type='beam_search', - inputs=inputs, - outputs={ - 'selected_ids': selected_ids, - 'selected_scores': selected_scores, - 'parent_idx': parent_idx, - }, - attrs={ - # TODO(ChunweiYan) to assure other value support - 'level': level, - 'beam_size': beam_size, - 'end_id': end_id, - 'is_accumulated': is_accumulated, - }, - ) - if return_parent_idx: - return selected_ids, selected_scores, parent_idx - else: - return selected_ids, selected_scores - - -def beam_search_decode(ids, scores, beam_size, end_id, name=None): - r""" - - This operator is used after beam search has completed. It constructs the - full predicted sequences for each sample by walking back along the search - paths stored in lod of ``ids`` . The result sequences are stored in a - LoDTensor, which uses the following way to parse: - - .. code-block:: text - - If lod = [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]] - - The first level of lod stands for: There are 2 samples each having 3 - (beam width) predicted sequence. - - The second level of lod stands for: The lengths of the first sample's - 3 predicted sequences are 12, 12, 16; The lengths of the second sample's - 3 predicted sequences are 14, 13, 15. - - - Please see the following demo for a fully beam search usage example: - fluid/tests/book/test_machine_translation.py - - Args: - ids(Variable): The LoDTensorArray variable containing the selected ids - of all steps. Each LoDTensor in it has int64 data type and 2 level - lod which can be used to get the search paths. - scores(Variable): The LodTensorArray variable containing the accumulated - scores corresponding to selected ids of all steps. It has the same size - as ``ids`` . Each LoDTensor in it has the same shape and lod as the - counterpart in ``ids`` , and has a float32 data type. - beam_size(int): The beam width used in beam search. - end_id(int): The id of end token. - name(str, optional): For detailed information, please refer - to :ref:`api_guide_Name`. Usually name is no need to set and - None by default. - - Returns: - tuple: The tuple contains two LodTensor variables. The two LodTensor, \ - containing the full sequences of ids and the corresponding accumulated \ - scores, have the same shape flattened to 1D and have the same 2 level \ - lod. The lod can be used to get how many predicted sequences each sample \ - has and how many ids each predicted sequence has. - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - import paddle - paddle.enable_static() - # Suppose `ids` and `scores` are LodTensorArray variables reserving - # the selected ids and scores of all steps - ids = fluid.layers.create_array(dtype='int64') - scores = fluid.layers.create_array(dtype='float32') - finished_ids, finished_scores = fluid.layers.beam_search_decode( - ids, scores, beam_size=5, end_id=0) - """ - check_variable_and_dtype(ids, 'ids', ['int64'], 'beam_search_encode') - check_variable_and_dtype( - scores, 'scores', ['float32'], 'beam_search_encode' - ) - helper = LayerHelper('beam_search_decode', **locals()) - sentence_ids = helper.create_variable_for_type_inference(dtype=ids.dtype) - sentence_scores = helper.create_variable_for_type_inference( - dtype=scores.dtype - ) - - helper.append_op( - type="beam_search_decode", - inputs={"Ids": ids, "Scores": scores}, - outputs={ - "SentenceIds": sentence_ids, - "SentenceScores": sentence_scores, - }, - attrs={"beam_size": beam_size, "end_id": end_id}, - ) - - return sentence_ids, sentence_scores - - def lstm_unit( x_t, hidden_t_prev, diff --git a/python/paddle/fluid/tests/unittests/dist_transformer.py b/python/paddle/fluid/tests/unittests/dist_transformer.py index d8548cb324..e036692de4 100644 --- a/python/paddle/fluid/tests/unittests/dist_transformer.py +++ b/python/paddle/fluid/tests/unittests/dist_transformer.py @@ -1719,161 +1719,6 @@ def wrap_decoder( return predict -def fast_decode( - src_vocab_size, - trg_vocab_size, - max_in_len, - n_layer, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - dropout_rate, - weight_sharing, - beam_size, - max_out_len, - eos_idx, -): - """ - Use beam search to decode. Caches will be used to store states of history - steps which can make the decoding faster. - """ - enc_output = wrap_encoder( - src_vocab_size, - max_in_len, - n_layer, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - dropout_rate, - weight_sharing, - ) - start_tokens, init_scores, trg_src_attn_bias = make_all_inputs( - fast_decoder_data_input_fields - ) - - def beam_search(): - max_len = layers.fill_constant( - shape=[1], dtype=start_tokens.dtype, value=max_out_len - ) - step_idx = layers.fill_constant( - shape=[1], dtype=start_tokens.dtype, value=0 - ) - cond = paddle.less_than(x=step_idx, y=max_len) - while_op = paddle.static.nn.control_flow.While(cond) - # array states will be stored for each step. - ids = layers.array_write( - paddle.reshape(start_tokens, (-1, 1)), step_idx - ) - scores = layers.array_write(init_scores, step_idx) - # cell states will be overwrited at each step. - # caches contains states of history steps to reduce redundant - # computation in decoder. - caches = [ - { - "k": layers.fill_constant_batch_size_like( - input=start_tokens, - shape=[-1, 0, d_model], - dtype=enc_output.dtype, - value=0, - ), - "v": layers.fill_constant_batch_size_like( - input=start_tokens, - shape=[-1, 0, d_model], - dtype=enc_output.dtype, - value=0, - ), - } - for i in range(n_layer) - ] - with while_op.block(): - pre_ids = layers.array_read(array=ids, i=step_idx) - pre_ids = paddle.reshape(pre_ids, (-1, 1, 1)) - pre_scores = layers.array_read(array=scores, i=step_idx) - # sequence_expand can gather sequences according to lod thus can be - # used in beam search to sift states corresponding to selected ids. - pre_src_attn_bias = layers.sequence_expand( - x=trg_src_attn_bias, y=pre_scores - ) - pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores) - pre_caches = [ - { - "k": layers.sequence_expand(x=cache["k"], y=pre_scores), - "v": layers.sequence_expand(x=cache["v"], y=pre_scores), - } - for cache in caches - ] - pre_pos = layers.elementwise_mul( - x=layers.fill_constant_batch_size_like( - input=pre_enc_output, # can't use pre_ids here since it has lod - value=1, - shape=[-1, 1, 1], - dtype=pre_ids.dtype, - ), - y=layers.increment(x=step_idx, value=1.0, in_place=False), - axis=0, - ) - logits = wrap_decoder( - trg_vocab_size, - max_in_len, - n_layer, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - dropout_rate, - weight_sharing, - dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias), - enc_output=pre_enc_output, - caches=pre_caches, - ) - logits = paddle.reshape(logits, (-1, trg_vocab_size)) - - topk_scores, topk_indices = paddle.topk( - x=paddle.nn.functional.softmax(logits), k=beam_size - ) - accu_scores = layers.elementwise_add( - x=paddle.log(topk_scores), - y=paddle.reshape(pre_scores, shape=[-1]), - axis=0, - ) - # beam_search op uses lod to distinguish branches. - topk_indices = layers.lod_reset(topk_indices, pre_ids) - selected_ids, selected_scores = layers.beam_search( - pre_ids=pre_ids, - pre_scores=pre_scores, - ids=topk_indices, - scores=accu_scores, - beam_size=beam_size, - end_id=eos_idx, - ) - - layers.increment(x=step_idx, value=1.0, in_place=True) - # update states - layers.array_write(selected_ids, i=step_idx, array=ids) - layers.array_write(selected_scores, i=step_idx, array=scores) - layers.assign(pre_src_attn_bias, trg_src_attn_bias) - layers.assign(pre_enc_output, enc_output) - for i in range(n_layer): - layers.assign(pre_caches[i]["k"], caches[i]["k"]) - layers.assign(pre_caches[i]["v"], caches[i]["v"]) - length_cond = paddle.less_than(x=step_idx, y=max_len) - finish_cond = paddle.logical_not(layers.is_empty(x=selected_ids)) - paddle.logical_and(x=length_cond, y=finish_cond, out=cond) - - finished_ids, finished_scores = layers.beam_search_decode( - ids, scores, beam_size=beam_size, end_id=eos_idx - ) - return finished_ids, finished_scores - - finished_ids, finished_scores = beam_search() - return finished_ids, finished_scores - - def get_model(is_dist, is_async): sum_cost, avg_cost, predict, token_num = transformer( ModelHyperParams.src_vocab_size, diff --git a/python/paddle/fluid/tests/unittests/test_beam_search_decode_op.py b/python/paddle/fluid/tests/unittests/test_beam_search_decode_op.py index 062c00a03b..6fa06165e2 100644 --- a/python/paddle/fluid/tests/unittests/test_beam_search_decode_op.py +++ b/python/paddle/fluid/tests/unittests/test_beam_search_decode_op.py @@ -16,10 +16,7 @@ import unittest import numpy as np -import paddle -import paddle.fluid as fluid import paddle.fluid.core as core -from paddle.fluid.framework import Program, program_guard from paddle.fluid.op import Operator @@ -118,50 +115,5 @@ class TestBeamSearchDecodeOpGPU(TestBeamSearchDecodeOp): self.place = core.CUDAPlace(0) -class TestBeamSearchDecodeOpError(unittest.TestCase): - def test_errors(self): - with program_guard(Program(), Program()): - - def test_id_Variable(): - # the input pre_ids must be Variable - test_ids = np.random.randint(1, 5, [5, 1]).astype("int64") - scores = paddle.tensor.create_array(dtype='float32') - fluid.layers.beam_search_decode( - test_ids, scores, beam_size=5, end_id=0 - ) - - self.assertRaises(TypeError, test_id_Variable) - - def test_score_Variable(): - # the input pre_scores must be Variable - ids = paddle.tensor.create_array(dtype='int64') - test_scores = np.random.uniform(1, 5, [5, 1]).astype("float32") - fluid.layers.beam_search_decode( - ids, test_scores, beam_size=5, end_id=0 - ) - - self.assertRaises(TypeError, test_score_Variable) - - def test_id_dtype(): - # the dtype of input pre_ids must be int64 - type_ids = paddle.tensor.create_array(dtype='float32') - scores = paddle.tensor.create_array(dtype='float32') - fluid.layers.beam_search_decode( - type_ids, scores, beam_size=5, end_id=0 - ) - - self.assertRaises(TypeError, test_id_dtype) - - def test_score_dtype(): - # the dtype of input pre_scores must be float32 - ids = paddle.tensor.create_array(dtype='int64') - type_scores = paddle.tensor.create_array(dtype='int64') - fluid.layers.beam_search_decode( - ids, type_scores, beam_size=5, end_id=0 - ) - - self.assertRaises(TypeError, test_score_dtype) - - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_beam_search_op.py b/python/paddle/fluid/tests/unittests/test_beam_search_op.py index d492560a50..b10b90bcdd 100644 --- a/python/paddle/fluid/tests/unittests/test_beam_search_op.py +++ b/python/paddle/fluid/tests/unittests/test_beam_search_op.py @@ -16,10 +16,7 @@ import unittest import numpy as np -import paddle -import paddle.fluid as fluid import paddle.fluid.core as core -from paddle.fluid.framework import Program, program_guard from paddle.fluid.op import Operator @@ -302,119 +299,5 @@ class BeamSearchOpTester6(BeamSearchOpTester): self.output_parent_idx = np.array([0, 1, 2, 3]) -class TestBeamSearchOpError(unittest.TestCase): - def test_errors(self): - with program_guard(Program(), Program()): - pre_ids = fluid.data( - name='pre_id', shape=[1], lod_level=2, dtype='int64' - ) - pre_scores = fluid.data( - name='pre_scores', shape=[1], lod_level=2, dtype='float32' - ) - probs = fluid.data(name='probs', shape=[10000], dtype='float32') - topk_scores, topk_indices = paddle.topk(probs, k=4) - accu_scores = fluid.layers.elementwise_add( - x=paddle.log(x=topk_scores), - y=paddle.reshape(pre_scores, shape=[-1]), - axis=0, - ) - - def test_preids_Variable(): - # the input pre_ids must be Variable - preids_data = np.random.randint(1, 5, [5, 1]).astype("int64") - fluid.layers.beam_search( - pre_ids=preids_data, - pre_scores=pre_scores, - ids=topk_indices, - scores=accu_scores, - beam_size=4, - end_id=1, - ) - - self.assertRaises(TypeError, test_preids_Variable) - - def test_prescores_Variable(): - # the input pre_scores must be Variable - prescores_data = np.random.uniform(1, 5, [5, 1]).astype( - "float32" - ) - fluid.layers.beam_search( - pre_ids=pre_ids, - pre_scores=prescores_data, - ids=topk_indices, - scores=accu_scores, - beam_size=4, - end_id=1, - ) - - self.assertRaises(TypeError, test_prescores_Variable) - - def test_ids_Variable(): - # the input ids must be Variable or None - ids_data = np.random.randint(1, 5, [5, 1]).astype("int64") - fluid.layers.beam_search( - pre_ids=pre_ids, - pre_scores=pre_scores, - ids=ids_data, - scores=accu_scores, - beam_size=4, - end_id=1, - ) - - self.assertRaises(TypeError, test_ids_Variable) - - def test_scores_Variable(): - # the input scores must be Variable - scores_data = np.random.uniform(1, 5, [5, 1]).astype("float32") - fluid.layers.beam_search( - pre_ids=pre_ids, - pre_scores=pre_scores, - ids=topk_indices, - scores=scores_data, - beam_size=4, - end_id=1, - ) - - self.assertRaises(TypeError, test_scores_Variable) - - def test_preids_dtype(): - # the dtype of input pre_ids must be int64 - preids_type_data = fluid.data( - name='preids_type_data', - shape=[1], - lod_level=2, - dtype='float32', - ) - fluid.layers.beam_search( - pre_ids=preids_type_data, - pre_scores=pre_scores, - ids=topk_indices, - scores=accu_scores, - beam_size=4, - end_id=1, - ) - - self.assertRaises(TypeError, test_preids_dtype) - - def test_prescores_dtype(): - # the dtype of input pre_scores must be float32 - prescores_type_data = fluid.data( - name='prescores_type_data', - shape=[1], - lod_level=2, - dtype='int64', - ) - fluid.layers.beam_search( - pre_ids=pre_ids, - pre_scores=prescores_type_data, - ids=topk_indices, - scores=accu_scores, - beam_size=4, - end_id=1, - ) - - self.assertRaises(TypeError, test_prescores_dtype) - - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py index cddc44bbf7..4677baf27f 100644 --- a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py +++ b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py @@ -141,25 +141,17 @@ class Decoder: **kwargs ): output_layer = kwargs.pop("output_layer", None) - if self.decoding_strategy == "train_greedy": - # for teach-forcing MLE pre-training - helper = layers.TrainingHelper(**kwargs) - elif self.decoding_strategy == "infer_sample": - helper = layers.SampleEmbeddingHelper(**kwargs) - elif self.decoding_strategy == "infer_greedy": - helper = layers.GreedyEmbeddingHelper(**kwargs) - - if self.decoding_strategy == "beam_search": - beam_size = kwargs.get("beam_size", 4) - encoder_output = BeamSearchDecoder.tile_beam_merge_with_batch( - encoder_output, beam_size - ) - encoder_padding_mask = BeamSearchDecoder.tile_beam_merge_with_batch( - encoder_padding_mask, beam_size - ) - decoder = BeamSearchDecoder( - cell=self.decoder_cell, output_fn=output_layer, **kwargs - ) + + beam_size = kwargs.get("beam_size", 4) + encoder_output = BeamSearchDecoder.tile_beam_merge_with_batch( + encoder_output, beam_size + ) + encoder_padding_mask = BeamSearchDecoder.tile_beam_merge_with_batch( + encoder_padding_mask, beam_size + ) + decoder = BeamSearchDecoder( + cell=self.decoder_cell, output_fn=output_layer, **kwargs + ) ( decoder_output, -- GitLab