From 31b546463234a43a2e71554386ae41f5fc8afbfd Mon Sep 17 00:00:00 2001 From: Guo Sheng Date: Wed, 12 Feb 2020 23:02:46 +0800 Subject: [PATCH] Add support for dynamic_decode(while) training. (#22231) * Add support for dynamic_decode(while) training. test=develop * Fix assign_op and tensor_array_read_write_op after solving conflict. test=develop * Fix test_rnn_decode_api.py. test=develop * Refine docs for apis in rnn.py. test=develop * Adjust outputs of dynamic_decode. test=develop * Remove the force_cpu update in assign_op. test=develop * Remove the force_cpu update in assign_op. test=develop * Make RNNCell.get_initial_states support batch_dim_idx argument. test=develop * Rename _create_array_outof_while as _create_array_out_of_while in rnn.py. test=develop --- .../controlflow/tensor_array_read_write_op.cc | 22 + python/paddle/fluid/layers/rnn.py | 786 ++++++++++++++++-- python/paddle/fluid/layers/tensor.py | 2 + .../tests/unittests/test_rnn_decode_api.py | 570 ++++++++++--- 4 files changed, 1208 insertions(+), 172 deletions(-) diff --git a/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc b/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc index 1afc8f6d73..537b811230 100644 --- a/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc +++ b/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/array_operator.h" #include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/operators/math/math_function.h" + namespace paddle { namespace operators { @@ -152,6 +154,21 @@ class ReadFromArrayOp : public ArrayOp { out_tensor->set_lod(x_array[offset].lod()); } else { VLOG(10) << "offset " << offset << " >= " << x_array.size(); + // set grad of the writed tensor to 0 when used as write_to_array_grad + auto *fw_var = scope.FindVar(Input("X_W")); + if (fw_var == nullptr) return; + auto &fw_var_tensor = fw_var->Get(); + + framework::AttributeMap attrs; + attrs["dtype"] = fw_var_tensor.type(); + attrs["shape"] = framework::vectorize(fw_var_tensor.dims()); + attrs["value"] = 0.0f; + + auto zero_op = framework::OpRegistry::CreateOp( + "fill_constant", {}, {{"Out", {Output("Out")}}}, attrs); + zero_op->Run(scope, place); + auto *out_tensor = out->GetMutable(); + out_tensor->set_lod(fw_var_tensor.lod()); } } }; @@ -163,6 +180,10 @@ class ReadFromArrayProtoMaker : public framework::OpProtoAndCheckerMaker { AddInput("I", "(Tensor) the subscript index in tensor array. The number of " "element should be 1"); + AddInput("X_W", + "(Tensor) the writed tensor when used as the grad op of " + "write_to_array. We use this to fill zero gradient.") + .AsDispensable(); AddOutput("Out", "(LoDTensor) the tensor will be read from."); AddComment(R"DOC( ReadFromArray Operator. @@ -199,6 +220,7 @@ class WriteToArrayGradMaker : public framework::SingleGradOpMaker { grad_op->SetType("read_from_array"); grad_op->SetInput("I", this->Input("I")); grad_op->SetInput("X", this->OutputGrad("Out")); + grad_op->SetInput("X_W", this->Input("X")); grad_op->SetOutput("Out", this->InputGrad("X")); grad_op->SetAttrMap(this->Attrs()); return std::unique_ptr(grad_op); diff --git a/python/paddle/fluid/layers/rnn.py b/python/paddle/fluid/layers/rnn.py index 1055703232..43409c80f3 100644 --- a/python/paddle/fluid/layers/rnn.py +++ b/python/paddle/fluid/layers/rnn.py @@ -14,6 +14,7 @@ from __future__ import print_function +import sys from functools import partial, reduce from . import nn @@ -22,6 +23,8 @@ from . import control_flow from . import utils from . import sequence_lod from .utils import * +from ..framework import default_main_program +from ..data_feeder import convert_dtype from ..layer_helper import LayerHelper from ..framework import in_dygraph_mode from ..param_attr import ParamAttr @@ -34,6 +37,11 @@ __all__ = [ 'BeamSearchDecoder', 'rnn', 'dynamic_decode', + 'DecodeHelper', + 'TrainingHelper', + 'GreedyEmbeddingHelper', + 'SampleEmbeddingHelper', + 'BasicDecoder', 'dynamic_lstm', 'dynamic_lstmp', 'dynamic_gru', @@ -81,7 +89,8 @@ class RNNCell(object): batch_ref, shape=None, dtype=None, - init_value=0): + init_value=0, + batch_dim_idx=0): """ Generate initialized states according to provided shape, data type and value. @@ -100,6 +109,8 @@ class RNNCell(object): property `cell.state_shape` is not available, float32 will be used as the data type. The default value is None. init_value: A float value used to initialize states. + batch_dim_idx: An integer indicating which dimension of the tensor in + inputs represents batch size. The default value is 0. Returns: Variable: tensor variable[s] packed in the same structure provided \ @@ -109,10 +120,16 @@ class RNNCell(object): batch_ref = flatten(batch_ref)[0] def _is_shape_sequence(seq): + if sys.version_info < (3, ): + integer_types = ( + int, + long, ) + else: + integer_types = (int, ) """For shape, list/tuple of integer is the finest-grained objection""" if (isinstance(seq, list) or isinstance(seq, tuple)): - if reduce(lambda flag, x: isinstance(x, int) and flag, seq, - True): + if reduce(lambda flag, x: isinstance(x, integer_types) and flag, + seq, True): return False # TODO: Add check for the illegal if isinstance(seq, dict): @@ -145,12 +162,14 @@ class RNNCell(object): input=batch_ref, shape=shape.shape, dtype=dtype, - value=init_value), states_shapes, states_dtypes) + value=init_value, + input_dim_idx=batch_dim_idx), states_shapes, states_dtypes) return init_states @property def state_shape(self): """ + Abstract method (property). Used to initialize states. A (possiblely nested structure of) shape[s], where a shape is represented as a list/tuple of integers (-1 for batch size would be automatically @@ -159,11 +178,13 @@ class RNNCell(object): `get_initial_states` or the `shape` argument is provided when using `get_initial_states`. """ - raise NotImplementedError + raise NotImplementedError( + "Please add implementaion for `state_shape` in the used cell.") @property def state_dtype(self): """ + Abstract method (property). Used to initialize states. A (possiblely nested structure of) data types[s]. The structure must be same as that of `shape`, except when all tensors' in states has the same @@ -172,7 +193,8 @@ class RNNCell(object): by `get_initial_states` or the `dtype` argument is provided when using `get_initial_states`. """ - raise NotImplementedError + raise NotImplementedError( + "Please add implementaion for `state_dtype` in the used cell.") class GRUCell(RNNCell): @@ -437,7 +459,8 @@ def rnn(cell, return x if initial_states is None: - initial_states = cell.get_initial_states(batch_ref=inputs) + initial_states = cell.get_initial_states( + batch_ref=inputs, batch_dim_idx=1 if time_major else 0) initial_states = map_structure(_switch_grad, initial_states) if not time_major: @@ -499,7 +522,7 @@ class Decoder(object): 1. :code:`(initial_input, initial_state, finished) = initialize(inits)` , which generates the input and state for the first decoding step, and gives the - inintial status telling whether each sequence in the batch is finished. + initial status telling whether each sequence in the batch is finished. It would be called once before the decoding iterations. 2. :code:`(output, next_state, next_input, finished) = step(time, input, state)` , @@ -528,14 +551,14 @@ class Decoder(object): inits: Argument provided by the caller. Returns: - tuple: A tuple( :code:(initial_inputs, initial_states, finished)` ). \ + tuple: A tuple( :code:`(initial_inputs, initial_states, finished)` ). \ `initial_inputs` and `initial_states` both are a (possibly nested \ structure of) tensor variable[s], and `finished` is a tensor with \ bool data type. """ raise NotImplementedError - def step(self, time, inputs, states): + def step(self, time, inputs, states, **kwargs): """ Called per step of decoding. @@ -544,6 +567,7 @@ class Decoder(object): The data type is int64. inputs(Variable): A (possibly nested structure of) tensor variable[s]. states(Variable): A (possibly nested structure of) tensor variable[s]. + **kwargs: Additional keyword arguments, provided by the caller. Returns: tuple: A tuple( :code:(outputs, next_states, next_inputs, finished)` ). \ @@ -555,14 +579,6 @@ class Decoder(object): """ raise NotImplementedError - @property - def output_dtype(self): - """ - A (possiblely nested structure of) data type[s]. The structure must be - same as `outputs` returned by `decoder.step`. - """ - raise NotImplementedError - def finalize(self, outputs, final_states, sequence_lengths): """ Called once after the decoding iterations if implemented. @@ -796,12 +812,14 @@ class BeamSearchDecoder(Decoder): batch_size = tensor.cast( batch_size, indices.dtype) if batch_size.dtype != indices.dtype else batch_size + batch_size.stop_gradient = True # TODO: remove this batch_pos = nn.expand( nn.unsqueeze( tensor.range( 0, batch_size, 1, dtype=indices.dtype), [1]), [1, self.beam_size]) topk_coordinates = nn.stack([batch_pos, indices], axis=2) + topk_coordinates.stop_gradient = True return nn.gather_nd(x, topk_coordinates) class OutputWrapper( @@ -834,7 +852,7 @@ class BeamSearchDecoder(Decoder): Returns: tuple: A tuple( :code:`(initial_inputs, initial_states, finished)` ). \ `initial_inputs` is a tensor t filled by `start_token` with shape \ - `[batch_size, beam_size, 1]` when `embedding_fn` is None, or the \ + `[batch_size, beam_size]` when `embedding_fn` is None, or the \ returned value of `embedding_fn(t)` when `embedding_fn` is provided. \ `initial_states` is a nested structure(namedtuple including cell_states, \ log_probs, finished, lengths as fields) of tensor variables, where \ @@ -895,7 +913,7 @@ class BeamSearchDecoder(Decoder): beam_state(Variable): A structure of tensor variables. It is same as the `initial_states` returned by `initialize()` for the first decoding step and `beam_search_state` returned by - `initialize()` for the others. + `step()` for the others. Returns: tuple: A tuple( :code:`(beam_search_output, beam_search_state)` ). \ @@ -921,6 +939,7 @@ class BeamSearchDecoder(Decoder): # TODO: length penalty scores = log_probs scores = nn.reshape(scores, [-1, self.beam_size * self.vocab_size]) + # TODO: add grad for topk then this beam search can be used to train topk_scores, topk_indices = nn.topk(input=scores, k=self.beam_size) beam_indices = nn.elementwise_floordiv(topk_indices, self.vocab_size_tensor) @@ -993,6 +1012,7 @@ class BeamSearchDecoder(Decoder): beam_state=states) finished = beam_search_state.finished sample_ids = beam_search_output.predicted_ids + sample_ids.stop_gradient = True next_inputs = self.embedding_fn( sample_ids) if self.embedding_fn else sample_ids @@ -1027,20 +1047,14 @@ class BeamSearchDecoder(Decoder): # TODO: use FinalBeamSearchDecoderOutput as output return predicted_ids, final_states - @property - def output_dtype(self): - """ - The nested structure of data types for beam search output. It is a namedtuple - including scores, predicted_ids, parent_ids as fields. - """ - return self.OutputWrapper( - scores="float32", predicted_ids="int64", parent_ids="int64") - def dynamic_decode(decoder, inits=None, max_step_num=None, output_time_major=False, + impute_finished=False, + is_test=False, + return_length=False, **kwargs): """ Dynamic decoding performs :code:`decoder.step()` repeatedly until the returned @@ -1058,24 +1072,40 @@ def dynamic_decode(decoder, max_step_num(int, optional): The maximum number of steps. If not provided, decode until the decoder is fully done, or in other words, the returned Tensor by :code:`decoder.step()` indicating finished status contains - all True). Default `None`. + all True. Default `None`. output_time_major(bool, optional): Indicate the data layout of Tensor included in the final outpus(the first returned value of this method). If attr:`False`, the data layout would be batch major with shape `[batch_size, seq_len, ...]`. If attr:`True`, the data layout would be time major with shape `[seq_len, batch_size, ...]`. Default: `False`. + impute_finished(bool, optional): If `True`, then states get copied through + for batch entries which are marked as finished, which differs with the + unfinished using the new states returned by :code:`decoder.step()` and + ensures that the final states have the correct values. Otherwise, states + wouldn't be copied through when finished. If the returned `final_states` + is needed, it should be set as True, which causes some slowdown. + Default `False`. + is_test(bool, optional): A flag indicating whether to use test mode. In + test mode, it is more memory saving. Default `False`. + return_length(bool, optional): A flag indicating whether to return an + extra Tensor variable in the output tuple, which stores the actual + lengths of all decoded sequences. Default `False`. **kwargs: Additional keyword arguments. Arguments passed to `decoder.step`. Returns: - tuple: A tuple( :code:`(final_outputs, final_states)` ) including the final \ - outputs and states, both are Tensor or nested structure of Tensor. \ - `final_outputs` has the same structure and data types as \ - :code:`decoder.output_dtype` , and each Tenser in `final_outputs` \ + tuple: A tuple( :code:`(final_outputs, final_states, sequence_lengths)` ) \ + when `return_length` is True, otherwise a tuple( :code:`(final_outputs, final_states)` ). \ + The final outputs and states, both are Tensor or nested structure of Tensor. \ + `final_outputs` has the same structure and data types as the :code:`outputs` \ + returned by :code:`decoder.step()` , and each Tenser in `final_outputs` \ is the stacked of all decoding steps' outputs, which might be revised \ - by :code:`decoder.finalize` . `final_states` is the counterpart \ - at last time step of initial states returned by :code:`decoder.initialize` , \ - thus has the same structure with it and has tensors with same shapes \ - and data types. + by :code:`decoder.finalize()` if the decoder has implemented `finalize`. \ + `final_states` is the counterpart at last time step of initial states \ + returned by :code:`decoder.initialize()` , thus has the same structure \ + with it and has tensors with same shapes and data types. `sequence_lengths` \ + is an `int64` tensor with the same shape as `finished` returned \ + by :code:`decoder.initialize()` , and it stores the actual lengths of \ + all decoded sequences. Examples: @@ -1111,70 +1141,720 @@ def dynamic_decode(decoder, initial_inputs, initial_states, initial_finished = decoder.initialize(inits) global_inputs, global_states, global_finished = ( initial_inputs, initial_states, initial_finished) - + global_finished.stop_gradient = True step_idx = tensor.fill_constant(shape=[1], dtype="int64", value=0) + cond = control_flow.logical_not((nn.reduce_all(initial_finished))) if max_step_num is not None: max_step_num = tensor.fill_constant( shape=[1], dtype="int64", value=max_step_num) - while_op = control_flow.While(cond) + while_op = control_flow.While(cond, is_test=is_test) - inputs = map_structure(lambda x: x, initial_inputs) - states = map_structure(lambda x: x, initial_states) - outputs_arrays = map_structure( - lambda dtype: control_flow.create_array(dtype), decoder.output_dtype) sequence_lengths = tensor.cast(tensor.zeros_like(initial_finished), "int64") + sequence_lengths.stop_gradient = True + + if is_test: + # for test, reuse inputs and states variables to save memory + inputs = map_structure(lambda x: x, initial_inputs) + states = map_structure(lambda x: x, initial_states) + else: + # inputs and states of all steps must be saved for backward and training + inputs_arrays = map_structure( + lambda x: control_flow.array_write(x, step_idx), initial_inputs) + states_arrays = map_structure( + lambda x: control_flow.array_write(x, step_idx), initial_states) def _maybe_copy(state, new_state, step_mask): # TODO: use where_op + state_dtype = state.dtype + if convert_dtype(state_dtype) in ["bool"]: + state = tensor.cast(state, dtype="float32") + new_state = tensor.cast(new_state, dtype="float32") + if step_mask.dtype != state.dtype: + step_mask = tensor.cast(step_mask, dtype=state.dtype) + # otherwise, renamed bool gradients of would be summed up leading + # to sum(bool) error. + step_mask.stop_gradient = True new_state = nn.elementwise_mul( - new_state, step_mask, axis=0) - nn.elementwise_mul( - state, (step_mask - 1), axis=0) + state, step_mask, axis=0) - nn.elementwise_mul( + new_state, (step_mask - 1), axis=0) + if convert_dtype(state_dtype) in ["bool"]: + new_state = tensor.cast(new_state, dtype=state_dtype) return new_state def _transpose_batch_time(x): return nn.transpose(x, [1, 0] + list(range(2, len(x.shape)))) + def _create_array_out_of_while(dtype): + current_block_idx = default_main_program().current_block_idx + default_main_program().current_block_idx = default_main_program( + ).current_block().parent_idx + tensor_array = control_flow.create_array(dtype) + default_main_program().current_block_idx = current_block_idx + return tensor_array + # While with while_op.block(): + if not is_test: + inputs = map_structure( + lambda array: control_flow.array_read(array, step_idx), + inputs_arrays) + states = map_structure( + lambda array: control_flow.array_read(array, step_idx), + states_arrays) (outputs, next_states, next_inputs, next_finished) = decoder.step(step_idx, inputs, states, **kwargs) + next_finished = control_flow.logical_or(next_finished, global_finished) next_sequence_lengths = nn.elementwise_add( sequence_lengths, tensor.cast( control_flow.logical_not(global_finished), sequence_lengths.dtype)) + if impute_finished: # rectify the states for the finished. + next_states = map_structure( + lambda x, y: _maybe_copy(x, y, global_finished), + states, + next_states, ) + + # create tensor array in global block after dtype[s] of outputs can be got + outputs_arrays = map_structure( + lambda x: _create_array_out_of_while(x.dtype), outputs) + map_structure( lambda x, x_array: control_flow.array_write( x, i=step_idx, array=x_array), outputs, outputs_arrays) control_flow.increment(x=step_idx, value=1.0, in_place=True) - map_structure(tensor.assign, next_inputs, global_inputs) - map_structure(tensor.assign, next_states, global_states) + if is_test: + map_structure(tensor.assign, next_inputs, global_inputs) + map_structure(tensor.assign, next_states, global_states) + else: + map_structure( + lambda x, x_array: control_flow.array_write( + x, i=step_idx, array=x_array), next_inputs, inputs_arrays) + map_structure( + lambda x, x_array: control_flow.array_write( + x, i=step_idx, array=x_array), next_states, states_arrays) tensor.assign(next_finished, global_finished) tensor.assign(next_sequence_lengths, sequence_lengths) if max_step_num is not None: control_flow.logical_and( - control_flow.logical_not(nn.reduce_all(next_finished)), + control_flow.logical_not(nn.reduce_all(global_finished)), control_flow.less_equal(step_idx, max_step_num), cond) else: - control_flow.logical_not(nn.reduce_all(next_finished), cond) + control_flow.logical_not(nn.reduce_all(global_finished), cond) final_outputs = map_structure( lambda array: tensor.tensor_array_to_tensor( array, axis=0, use_stack=True)[0], outputs_arrays) - final_states = global_states + if is_test: + final_states = global_states + else: + final_states = map_structure( + lambda array: control_flow.array_read(array, step_idx), + states_arrays) try: final_outputs, final_states = decoder.finalize( - final_outputs, global_states, sequence_lengths) + final_outputs, final_states, sequence_lengths) except NotImplementedError: pass if not output_time_major: final_outputs = map_structure(_transpose_batch_time, final_outputs) - return final_outputs, final_states + return (final_outputs, final_states, + sequence_lengths) if return_length else (final_outputs, + final_states) + + +class DecodeHelper(object): + """ + DecodeHelper is the base class for any helper instance used in `BasicDecoder`. + It provides interface to implement sampling and produce inputs for the next + time step in dynamic decoding. + """ + + def initialize(self): + """ + DecodeHelper initialization to produce inputs for the first decoding step + and give the initial status telling whether each sequence in the batch + is finished. It is the partial of the initialization of `BasicDecoder`. + + Returns: + tuple: A tuple( :code:`(initial_inputs, initial_finished)` ). \ + `initial_inputs` is a (possibly nested structure of) tensor \ + variable[s], and the tensor's shape is `[batch_size, ...]`. \ + `initial_finished` is a bool tensor with shape `[batch_size]`. + """ + pass + + def sample(self, time, outputs, states): + """ + Perform sampling with some strategies according to `outputs`. It is the + partial of `BasicDecoder.step`. + + Parameters: + time(Variable): An `int64` tensor with shape `[1]` provided by the caller, + representing the current time step number of decoding. + outputs(Variable): A tensor variable. Usually it's data type is float32 + or float64, and it's shape is `[batch_size, vocabulary_size]`, + representing the predicted logits of current step. It is same as + `outputs` returned by `BasicDecoder.output_fn(BasicDecoder.cell.call())`. + states(Variable): A (possibly nested structure of) tensor variable[s]. + It is same as `new_states` returned by `BasicDecoder.cell.call()`. + + Returns: + Variable: An `int64` tensor representing the sampled ids. + """ + pass + + def next_inputs(self, time, outputs, states, sample_ids): + """ + Produce the inputs and states for next time step and give status telling + whether each minibatch entry is finished. It is called after `sample` in + `BasicDecoder.step`. It is the partial of `BasicDecoder.step`. + + Parameters: + time(Variable): An `int64` tensor with shape `[1]` provided by the caller, + representing the current time step number of decoding. + outputs(Variable): A tensor variable. Usually it's data type is float32 + or float64, and it's shape is `[batch_size, vocabulary_size]`, + representing the predicted logits of current step. It is same as + `outputs` returned by `BasicDecoder.output_fn(BasicDecoder.cell.call())`. + states(Variable): A (possibly nested structure of) tensor variable[s]. + It is same as `new_states` returned by `BasicDecoder.cell.call()`. + sample_ids(Variable): A (possibly nested structure of) tensor variable[s]. + It is same as `sample_ids` returned by `sample()`. + + Returns: + tuple: A tuple( :code:`(finished, next_inputs, next_states)` ). \ + `next_inputs` and `next_states` both are a (possibly nested \ + structure of) tensor variable[s], and the structure, shape and \ + data type of `next_states` must be same as the input argument \ + `states`. `finished` is a bool tensor with shape `[batch_size]`. + """ + pass + + +class TrainingHelper(DecodeHelper): + """ + TrainingHelper is a subclass of DecodeHelper. It is a decoding helper + slicing from the full sequence inputs as the inputs for corresponding + step. And it uses `argmax` to sample from the outputs of `cell.call()`. + + Since the needs of sequence inputs, it is used mostly for teach-forcing MLE + (maximum likelihood) training, and the sampled would not be used. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.fluid.layers as layers + trg_emb = fluid.data(name="trg_emb", + shape=[None, None, 128], + dtype="float32") + trg_seq_length = fluid.data(name="trg_seq_length", + shape=[None], + dtype="int64") + helper = layers.TrainingHelper(trg_emb, trg_seq_length) + decoder_cell = layers.GRUCell(hidden_size=128) + decoder = layers.BasicDecoder(decoder_cell, helper) + outputs = layers.dynamic_decode( + decoder, + inits=decoder_cell.get_initial_states(trg_emb), + is_test=False) + """ + + def __init__(self, inputs, sequence_length, time_major=False): + """ + Constructor of TrainingHelper. + + Parameters: + inputs(Variable): A (possibly nested structure of) tensor variable[s]. + The shape of tensor should be `[batch_size, sequence_length, ...]` + for `time_major == False` or `[sequence_length, batch_size, ...]` + for `time_major == True`. It represents the inputs to be sliced + from at every decoding step. + sequence_length(Variable): A tensor with shape `[batch_size]`. + It stores real length of each instance in `inputs`, by which we + can label the finished status of each instance at every decoding + step. + time_major(bool, optional): Indicate the data layout of Tensor included + in `inputs`. If `False`, the data layout would be batch major with + shape `[batch_size, sequence_length, ...]`. If `True`, the data + layout would be time major with shape `[sequence_length, batch_size, ...]`. + Default: `False`. + """ + self.inputs = inputs + self.sequence_length = sequence_length + self.time_major = time_major + # extend inputs to avoid to slice out of range in `next_inputs` + # may be easier and have better performance than condition_op + self.inputs_ = map_structure( + lambda x: nn.pad(x, + paddings=([0, 1] + [0, 0] * (len(x.shape) - 1)) + if time_major else ([0, 0, 0, 1] + [0, 0] * + (len(x.shape) - 2))), + self.inputs) + + def initialize(self): + """ + TrainingHelper initialization produces inputs for the first decoding + step by slicing at the first time step of full sequence inputs, and it + gives initial status telling whether each sequence in the batch is + finished. It is the partial of the initialization of `BasicDecoder`. + + Returns: + tuple: A tuple( :code:`(initial_inputs, initial_finished)` ). \ + `initial_inputs` is a (possibly nested structure of) tensor \ + variable[s], and the tensor's shape is `[batch_size, ...]`. \ + `initial_finished` is a bool tensor with shape `[batch_size]`. + """ + init_finished = control_flow.equal( + self.sequence_length, + tensor.fill_constant( + shape=[1], dtype=self.sequence_length.dtype, value=0)) + # TODO: support zero length + init_inputs = map_structure( + lambda x: x[0] if self.time_major else x[:, 0], self.inputs) + return init_inputs, init_finished + + def sample(self, time, outputs, states): + """ + Perform sampling by using `argmax` according to the `outputs`. Mostly + the sampled ids would not be used since the inputs for next decoding + step would be got by slicing. + + Parameters: + time(Variable): An `int64` tensor with shape `[1]` provided by the + caller, representing the current time step number of decoding. + outputs(Variable): A tensor variable. Usually it's data type is float32 + or float64, and it's shape is `[batch_size, vocabulary_size]`, + representing the predicted logits of current step. It is same as + `outputs` returned by `BasicDecoder.output_fn(BasicDecoder.cell.call())`. + states(Variable): A (possibly nested structure of) tensor variable[s]. + It is same as `new_states` returned by `BasicDecoder.cell.call()`. + + Returns: + Variable: An `int64` tensor with shape `[batch_size]`, representing \ + the sampled ids. + """ + sample_ids = tensor.argmax(outputs, axis=-1) + return sample_ids + + def next_inputs(self, time, outputs, states, sample_ids): + """ + Generate inputs for the next decoding step by slicing at corresponding + step of the full sequence inputs. Simultaneously, produce the states + for next time step by directly using the input `states` and emit status + telling whether each minibatch entry reaches to the corresponding length. + + Parameters: + time(Variable): An `int64` tensor with shape `[1]` provided by the + caller, representing the current time step number of decoding. + outputs(Variable): A tensor variable. Usually it's data type is float32 + or float64, and it's shape is `[batch_size, vocabulary_size]`, + representing the predicted logits of current step. It is same as + `outputs` returned by `BasicDecoder.output_fn(BasicDecoder.cell.call())`. + states(Variable): A (possibly nested structure of) tensor variable[s]. + It is same as `new_states` returned by `BasicDecoder.cell.call()`. + sample_ids(Variable): An `int64` tensor variable shaped `[batch_size]`. + It is same as `sample_ids` returned by `sample()`. + + Returns: + tuple: A tuple( :code:`(finished, next_inputs, next_states)` ). \ + `next_inputs` and `next_states` both are a (possibly nested \ + structure of) tensor variable[s], and the tensor's shape is \ + `[batch_size, ...]`. `next_states` is identical to the input \ + argument `states`. `finished` is a `bool` Tensor with \ + shape `[batch_size]`. + """ + # TODO: compatibility of int32 and int64 + time = tensor.cast( + time, + "int32") if convert_dtype(time.dtype) not in ["int32"] else time + if self.sequence_length.dtype != time.dtype: + self.sequence_length = tensor.cast(self.sequence_length, time.dtype) + next_time = time + 1 + finished = control_flow.less_equal(self.sequence_length, next_time) + + def _slice(x): # TODO: use Variable.__getitem__ + axes = [0 if self.time_major else 1] + return nn.squeeze( + nn.slice( + x, axes=axes, starts=[next_time], ends=[next_time + 1]), + axes=axes) + + next_inputs = map_structure(_slice, self.inputs_) + return finished, next_inputs, states + + +class GreedyEmbeddingHelper(DecodeHelper): + """ + GreedyEmbeddingHelper is a subclass of DecodeHelper. It is a decoding helper + uses the argmax of the output (treated as logits) and passes the results + through an embedding layer to get inputs for the next decoding step. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.fluid.layers as layers + trg_emb = fluid.data(name="trg_emb", + shape=[None, None, 128], + dtype="float32") + + trg_embeder = lambda x: fluid.embedding( + x, size=[10000, 128], param_attr=fluid.ParamAttr(name="trg_embedding")) + output_layer = lambda x: layers.fc(x, + size=10000, + num_flatten_dims=len(x.shape) - 1, + param_attr=fluid.ParamAttr(name= + "output_w"), + bias_attr=False) + helper = layers.GreedyEmbeddingHelper(trg_embeder, start_tokens=0, end_token=1) + decoder_cell = layers.GRUCell(hidden_size=128) + decoder = layers.BasicDecoder(decoder_cell, helper, output_fn=output_layer) + outputs = layers.dynamic_decode( + decoder=decoder, inits=decoder_cell.get_initial_states(encoder_output)) + """ + + def __init__(self, embedding_fn, start_tokens, end_token): + """ + Constructor of GreedyEmbeddingHelper. + + Parameters: + embedding_fn(callable): A functor to apply on the argmax results. + Mostly it is an embedding layer to transform ids to embeddings. + **Note that fluid.embedding should be used here rather than + fluid.layers.embedding, since shape of ids is [batch_size]. + when using fluid.layers.embedding, must unsqueeze in embedding_fn.** + start_tokens(Variable): A `int64` tensor shaped `[batch_size]`, + representing the start tokens. + end_token(int): The end token id. + + Returns: + tuple: A tuple( :code:`(initial_inputs, initial_states, finished)` ). \ + `initial_inputs` and `initial_states` both are a (possibly nested \ + structure of) tensor variable[s], and `finished` is a tensor with \ + bool data type. + """ + self.embedding_fn = embedding_fn + self.start_tokens = start_tokens + self.end_token = tensor.fill_constant( + shape=[1], dtype="int64", value=end_token) + + def initialize(self): + """ + GreedyEmbeddingHelper initialization produces inputs for the first decoding + step by using `start_tokens` of the constructor, and gives initial + status telling whether each sequence in the batch is finished. + It is the partial of the initialization of `BasicDecoder`. + + Returns: + tuple: A tuple( :code:`(initial_inputs, initial_finished)` ). \ + `initial_inputs` is same as `start_tokens` of the constructor. \ + `initial_finished` is a `bool` tensor filled by False and has \ + the same shape as `start_tokens`. + """ + # TODO: remove the restriction of force_cpu + init_finished = tensor.fill_constant_batch_size_like( + input=self.start_tokens, + shape=[-1], + dtype="bool", + value=False, + force_cpu=True) + init_inputs = self.embedding_fn(self.start_tokens) + return init_inputs, init_finished + + def sample(self, time, outputs, states): + """ + Perform sampling by using `argmax` according to the `outputs`. + + Parameters: + time(Variable): An `int64` tensor with shape `[1]` provided by the + caller, representing the current time step number of decoding. + outputs(Variable): A tensor variable. Usually it's data type is float32 + or float64, and it's shape is `[batch_size, vocabulary_size]`, + representing the predicted logits of current step. It is same as + `outputs` returned by `BasicDecoder.output_fn(BasicDecoder.cell.call())`. + states(Variable): A (possibly nested structure of) tensor variable[s]. + It is same as `new_states` returned by `BasicDecoder.cell.call()`. + + Returns: + Variable: An `int64` tensor with shape `[batch_size]`, representing \ + the sampled ids. + """ + sample_ids = tensor.argmax(outputs, axis=-1) + return sample_ids + + def next_inputs(self, time, outputs, states, sample_ids): + """ + Generate inputs for the next decoding step by applying `embedding_fn` + to `sample_ids`. Simultaneously, produce the states for next time step + by directly using the input `states` and emit status telling whether + each minibatch entry gets an `end_token` sample. + + Parameters: + time(Variable): An `int64` tensor with shape `[1]` provided by the + caller, representing the current time step number of decoding. + outputs(Variable): A tensor variable. Usually it's data type is float32 + or float64, and it's shape is `[batch_size, vocabulary_size]`, + representing the predicted logits of current step. It is same as + `outputs` returned by `BasicDecoder.output_fn(BasicDecoder.cell.call())`. + states(Variable): A (possibly nested structure of) tensor variable[s]. + It is same as `new_states` returned by `BasicDecoder.cell.call()`. + sample_ids(Variable): An `int64` tensor variable shaped `[batch_size]`. + It is same as `sample_ids` returned by `sample()`. + + Returns: + tuple: A tuple( :code:`(finished, next_inputs, next_states)` ). \ + `next_inputs` and `next_states` both are a (possibly nested \ + structure of) tensor variable[s], and the tensor's shape is \ + `[batch_size, ...]`. `next_states` is identical to the input \ + argument `states`. `finished` is a `bool` Tensor with \ + shape `[batch_size]`. + """ + finished = control_flow.equal(sample_ids, self.end_token) + next_inputs = self.embedding_fn(sample_ids) + return finished, next_inputs, states + + +class SampleEmbeddingHelper(GreedyEmbeddingHelper): + """ + SampleEmbeddingHelper is a subclass of GreedyEmbeddingHelper. It is a decoding + helper uses sampling (from a distribution) instead of argmax of the output + (treated as logits) and passes the results through an embedding layer to get + inputs for the next decoding step. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.fluid.layers as layers + trg_emb = fluid.data(name="trg_emb", + shape=[None, None, 128], + dtype="float32") + + trg_embeder = lambda x: fluid.embedding( + x, size=[10000, 128], param_attr=fluid.ParamAttr(name="trg_embedding")) + output_layer = lambda x: layers.fc(x, + size=10000, + num_flatten_dims=len(x.shape) - 1, + param_attr=fluid.ParamAttr(name= + "output_w"), + bias_attr=False) + helper = layers.SampleEmbeddingHelper(trg_embeder, start_tokens=0, end_token=1) + decoder_cell = layers.GRUCell(hidden_size=128) + decoder = layers.BasicDecoder(decoder_cell, helper, output_fn=output_layer) + outputs = layers.dynamic_decode( + decoder=decoder, inits=decoder_cell.get_initial_states(encoder_output)) + """ + + def __init__(self, + embedding_fn, + start_tokens, + end_token, + softmax_temperature=None, + seed=None): + """ + Constructor of SampleEmbeddingHelper. + + Parameters: + embedding_fn(callable): A functor to apply on the argmax results. + Mostly it is an embedding layer to transform ids to embeddings. + **Note that fluid.embedding should be used here rather than + fluid.layers.embedding, since shape of ids is [batch_size]. + when using fluid.layers.embedding, must unsqueeze in embedding_fn.** + start_tokens(Variable): A `int64` tensor shaped `[batch_size]`, + representing the start tokens. + end_token(int): The end token id. + softmax_temperature(float, optional): the value to divide the logits + by before computing the softmax. Higher temperatures (above 1.0) + lead to more random, while lower temperatures push the sampling + distribution towards the argmax. It must be strictly greater than + 0. Defaults to None, meaning using a temperature valued 1.0. + seed: (int, optional) The sampling seed. Defaults to None, meaning not + to use fixed seed. + + Returns: + tuple: A tuple( :code:`(initial_inputs, initial_states, finished)` ). \ + `initial_inputs` and `initial_states` both are a (possibly nested \ + structure of) tensor variable[s], and `finished` is a tensor with \ + bool data type. + """ + super(SampleEmbeddingHelper, self).__init__(embedding_fn, start_tokens, + end_token) + self.softmax_temperature = tensor.fill_constant( + shape=[1], dtype="float32", value=softmax_temperature + ) if softmax_temperature is not None else None + self.seed = seed + + def sample(self, time, outputs, states): + """ + Perform sampling from a categorical distribution, and the distribution + is computed by `softmax(outputs/softmax_temperature)`. + + Parameters: + time(Variable): An `int64` tensor with shape `[1]` provided by the + caller, representing the current time step number of decoding. + outputs(Variable): A tensor variable. Usually it's data type is float32 + or float64, and it's shape is `[batch_size, vocabulary_size]`, + representing the predicted logits of current step. It is same as + `outputs` returned by `BasicDecoder.output_fn(BasicDecoder.cell.call())`. + states(Variable): A (possibly nested structure of) tensor variable[s]. + It is same as `new_states` returned by `BasicDecoder.cell.call()`. + + Returns: + Variable: An `int64` tensor with shape `[batch_size]`, representing \ + the sampled ids. + """ + logits = (outputs / self.softmax_temperature + ) if self.softmax_temperature is not None else outputs + probs = nn.softmax(logits) + # TODO: remove this stop_gradient. The stop_gradient of sample_ids can + # not pass to probs, since sampling_id op does not have corresponding + # grad op and thus can not pass. + probs.stop_gradient = True + sample_ids = nn.sampling_id( + probs, seed=self.seed, dtype=self.start_tokens.dtype) + return sample_ids + + +class BasicDecoder(Decoder): + """ + BasicDecoder is a subclass of Decoder and assembles a RNNCell and DecodeHelper + instance as members, where the DecodeHelper helps to implement customed + decoding strategies.. It performs one decoding step as following steps: + + 1. Perform `cell_outputs, cell_states = cell.call(inputs, states)` + to get outputs and new states from cell. + + 2. Perform `sample_ids = helper.sample(time, cell_outputs, cell_states)` + to sample ids as decoded results of the current time step. + + 3. Perform `finished, next_inputs, next_states = helper.next_inputs(time, + cell_outputs, cell_states, sample_ids)` to generate inputs, states and + finished status for the next decoding step. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.fluid.layers as layers + trg_emb = fluid.data(name="trg_emb", + shape=[None, None, 128], + dtype="float32") + + trg_embeder = lambda x: fluid.embedding( + x, size=[10000, 128], param_attr=fluid.ParamAttr(name="trg_embedding")) + output_layer = lambda x: layers.fc(x, + size=10000, + num_flatten_dims=len(x.shape) - 1, + param_attr=fluid.ParamAttr(name= + "output_w"), + bias_attr=False) + helper = layers.SampleEmbeddingHelper(trg_embeder, start_tokens=0, end_token=1) + decoder_cell = layers.GRUCell(hidden_size=128) + decoder = layers.BasicDecoder(decoder_cell, helper, output_fn=output_layer) + outputs = layers.dynamic_decode( + decoder=decoder, inits=decoder_cell.get_initial_states(encoder_output)) + """ + + def __init__(self, cell, helper, output_fn=None): + """ + Constructor of BasicDecoder. + + Parameters: + cell(RNNCell): An instance of `RNNCell` or object with the same interface. + helper(DecodeHelper): An instance of `DecodeHelper`. + output_fn(optional): A callable to apply to the cell's output prior to + sampling. Default None. + """ + self.cell = cell + self.helper = helper + self.output_fn = output_fn + + def initialize(self, initial_cell_states): + """ + BasicDecoder initialization includes helper initialization and cell + initialization, and cell initialization uses `initial_cell_states` as + the result directly. + + Parameters: + initial_cell_states(Variable): A (possibly nested structure of) + tensor variable[s]. An argument provided by the caller `dynamic_decode`. + + Returns: + tuple: A tuple( :code:(initial_inputs, initial_cell_states, finished)` ). \ + `initial_inputs` and `initial_states` both are a (possibly nested \ + structure of) tensor variable[s], and `finished` is a tensor with \ + bool data type. `initial_inputs` and `finished` are the results \ + of `helper.initialize()`, and `initial_cell_states` is same as \ + the input argument counterpart. + """ + (initial_inputs, initial_finished) = self.helper.initialize() + return initial_inputs, initial_cell_states, initial_finished + + class OutputWrapper( + collections.namedtuple("OutputWrapper", + ("cell_outputs", "sample_ids"))): + """ + The structure for the returned value `outputs` of `decoder.step`. + A namedtuple includes cell_outputs, sample_ids as fields. + """ + pass + + def step(self, time, inputs, states, **kwargs): + """ + Perform one decoding step as following steps: + + 1. Perform `cell_outputs, cell_states = cell.call(inputs, states)` + to get outputs and new states from cell. + + 2. Perform `sample_ids = helper.sample(time, cell_outputs, cell_states)` + to sample ids as decoded results of the current time step. + + 3. Perform `finished, next_inputs, next_states = helper.next_inputs(time, + cell_outputs, cell_states, sample_ids)` to generate inputs, states and + finished status for the next decoding step. + + Parameters: + time(Variable): An `int64` tensor with shape `[1]` provided by the caller, + representing the current time step number of decoding. + inputs(Variable): A tensor variable. It is same as `initial_inputs` + returned by `initialize()` for the first decoding step and + `next_inputs` returned by `step()` for the others. + states(Variable): A structure of tensor variables. + It is same as the `initial_cell_states` returned by `initialize()` + for the first decoding step and `next_states` returned by + `step()` for the others. + **kwargs: Additional keyword arguments, provided by the caller + `dynamic_decode`. + + Returns: + tuple: A tuple( :code:`(outputs, next_states, next_inputs, finished)` ). \ + `outputs` is a namedtuple(including cell_outputs, sample_ids, \ + as fields) of tensor variables, where `cell_outputs` is the result \ + fof `cell.call()` and `sample_ids` is the result of `helper.sample()`. \ + `next_states` and `next_inputs` have the same structure, shape \ + and data type as the input arguments `states` and `inputs` separately. \ + `finished` is a `bool` tensor with shape `[batch_size]`. + """ + cell_outputs, cell_states = self.cell(inputs, states, **kwargs) + if self.output_fn is not None: + cell_outputs = self.output_fn(cell_outputs) + sample_ids = self.helper.sample( + time=time, outputs=cell_outputs, states=cell_states) + sample_ids.stop_gradient = True + (finished, next_inputs, next_states) = self.helper.next_inputs( + time=time, + outputs=cell_outputs, + states=cell_states, + sample_ids=sample_ids) + outputs = self.OutputWrapper(cell_outputs, sample_ids) + return (outputs, next_states, next_inputs, finished) def dynamic_lstm(input, diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 5f2579b103..c8b8e63413 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -780,6 +780,7 @@ def argmin(x, axis=0): inputs={'X': x}, outputs={'Out': [out]}, attrs={'axis': axis}) + out.stop_gradient = True return out @@ -839,6 +840,7 @@ def argmax(x, axis=0): inputs={'X': x}, outputs={'Out': [out]}, attrs={'axis': axis}) + out.stop_gradient = True return out diff --git a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py index 55365abd49..8b0b144c6d 100644 --- a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py +++ b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py @@ -15,7 +15,7 @@ from __future__ import print_function import unittest -import numpy +import numpy as np import paddle.fluid as fluid import paddle.fluid.layers as layers @@ -24,22 +24,15 @@ import paddle.fluid.core as core from paddle.fluid.executor import Executor from paddle.fluid import framework -from paddle.fluid.layers.rnn import LSTMCell, GRUCell, RNNCell, BeamSearchDecoder, dynamic_decode -from paddle.fluid.layers import rnn as dynamic_rnn -from paddle.fluid import contrib -from paddle.fluid.contrib.layers import basic_lstm - -import numpy as np - -class EncoderCell(RNNCell): +class EncoderCell(layers.RNNCell): def __init__(self, num_layers, hidden_size, dropout_prob=0.): self.num_layers = num_layers self.hidden_size = hidden_size self.dropout_prob = dropout_prob - self.lstm_cells = [] - for i in range(num_layers): - self.lstm_cells.append(LSTMCell(hidden_size)) + self.lstm_cells = [ + layers.LSTMCell(hidden_size) for i in range(num_layers) + ] def call(self, step_input, states): new_states = [] @@ -55,14 +48,14 @@ class EncoderCell(RNNCell): return [cell.state_shape for cell in self.lstm_cells] -class DecoderCell(RNNCell): +class DecoderCell(layers.RNNCell): def __init__(self, num_layers, hidden_size, dropout_prob=0.): self.num_layers = num_layers self.hidden_size = hidden_size self.dropout_prob = dropout_prob - self.lstm_cells = [] - for i in range(num_layers): - self.lstm_cells.append(LSTMCell(hidden_size)) + self.lstm_cells = [ + layers.LSTMCell(hidden_size) for i in range(num_layers) + ] def attention(self, hidden, encoder_output, encoder_padding_mask): query = layers.fc(hidden, @@ -97,117 +90,456 @@ class DecoderCell(RNNCell): return out, [new_lstm_states, out] -class TestDynamicDecode(unittest.TestCase): - def setUp(self): - self.batch_size = 4 - self.input_size = 16 - self.hidden_size = 16 - self.seq_len = 4 - - def test_run(self): - start_token = 0 - end_token = 1 - src_vocab_size = 10 - trg_vocab_size = 10 - num_layers = 1 - hidden_size = self.hidden_size - beam_size = 8 - max_length = self.seq_len - - src = layers.data(name="src", shape=[-1, 1], dtype='int64') - src_len = layers.data(name="src_len", shape=[-1], dtype='int64') - - trg = layers.data(name="trg", shape=[-1, 1], dtype='int64') - trg_len = layers.data(name="trg_len", shape=[-1], dtype='int64') - - src_embeder = lambda x: fluid.embedding( - x, - size=[src_vocab_size, hidden_size], - param_attr=fluid.ParamAttr(name="src_embedding")) +class Encoder(object): + def __init__(self, num_layers, hidden_size, dropout_prob=0.): + self.encoder_cell = EncoderCell(num_layers, hidden_size, dropout_prob) - trg_embeder = lambda x: fluid.embedding( - x, - size=[trg_vocab_size, hidden_size], - param_attr=fluid.ParamAttr(name="trg_embedding")) - - # use basic_lstm - encoder_cell = EncoderCell(num_layers, hidden_size) - encoder_output, encoder_final_state = dynamic_rnn( - cell=encoder_cell, - inputs=src_embeder(src), - sequence_length=src_len, + def __call__(self, src_emb, src_sequence_length): + encoder_output, encoder_final_state = layers.rnn( + cell=self.encoder_cell, + inputs=src_emb, + sequence_length=src_sequence_length, is_reverse=False) + return encoder_output, encoder_final_state + + +class Decoder(object): + def __init__(self, + num_layers, + hidden_size, + dropout_prob, + decoding_strategy="infer_sample", + max_decoding_length=20): + self.decoder_cell = DecoderCell(num_layers, hidden_size, dropout_prob) + self.decoding_strategy = decoding_strategy + self.max_decoding_length = None if ( + self.decoding_strategy == "train_greedy") else max_decoding_length + + def __call__(self, decoder_initial_states, encoder_output, + encoder_padding_mask, **kwargs): + output_layer = kwargs.pop("output_layer", None) + if self.decoding_strategy == "train_greedy": + # for teach-forcing MLE pre-training + helper = layers.TrainingHelper(**kwargs) + elif self.decoding_strategy == "infer_sample": + helper = layers.SampleEmbeddingHelper(**kwargs) + elif self.decoding_strategy == "infer_greedy": + helper = layers.GreedyEmbeddingHelper(**kwargs) + + if self.decoding_strategy == "beam_search": + beam_size = kwargs.get("beam_size", 4) + encoder_output = layers.BeamSearchDecoder.tile_beam_merge_with_batch( + encoder_output, beam_size) + encoder_padding_mask = layers.BeamSearchDecoder.tile_beam_merge_with_batch( + encoder_padding_mask, beam_size) + decoder = layers.BeamSearchDecoder( + cell=self.decoder_cell, output_fn=output_layer, **kwargs) + else: + decoder = layers.BasicDecoder( + self.decoder_cell, helper, output_fn=output_layer) + + (decoder_output, decoder_final_state, + dec_seq_lengths) = layers.dynamic_decode( + decoder, + inits=decoder_initial_states, + max_step_num=self.max_decoding_length, + encoder_output=encoder_output, + encoder_padding_mask=encoder_padding_mask, + impute_finished=False # for test coverage + if self.decoding_strategy == "beam_search" else True, + is_test=True if self.decoding_strategy == "beam_search" else False, + return_length=True) + return decoder_output, decoder_final_state, dec_seq_lengths + + +class Seq2SeqModel(object): + """Seq2Seq model: RNN encoder-decoder with attention""" + + def __init__(self, + num_layers, + hidden_size, + dropout_prob, + src_vocab_size, + trg_vocab_size, + start_token, + end_token, + decoding_strategy="infer_sample", + max_decoding_length=20, + beam_size=4): + self.start_token, self.end_token = start_token, end_token + self.max_decoding_length, self.beam_size = max_decoding_length, beam_size + self.src_embeder = lambda x: fluid.embedding( + input=x, + size=[src_vocab_size, hidden_size], + dtype="float32", + param_attr=fluid.ParamAttr(name="source_embedding")) + self.trg_embeder = lambda x: fluid.embedding( + input=x, + size=[trg_vocab_size, hidden_size], + dtype="float32", + param_attr=fluid.ParamAttr(name="target_embedding")) + self.encoder = Encoder(num_layers, hidden_size, dropout_prob) + self.decoder = Decoder(num_layers, hidden_size, dropout_prob, + decoding_strategy, max_decoding_length) + self.output_layer = lambda x: layers.fc( + x, + size=trg_vocab_size, + num_flatten_dims=len(x.shape) - 1, + param_attr=fluid.ParamAttr(name="output_w"), + bias_attr=False) - src_mask = layers.sequence_mask( - src_len, maxlen=layers.shape(src)[1], dtype='float32') - encoder_padding_mask = (src_mask - 1.0) * 1000000000 - encoder_padding_mask = layers.unsqueeze(encoder_padding_mask, [1]) + def __call__(self, src, src_length, trg=None, trg_length=None): + # encoder + encoder_output, encoder_final_state = self.encoder( + self.src_embeder(src), src_length) - decoder_cell = DecoderCell(num_layers, hidden_size) decoder_initial_states = [ - encoder_final_state, decoder_cell.get_initial_states( - batch_ref=encoder_output, shape=[hidden_size]) + encoder_final_state, self.decoder.decoder_cell.get_initial_states( + batch_ref=encoder_output, shape=[encoder_output.shape[-1]]) ] + src_mask = layers.sequence_mask( + src_length, maxlen=layers.shape(src)[1], dtype="float32") + encoder_padding_mask = (src_mask - 1.0) * 1e9 + encoder_padding_mask = layers.unsqueeze(encoder_padding_mask, [1]) - decoder_output, _ = dynamic_rnn( - cell=decoder_cell, - inputs=trg_embeder(trg), - initial_states=decoder_initial_states, - sequence_length=None, - encoder_output=encoder_output, - encoder_padding_mask=encoder_padding_mask) - - output_layer = lambda x: layers.fc(x, - size=trg_vocab_size, - num_flatten_dims=len(x.shape) - 1, - param_attr=fluid.ParamAttr( - name="output_w"), - bias_attr=False) - - # inference - encoder_output = BeamSearchDecoder.tile_beam_merge_with_batch( - encoder_output, beam_size) - encoder_padding_mask = BeamSearchDecoder.tile_beam_merge_with_batch( - encoder_padding_mask, beam_size) - beam_search_decoder = BeamSearchDecoder( - decoder_cell, - start_token, - end_token, - beam_size, - embedding_fn=trg_embeder, - output_fn=output_layer) - outputs, _ = dynamic_decode( - beam_search_decoder, - inits=decoder_initial_states, - max_step_num=max_length, - encoder_output=encoder_output, - encoder_padding_mask=encoder_padding_mask) - - if core.is_compiled_with_cuda(): - place = core.CUDAPlace(0) + # decoder + decoder_kwargs = { + "inputs": self.trg_embeder(trg), + "sequence_length": trg_length, + } if self.decoder.decoding_strategy == "train_greedy" else ({ + "embedding_fn": self.trg_embeder, + "beam_size": self.beam_size, + "start_token": self.start_token, + "end_token": self.end_token + } if self.decoder.decoding_strategy == "beam_search" else { + "embedding_fn": self.trg_embeder, + "start_tokens": layers.fill_constant_batch_size_like( + input=encoder_output, + shape=[-1], + dtype=src.dtype, + value=self.start_token), + "end_token": self.end_token + }) + decoder_kwargs["output_layer"] = self.output_layer + + (decoder_output, decoder_final_state, + dec_seq_lengths) = self.decoder(decoder_initial_states, encoder_output, + encoder_padding_mask, **decoder_kwargs) + if self.decoder.decoding_strategy == "beam_search": # for inference + return decoder_output + logits, samples, sample_length = (decoder_output.cell_outputs, + decoder_output.sample_ids, + dec_seq_lengths) + probs = layers.softmax(logits) + return probs, samples, sample_length + + +class PolicyGradient(object): + """policy gradient""" + + def __init__(self, lr=None): + self.lr = lr + + def learn(self, act_prob, action, reward, length=None): + """ + update policy model self.model with policy gradient algorithm + """ + self.reward = fluid.layers.py_func( + func=reward_func, x=[action, length], out=reward) + neg_log_prob = layers.cross_entropy(act_prob, action) + cost = neg_log_prob * reward + cost = (layers.reduce_sum(cost) / layers.reduce_sum(length) + ) if length is not None else layers.reduce_mean(cost) + optimizer = fluid.optimizer.Adam(self.lr) + optimizer.minimize(cost) + return cost + + +def reward_func(samples, sample_length): + """toy reward""" + + def discount_reward(reward, sequence_length, discount=1.): + return discount_reward_1d(reward, sequence_length, discount) + + def discount_reward_1d(reward, sequence_length, discount=1., dtype=None): + if sequence_length is None: + raise ValueError( + 'sequence_length must not be `None` for 1D reward.') + reward = np.array(reward) + sequence_length = np.array(sequence_length) + batch_size = reward.shape[0] + max_seq_length = np.max(sequence_length) + dtype = dtype or reward.dtype + if discount == 1.: + dmat = np.ones([batch_size, max_seq_length], dtype=dtype) else: - place = core.CPUPlace() - exe = Executor(place) - exe.run(framework.default_startup_program()) - - src_np = np.random.randint( - 0, src_vocab_size, (self.batch_size, max_length)).astype('int64') - src_len_np = np.ones(self.batch_size, dtype='int64') * max_length - trg_np = np.random.randint( - 0, trg_vocab_size, (self.batch_size, max_length)).astype('int64') - trg_len_np = np.ones(self.batch_size, dtype='int64') * max_length - - out = exe.run(feed={ - 'src': src_np, - 'src_len': src_len_np, - 'trg': trg_np, - 'trg_len': trg_len_np - }, - fetch_list=[outputs]) - - self.assertTrue(out[0].shape[0] == self.batch_size) - self.assertTrue(out[0].shape[1] <= max_length + 1) - self.assertTrue(out[0].shape[2] == beam_size) + steps = np.tile(np.arange(max_seq_length), [batch_size, 1]) + mask = np.asarray( + steps < (sequence_length - 1)[:, None], dtype=dtype) + # Make each row = [discount, ..., discount, 1, ..., 1] + dmat = mask * discount + (1 - mask) + dmat = np.cumprod(dmat[:, ::-1], axis=1)[:, ::-1] + disc_reward = dmat * reward[:, None] + disc_reward = mask_sequences(disc_reward, sequence_length, dtype=dtype) + return disc_reward + + def mask_sequences(sequence, sequence_length, dtype=None, time_major=False): + sequence = np.array(sequence) + sequence_length = np.array(sequence_length) + rank = sequence.ndim + if rank < 2: + raise ValueError("`sequence` must be 2D or higher order.") + batch_size = sequence.shape[0] + max_time = sequence.shape[1] + dtype = dtype or sequence.dtype + if time_major: + sequence = np.transpose(sequence, axes=[1, 0, 2]) + steps = np.tile(np.arange(max_time), [batch_size, 1]) + mask = np.asarray(steps < sequence_length[:, None], dtype=dtype) + for _ in range(2, rank): + mask = np.expand_dims(mask, -1) + sequence = sequence * mask + if time_major: + sequence = np.transpose(sequence, axes=[1, 0, 2]) + return sequence + + samples = np.array(samples) + sample_length = np.array(sample_length) + # length reward + reward = (5 - np.abs(sample_length - 5)).astype("float32") + # repeat punishment to trapped into local minima getting all same words + # beam search to get more than one sample may also can avoid this + for i in range(reward.shape[0]): + reward[i] += -10 if sample_length[i] > 1 and np.all( + samples[i][:sample_length[i] - 1] == samples[i][0]) else 0 + return discount_reward(reward, sample_length, discount=1.).astype("float32") + + +class MLE(object): + """teacher-forcing MLE training""" + + def __init__(self, lr=None): + self.lr = lr + + def learn(self, probs, label, weight=None, length=None): + loss = layers.cross_entropy(input=probs, label=label, soft_label=False) + max_seq_len = layers.shape(probs)[1] + mask = layers.sequence_mask(length, maxlen=max_seq_len, dtype="float32") + loss = loss * mask + loss = layers.reduce_mean(loss, dim=[0]) + loss = layers.reduce_sum(loss) + optimizer = fluid.optimizer.Adam(self.lr) + optimizer.minimize(loss) + return loss + + +class SeqPGAgent(object): + def __init__(self, + model_cls, + alg_cls=PolicyGradient, + model_hparams={}, + alg_hparams={}, + executor=None, + main_program=None, + startup_program=None, + seed=None): + self.main_program = fluid.Program( + ) if main_program is None else main_program + self.startup_program = fluid.Program( + ) if startup_program is None else startup_program + if seed is not None: + self.main_program.random_seed = seed + self.startup_program.random_seed = seed + self.build_program(model_cls, alg_cls, model_hparams, alg_hparams) + self.executor = executor + + def build_program(self, model_cls, alg_cls, model_hparams, alg_hparams): + with fluid.program_guard(self.main_program, self.startup_program): + source = fluid.data(name="src", shape=[None, None], dtype="int64") + source_length = fluid.data( + name="src_sequence_length", shape=[None], dtype="int64") + # only for teacher-forcing MLE training + target = fluid.data(name="trg", shape=[None, None], dtype="int64") + target_length = fluid.data( + name="trg_sequence_length", shape=[None], dtype="int64") + label = fluid.data( + name="label", shape=[None, None, 1], dtype="int64") + self.model = model_cls(**model_hparams) + self.alg = alg_cls(**alg_hparams) + self.probs, self.samples, self.sample_length = self.model( + source, source_length, target, target_length) + self.samples.stop_gradient = True + self.reward = fluid.layers.create_global_var( + name="reward", + shape=[-1, -1], # batch_size, seq_len + value="1", + dtype=self.probs.dtype) + self.cost = self.alg.learn(self.probs, self.samples, self.reward, + self.sample_length) + + # to define the same parameters between different programs + self.pred_program = self.main_program._prune_with_input( + [source.name, source_length.name], + [self.probs, self.samples, self.sample_length]) + + def predict(self, feed_dict): + samples, sample_length = self.executor.run( + self.pred_program, + feed=feed_dict, + fetch_list=[self.samples, self.sample_length]) + return samples, sample_length + + def learn(self, feed_dict, fetch_list): + results = self.executor.run(self.main_program, + feed=feed_dict, + fetch_list=fetch_list) + return results + + +class TestDynamicDecode(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.model_hparams = { + "num_layers": 2, + "hidden_size": 32, + "dropout_prob": 0.1, + "src_vocab_size": 100, + "trg_vocab_size": 100, + "start_token": 0, + "end_token": 1, + "decoding_strategy": "infer_greedy", + "max_decoding_length": 10 + } + + self.iter_num = iter_num = 2 + self.batch_size = batch_size = 4 + src_seq_len = 10 + trg_seq_len = 12 + self.data = { + "src": np.random.randint( + 2, self.model_hparams["src_vocab_size"], + (iter_num * batch_size, src_seq_len)).astype("int64"), + "src_sequence_length": np.random.randint( + 1, src_seq_len, (iter_num * batch_size, )).astype("int64"), + "trg": np.random.randint( + 2, self.model_hparams["src_vocab_size"], + (iter_num * batch_size, trg_seq_len)).astype("int64"), + "trg_sequence_length": np.random.randint( + 1, trg_seq_len, (iter_num * batch_size, )).astype("int64"), + "label": np.random.randint( + 2, self.model_hparams["src_vocab_size"], + (iter_num * batch_size, trg_seq_len, 1)).astype("int64"), + } + + place = core.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else core.CPUPlace() + self.exe = Executor(place) + + def test_mle_train(self): + self.model_hparams["decoding_strategy"] = "train_greedy" + agent = SeqPGAgent( + model_cls=Seq2SeqModel, + alg_cls=MLE, + model_hparams=self.model_hparams, + alg_hparams={"lr": 0.001}, + executor=self.exe, + main_program=fluid.Program(), + startup_program=fluid.Program(), + seed=123) + self.exe.run(agent.startup_program) + for iter_idx in range(self.iter_num): + reward, cost = agent.learn( + { + "src": self.data["src"][iter_idx * self.batch_size:( + iter_idx + 1) * self.batch_size, :], + "src_sequence_length": self.data["src_sequence_length"][ + iter_idx * self.batch_size:(iter_idx + 1 + ) * self.batch_size], + "trg": self.data["trg"][iter_idx * self.batch_size:( + iter_idx + 1) * self.batch_size, :], + "trg_sequence_length": self.data["trg_sequence_length"] + [iter_idx * self.batch_size:(iter_idx + 1) * + self.batch_size], + "label": self.data["label"][iter_idx * self.batch_size:( + iter_idx + 1) * self.batch_size] + }, + fetch_list=[agent.cost, agent.cost]) + print("iter_idx: %d, reward: %f, cost: %f" % + (iter_idx, reward.mean(), cost)) + + def test_greedy_train(self): + self.model_hparams["decoding_strategy"] = "infer_greedy" + agent = SeqPGAgent( + model_cls=Seq2SeqModel, + alg_cls=PolicyGradient, + model_hparams=self.model_hparams, + alg_hparams={"lr": 0.001}, + executor=self.exe, + main_program=fluid.Program(), + startup_program=fluid.Program(), + seed=123) + self.exe.run(agent.startup_program) + for iter_idx in range(self.iter_num): + reward, cost = agent.learn( + { + "src": self.data["src"][iter_idx * self.batch_size:( + iter_idx + 1) * self.batch_size, :], + "src_sequence_length": self.data["src_sequence_length"] + [iter_idx * self.batch_size:(iter_idx + 1) * + self.batch_size] + }, + fetch_list=[agent.reward, agent.cost]) + print("iter_idx: %d, reward: %f, cost: %f" % + (iter_idx, reward.mean(), cost)) + + def test_sample_train(self): + self.model_hparams["decoding_strategy"] = "infer_sample" + agent = SeqPGAgent( + model_cls=Seq2SeqModel, + alg_cls=PolicyGradient, + model_hparams=self.model_hparams, + alg_hparams={"lr": 0.001}, + executor=self.exe, + main_program=fluid.Program(), + startup_program=fluid.Program(), + seed=123) + self.exe.run(agent.startup_program) + for iter_idx in range(self.iter_num): + reward, cost = agent.learn( + { + "src": self.data["src"][iter_idx * self.batch_size:( + iter_idx + 1) * self.batch_size, :], + "src_sequence_length": self.data["src_sequence_length"] + [iter_idx * self.batch_size:(iter_idx + 1) * + self.batch_size] + }, + fetch_list=[agent.reward, agent.cost]) + print("iter_idx: %d, reward: %f, cost: %f" % + (iter_idx, reward.mean(), cost)) + + def test_beam_search_infer(self): + self.model_hparams["decoding_strategy"] = "beam_search" + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + source = fluid.data(name="src", shape=[None, None], dtype="int64") + source_length = fluid.data( + name="src_sequence_length", shape=[None], dtype="int64") + model = Seq2SeqModel(**self.model_hparams) + output = model(source, source_length) + + self.exe.run(startup_program) + for iter_idx in range(self.iter_num): + trans_ids = self.exe.run( + program=main_program, + feed={ + "src": self.data["src"][iter_idx * self.batch_size:( + iter_idx + 1) * self.batch_size, :], + "src_sequence_length": self.data["src_sequence_length"] + [iter_idx * self.batch_size:(iter_idx + 1) * + self.batch_size] + }, + fetch_list=[output])[0] if __name__ == '__main__': -- GitLab