From be4ea1748e3eca41bec32f4a1aee0836edbc62cc Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Mon, 19 Dec 2022 10:08:05 +0800 Subject: [PATCH] [Migrate Fluid] Migrate dynamic_decode and tensor_array_to_tensor (#48876) --- python/paddle/fluid/layers/__init__.py | 4 - python/paddle/fluid/layers/nn.py | 2 +- python/paddle/fluid/layers/rnn.py | 473 ------------------ python/paddle/fluid/layers/tensor.py | 118 ----- .../test_dynamic_rnn_stop_gradient.py | 3 +- .../tests/unittests/test_rnn_decode_api.py | 324 +++++++++++- .../fluid/tests/unittests/test_slice_op.py | 5 +- .../unittests/test_tensor_array_to_tensor.py | 13 +- python/paddle/nn/decode.py | 442 +++++++++++++++- python/paddle/tensor/manipulation.py | 114 +++++ 10 files changed, 887 insertions(+), 611 deletions(-) delete mode 100644 python/paddle/fluid/layers/rnn.py diff --git a/python/paddle/fluid/layers/__init__.py b/python/paddle/fluid/layers/__init__.py index 0e98d90773..c2e4035ea8 100644 --- a/python/paddle/fluid/layers/__init__.py +++ b/python/paddle/fluid/layers/__init__.py @@ -27,7 +27,6 @@ from .loss import * from .learning_rate_scheduler import * from .collective import * from .sequence_lod import * -from . import rnn __all__ = [] __all__ += nn.__all__ @@ -37,6 +36,3 @@ __all__ += control_flow.__all__ __all__ += learning_rate_scheduler.__all__ __all__ += sequence_lod.__all__ __all__ += loss.__all__ -__all__ += rnn.__all__ - -from .rnn import * diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 7421a05cc8..65f02001b2 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -45,7 +45,7 @@ from .layer_function_generator import ( templatedoc, _generate_doc_string_, ) -from .tensor import concat, assign, fill_constant, zeros, tensor_array_to_tensor +from .tensor import concat, assign, fill_constant, zeros from . import utils from .. import unique_name from functools import reduce diff --git a/python/paddle/fluid/layers/rnn.py b/python/paddle/fluid/layers/rnn.py deleted file mode 100644 index 90cc5a6853..0000000000 --- a/python/paddle/fluid/layers/rnn.py +++ /dev/null @@ -1,473 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -from functools import partial, reduce -import warnings - - -import paddle -from paddle.utils import deprecated -from . import nn -from . import tensor -from . import control_flow -from . import utils -from . import sequence_lod -from .utils import * -from .. import core -from ..framework import default_main_program -from ..data_feeder import convert_dtype -from ..layer_helper import LayerHelper -from ..framework import _non_static_mode -from ..param_attr import ParamAttr -from ..data_feeder import check_variable_and_dtype, check_type, check_dtype - -from collections.abc import Sequence - -__all__ = [ - 'dynamic_decode', -] - - -class ArrayWrapper: - def __init__(self, x): - self.array = [x] - - def append(self, x): - self.array.append(x) - return self - - def __getitem__(self, item): - return self.array.__getitem__(item) - - -def _dynamic_decode_imperative( - decoder, - inits=None, - max_step_num=None, - output_time_major=False, - impute_finished=False, - is_test=False, - return_length=False, - **kwargs -): - def _maybe_copy(state, new_state, step_mask): - # TODO: use where_op - state_dtype = state.dtype - if convert_dtype(state_dtype) in ["bool"]: - state = tensor.cast(state, dtype="float32") - new_state = tensor.cast(new_state, dtype="float32") - if step_mask.dtype != state.dtype: - step_mask = tensor.cast(step_mask, dtype=state.dtype) - # otherwise, renamed bool gradients of would be summed up leading - # to sum(bool) error. - step_mask.stop_gradient = True - new_state = paddle.tensor.math._multiply_with_axis( - state, step_mask, axis=0 - ) - paddle.tensor.math._multiply_with_axis( - new_state, (step_mask - 1), axis=0 - ) - if convert_dtype(state_dtype) in ["bool"]: - new_state = tensor.cast(new_state, dtype=state_dtype) - return new_state - - initial_inputs, initial_states, initial_finished = decoder.initialize(inits) - inputs, states, finished = ( - initial_inputs, - initial_states, - initial_finished, - ) - cond = paddle.logical_not((paddle.all(initial_finished))) - sequence_lengths = tensor.cast(paddle.zeros_like(initial_finished), "int64") - outputs = None - - step_idx = 0 - step_idx_tensor = tensor.fill_constant( - shape=[1], dtype="int64", value=step_idx - ) - while cond.numpy(): - (step_outputs, next_states, next_inputs, next_finished) = decoder.step( - step_idx_tensor, inputs, states, **kwargs - ) - if not decoder.tracks_own_finished: - # BeamSearchDecoder would track it own finished, since - # beams would be reordered and the finished status of each - # entry might change. Otherwise, perform logical OR which - # would not change the already finished. - next_finished = paddle.logical_or(next_finished, finished) - # To confirm states.finished/finished be consistent with - # next_finished. - tensor.assign(next_finished, finished) - next_sequence_lengths = paddle.add( - sequence_lengths, - tensor.cast( - paddle.logical_not(finished), sequence_lengths.dtype - ), - ) - if impute_finished: # rectify the states for the finished. - next_states = map_structure( - lambda x, y: _maybe_copy(x, y, finished), - states, - next_states, - ) - else: - warnings.warn( - "`next_states` has no `lengths` attribute, the returned `sequence_lengths` would be all zeros." - ) if not hasattr(next_states, "lengths") else None - next_sequence_lengths = getattr( - next_states, "lengths", sequence_lengths - ) - - outputs = ( - map_structure(lambda x: ArrayWrapper(x), step_outputs) - if step_idx == 0 - else map_structure( - lambda x, x_array: x_array.append(x), step_outputs, outputs - ) - ) - inputs, states, finished, sequence_lengths = ( - next_inputs, - next_states, - next_finished, - next_sequence_lengths, - ) - - paddle.increment(x=step_idx_tensor, value=1.0) - step_idx += 1 - - cond = paddle.logical_not(paddle.all(finished)) - if max_step_num is not None and step_idx > max_step_num: - break - - final_outputs = map_structure( - lambda x: paddle.stack(x.array, axis=0), outputs - ) - final_states = states - - try: - final_outputs, final_states = decoder.finalize( - final_outputs, final_states, sequence_lengths - ) - except NotImplementedError: - pass - - if not output_time_major: - final_outputs = map_structure( - lambda x: paddle.transpose( - x, [1, 0] + list(range(2, len(x.shape))) - ), - final_outputs, - ) - - return ( - (final_outputs, final_states, sequence_lengths) - if return_length - else (final_outputs, final_states) - ) - - -def _dynamic_decode_declarative( - decoder, - inits=None, - max_step_num=None, - output_time_major=False, - impute_finished=False, - is_test=False, - return_length=False, - **kwargs -): - initial_inputs, initial_states, initial_finished = decoder.initialize(inits) - global_inputs, global_states, global_finished = ( - initial_inputs, - initial_states, - initial_finished, - ) - global_finished.stop_gradient = True - step_idx = tensor.fill_constant(shape=[1], dtype="int64", value=0) - - cond = paddle.logical_not((paddle.all(initial_finished))) - if max_step_num is not None: - max_step_num = tensor.fill_constant( - shape=[1], dtype="int64", value=max_step_num - ) - while_op = paddle.static.nn.control_flow.While(cond, is_test=is_test) - - sequence_lengths = tensor.cast(paddle.zeros_like(initial_finished), "int64") - sequence_lengths.stop_gradient = True - - if is_test: - # for test, reuse inputs and states variables to save memory - inputs = map_structure(lambda x: x, initial_inputs) - states = map_structure(lambda x: x, initial_states) - else: - # inputs and states of all steps must be saved for backward and training - inputs_arrays = map_structure( - lambda x: paddle.tensor.array_write(x, step_idx), initial_inputs - ) - states_arrays = map_structure( - lambda x: paddle.tensor.array_write(x, step_idx), initial_states - ) - - def _maybe_copy(state, new_state, step_mask): - # TODO: use where_op - state_dtype = state.dtype - if convert_dtype(state_dtype) in ["bool"]: - state = tensor.cast(state, dtype="float32") - new_state = tensor.cast(new_state, dtype="float32") - if step_mask.dtype != state.dtype: - step_mask = tensor.cast(step_mask, dtype=state.dtype) - # otherwise, renamed bool gradients of would be summed up leading - # to sum(bool) error. - step_mask.stop_gradient = True - new_state = paddle.tensor.math._multiply_with_axis( - state, step_mask, axis=0 - ) - paddle.tensor.math._multiply_with_axis( - new_state, (step_mask - 1), axis=0 - ) - if convert_dtype(state_dtype) in ["bool"]: - new_state = tensor.cast(new_state, dtype=state_dtype) - return new_state - - def _transpose_batch_time(x): - return paddle.transpose(x, [1, 0] + list(range(2, len(x.shape)))) - - def _create_array_out_of_while(dtype): - current_block_idx = default_main_program().current_block_idx - default_main_program().current_block_idx = ( - default_main_program().current_block().parent_idx - ) - tensor_array = paddle.tensor.create_array(dtype) - default_main_program().current_block_idx = current_block_idx - return tensor_array - - # While - with while_op.block(): - if not is_test: - inputs = map_structure( - lambda array: paddle.tensor.array_read(array, step_idx), - inputs_arrays, - ) - states = map_structure( - lambda array: paddle.tensor.array_read(array, step_idx), - states_arrays, - ) - (outputs, next_states, next_inputs, next_finished) = decoder.step( - step_idx, inputs, states, **kwargs - ) - if not decoder.tracks_own_finished: - # BeamSearchDecoder would track it own finished, since beams would - # be reordered and the finished status of each entry might change. - # Otherwise, perform logical OR which would not change the already - # finished. - next_finished = paddle.logical_or(next_finished, global_finished) - next_sequence_lengths = paddle.add( - sequence_lengths, - tensor.cast( - paddle.logical_not(global_finished), - sequence_lengths.dtype, - ), - ) - if impute_finished: # rectify the states for the finished. - next_states = map_structure( - lambda x, y: _maybe_copy(x, y, global_finished), - states, - next_states, - ) - else: - warnings.warn( - "`next_states` has no `lengths` attribute, the returned `sequence_lengths` would be all zeros." - ) if not hasattr(next_states, "lengths") else None - next_sequence_lengths = getattr( - next_states, "lengths", sequence_lengths - ) - - # create tensor array in global block after dtype[s] of outputs can be got - outputs_arrays = map_structure( - lambda x: _create_array_out_of_while(x.dtype), outputs - ) - - map_structure( - lambda x, x_array: paddle.tensor.array_write( - x, i=step_idx, array=x_array - ), - outputs, - outputs_arrays, - ) - - paddle.increment(x=step_idx, value=1.0) - # update the global_finished first, since it might be also in states of - # decoder, which otherwise would write a stale finished status to array - tensor.assign(next_finished, global_finished) - tensor.assign(next_sequence_lengths, sequence_lengths) - if is_test: - map_structure(tensor.assign, next_inputs, global_inputs) - map_structure(tensor.assign, next_states, global_states) - else: - map_structure( - lambda x, x_array: paddle.tensor.array_write( - x, i=step_idx, array=x_array - ), - next_inputs, - inputs_arrays, - ) - map_structure( - lambda x, x_array: paddle.tensor.array_write( - x, i=step_idx, array=x_array - ), - next_states, - states_arrays, - ) - if max_step_num is not None: - paddle.logical_and( - paddle.logical_not(paddle.all(global_finished)), - paddle.less_equal(step_idx, max_step_num), - cond, - ) - else: - paddle.logical_not(paddle.all(global_finished), cond) - - final_outputs = map_structure( - lambda array: tensor.tensor_array_to_tensor( - array, axis=0, use_stack=True - )[0], - outputs_arrays, - ) - if is_test: - final_states = global_states - else: - final_states = map_structure( - lambda array: paddle.tensor.array_read(array, step_idx), - states_arrays, - ) - - try: - final_outputs, final_states = decoder.finalize( - final_outputs, final_states, sequence_lengths - ) - except NotImplementedError: - pass - - if not output_time_major: - final_outputs = map_structure(_transpose_batch_time, final_outputs) - - return ( - (final_outputs, final_states, sequence_lengths) - if return_length - else (final_outputs, final_states) - ) - - -def dynamic_decode( - decoder, - inits=None, - max_step_num=None, - output_time_major=False, - impute_finished=False, - is_test=False, - return_length=False, - **kwargs -): - r""" - Dynamic decoding performs :code:`decoder.step()` repeatedly until the returned - Tensor indicating finished status contains all True values or the number of - decoding step reaches to :attr:`max_step_num`. - - :code:`decoder.initialize()` would be called once before the decoding loop. - If the `decoder` has implemented `finalize` method, :code:`decoder.finalize()` - would be called once after the decoding loop. - - Parameters: - decoder(Decoder): An instance of `Decoder`. - inits(object, optional): Argument passed to `decoder.initialize`. - Default `None`. - max_step_num(int, optional): The maximum number of steps. If not provided, - decode until the decoder is fully done, or in other words, the returned - Tensor by :code:`decoder.step()` indicating finished status contains - all True. Default `None`. - output_time_major(bool, optional): Indicate the data layout of Tensor included - in the final outputs(the first returned value of this method). If - attr:`False`, the data layout would be batch major with shape - `[batch_size, seq_len, ...]`. If attr:`True`, the data layout would - be time major with shape `[seq_len, batch_size, ...]`. Default: `False`. - impute_finished(bool, optional): If `True` and `decoder.tracks_own_finished` - is False, then states get copied through for batch entries which are - marked as finished, which differs with the unfinished using the new states - returned by :code:`decoder.step()` and ensures that the final states have - the correct values. Otherwise, states wouldn't be copied through when - finished. If the returned `final_states` is needed, it should be set as - True, which causes some slowdown. Default `False`. - is_test(bool, optional): A flag indicating whether to use test mode. In - test mode, it is more memory saving. Default `False`. - return_length(bool, optional): A flag indicating whether to return an - extra Tensor variable in the output tuple, which stores the actual - lengths of all decoded sequences. Default `False`. - **kwargs: Additional keyword arguments. Arguments passed to `decoder.step`. - - Returns: - - - final_outputs (Tensor, nested structure of Tensor), each Tensor in :code:`final_outputs` is the stacked of all decoding steps' outputs, which might be revised - by :code:`decoder.finalize()` if the decoder has implemented finalize. - And :code:`final_outputs` has the same structure and data types as the :code:`outputs` - returned by :code:`decoder.step()` - - - final_states (Tensor, nested structure of Tensor), :code:`final_states` is the counterpart at last time step of initial states \ - returned by :code:`decoder.initialize()` , thus has the same structure - with it and has tensors with same shapes and data types. - - - sequence_lengths (Tensor), stores the actual lengths of all decoded sequences. - sequence_lengths is provided only if :code:`return_length` is True. - - Examples: - - .. code-block:: python - - import paddle - from paddle.nn import BeamSearchDecoder, dynamic_decode - from paddle.nn import GRUCell, Linear, Embedding - trg_embeder = Embedding(100, 32) - output_layer = Linear(32, 32) - decoder_cell = GRUCell(input_size=32, hidden_size=32) - decoder = BeamSearchDecoder(decoder_cell, - start_token=0, - end_token=1, - beam_size=4, - embedding_fn=trg_embeder, - output_fn=output_layer) - encoder_output = paddle.ones((4, 8, 32), dtype=paddle.get_default_dtype()) - outputs = dynamic_decode(decoder=decoder, - inits=decoder_cell.get_initial_states(encoder_output), - max_step_num=10) - """ - if _non_static_mode(): - return _dynamic_decode_imperative( - decoder, - inits, - max_step_num, - output_time_major, - impute_finished, - is_test, - return_length, - **kwargs - ) - else: - return _dynamic_decode_declarative( - decoder, - inits, - max_step_num, - output_time_major, - impute_finished, - is_test, - return_length, - **kwargs - ) diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 99b8b44eca..15ab8ba5f6 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -42,7 +42,6 @@ from paddle import _C_ops, _legacy_C_ops __all__ = [ 'cast', - 'tensor_array_to_tensor', 'concat', 'sums', 'assign', @@ -261,123 +260,6 @@ def concat(input, axis=0, name=None): return out -def tensor_array_to_tensor(input, axis=1, name=None, use_stack=False): - r""" - This function concatenates or stacks all tensors in the input LoDTensorArray - along the axis mentioned and returns that as the output. - - For Example: - - .. code-block:: text - - Case 1: - - Given: - - input.data = {[[0.6, 0.1, 0.3], - [0.5, 0.3, 0.2]], - [[1.3], - [1.8]], - [[2.3, 2.1], - [2.5, 2.4]]} - - axis = 1, use_stack = False - - Then: - - output.data = [[0.6, 0.1, 0.3, 1.3, 2.3, 2.1], - [0.5, 0.3, 0.2, 1.8, 2.5, 2.4]] - - output_index.data = [3, 1, 2] - - Case 2: - - Given: - - input.data = {[[0.6, 0.1], - [0.5, 0.3]], - [[0.3, 1.3], - [0.2, 1.8]], - [[2.3, 2.1], - [2.5, 2.4]]} - - axis = 1, use_stack = True - - Then: - - output.data = [[[0.6, 0.1] - [0.3, 1.3] - [2.3, 2.1], - [[0.5, 0.3] - [0.2, 1.8] - [2.5, 2.4]]] - - output_index.data = [2, 2, 2] - - Args: - input(Variable): A LodTensorArray variable. - axis(int): The axis along which the tensors in attr::`input` will be - concatenated or stacked. - name(str|None): A name for this layer(optional). If set None, the layer - will be named automatically. - use_stack(bool): Act as concat_op or stack_op. For stack mode, all - tensors in the tensor array must have the same shape. - - Returns: - Variable: The concatenated or stacked tensor variable. - Variable: A 1-D tensor variable with int32 data type. The data in this \ - tensor contains all input including tensors' sizes along the axis. - - Examples: - .. code-block:: python - - import paddle - import paddle.fluid as fluid - import numpy as np - x0 = fluid.layers.assign(np.random.rand(2, 2).astype("float32")) - x1 = fluid.layers.assign(np.random.rand(2, 2).astype("float32")) - i = fluid.layers.fill_constant(shape=[1], dtype="int64", value=0) - array = paddle.tensor.create_array(dtype='float32') - paddle.tensor.array_write(x0, i, array) - paddle.tensor.array_write(x1, i + 1, array) - output, output_index = fluid.layers.tensor_array_to_tensor(input=array) - """ - if _non_static_mode(): - assert isinstance( - input, list - ), "The 'input' in tensor_array_to_tensor must be list" - from .nn import concat - from ..dygraph import to_variable - from paddle import stack - - op = stack if use_stack else concat - res = op(input, axis=axis) - sizes = to_variable( - numpy.array(list(map(lambda x: int(x.shape[axis]), input))) - ) - return res, sizes - - check_type(input, 'input', (list, Variable), 'tensor_array_to_tensor') - if isinstance(input, list): - for i, input_x in enumerate(input): - check_type( - input_x, - 'input[' + str(i) + ']', - Variable, - 'tensor_array_to_tensor', - ) - helper = LayerHelper('tensor_array_to_tensor', **locals()) - out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) - out_index = helper.create_variable_for_type_inference(dtype="int32") - helper.append_op( - type='tensor_array_to_tensor', - inputs={'X': input}, - outputs={'Out': [out], 'OutIndex': [out_index]}, - attrs={'axis': axis, 'use_stack': use_stack}, - ) - return out, out_index - - def sums(input, out=None): r""" This function computes the sum of multiple input Tensors elementwisely. diff --git a/python/paddle/fluid/tests/unittests/test_dynamic_rnn_stop_gradient.py b/python/paddle/fluid/tests/unittests/test_dynamic_rnn_stop_gradient.py index 2651a20dd4..0abc9317b6 100644 --- a/python/paddle/fluid/tests/unittests/test_dynamic_rnn_stop_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_dynamic_rnn_stop_gradient.py @@ -19,6 +19,7 @@ import numpy as np import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers +from paddle.tensor.manipulation import tensor_array_to_tensor paddle.enable_static() @@ -58,7 +59,7 @@ def build_and_run_program(place, batch_size, beam_size, stop_gradient=False): length_cond = paddle.less_than(x=step_idx, y=max_len) layers.assign(length_cond, cond) - out = layers.tensor_array_to_tensor(scores, axis=0, use_stack=True)[0] + out = tensor_array_to_tensor(scores, axis=0, use_stack=True)[0] loss = paddle.mean(out) opt = fluid.optimizer.Adam(0.01) opt.minimize(loss) diff --git a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py index 91b3adcb92..e89542422b 100644 --- a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py +++ b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import random import unittest @@ -22,9 +23,19 @@ import paddle.fluid as fluid import paddle.fluid.layers as layers import paddle.nn as nn from paddle import Model, set_device -from paddle.fluid.dygraph import Layer +from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.framework import _test_eager_guard -from paddle.nn import BeamSearchDecoder, dynamic_decode +from paddle.fluid.layers.utils import map_structure +from paddle.nn import ( + RNN, + BeamSearchDecoder, + Embedding, + Layer, + Linear, + LSTMCell, + SimpleRNNCell, + dynamic_decode, +) from paddle.static import InputSpec as Input paddle.enable_static() @@ -351,7 +362,7 @@ class TestBeamSearch(ModuleApiTest): beam_size=4, max_step_num=20, ): - embedder = paddle.nn.Embedding(vocab_size, embed_dim) + embedder = Embedding(vocab_size, embed_dim) output_layer = nn.Linear(hidden_size, vocab_size) cell = nn.LSTMCell(embed_dim, hidden_size) self.max_step_num = max_step_num @@ -392,5 +403,312 @@ class TestBeamSearch(ModuleApiTest): self.func_check_output() +class EncoderCell(SimpleRNNCell): + def __init__( + self, + num_layers, + input_size, + hidden_size, + dropout_prob=0.0, + init_scale=0.1, + ): + super(EncoderCell, self).__init__(input_size, hidden_size) + self.dropout_prob = dropout_prob + # use add_sublayer to add multi-layers + self.lstm_cells = [] + for i in range(num_layers): + self.lstm_cells.append( + self.add_sublayer( + "lstm_%d" % i, + LSTMCell( + input_size=input_size if i == 0 else hidden_size, + hidden_size=hidden_size, + ), + ) + ) + + def forward(self, step_input, states): + new_states = [] + for i, lstm_cell in enumerate(self.lstm_cells): + out, new_state = lstm_cell(step_input, states[i]) + step_input = ( + layers.dropout( + out, + self.dropout_prob, + dropout_implementation='upscale_in_train', + ) + if self.dropout_prob > 0 + else out + ) + new_states.append(new_state) + return step_input, new_states + + @property + def state_shape(self): + return [cell.state_shape for cell in self.lstm_cells] + + +class Encoder(Layer): + def __init__( + self, + vocab_size, + embed_dim, + hidden_size, + num_layers, + dropout_prob=0.0, + init_scale=0.1, + ): + super(Encoder, self).__init__() + self.embedder = Embedding(vocab_size, embed_dim) + self.stack_lstm = RNN( + EncoderCell( + num_layers, embed_dim, hidden_size, dropout_prob, init_scale + ), + is_reverse=False, + time_major=False, + ) + + def forward(self, sequence, sequence_length): + inputs = self.embedder(sequence) + encoder_output, encoder_state = self.stack_lstm( + inputs, sequence_length=sequence_length + ) + return encoder_output, encoder_state + + +DecoderCell = EncoderCell + + +class Decoder(Layer): + def __init__( + self, + vocab_size, + embed_dim, + hidden_size, + num_layers, + dropout_prob=0.0, + init_scale=0.1, + ): + super(Decoder, self).__init__() + self.embedder = Embedding(vocab_size, embed_dim) + self.stack_lstm = RNN( + DecoderCell( + num_layers, embed_dim, hidden_size, dropout_prob, init_scale + ), + is_reverse=False, + time_major=False, + ) + self.output_layer = Linear(hidden_size, vocab_size, bias_attr=False) + + def forward(self, target, decoder_initial_states): + inputs = self.embedder(target) + decoder_output, _ = self.stack_lstm( + inputs, initial_states=decoder_initial_states + ) + predict = self.output_layer(decoder_output) + return predict + + +class TrainingHelper: + def __init__(self, inputs, sequence_length, time_major=False): + self.inputs = inputs + self.sequence_length = sequence_length + self.time_major = time_major + self.inputs_ = map_structure( + lambda x: paddle.nn.functional.pad( + x, + pad=([0, 1] + [0, 0] * (len(x.shape) - 1)) + if time_major + else ([0, 0, 0, 1] + [0, 0] * (len(x.shape) - 2)), + ), + self.inputs, + ) + + def initialize(self): + init_finished = paddle.equal( + self.sequence_length, + paddle.full( + shape=[1], dtype=self.sequence_length.dtype, fill_value=0 + ), + ) + init_inputs = map_structure( + lambda x: x[0] if self.time_major else x[:, 0], self.inputs + ) + return init_inputs, init_finished + + def sample(self, time, outputs, states): + sample_ids = paddle.argmax(outputs, axis=-1) + return sample_ids + + def next_inputs(self, time, outputs, states, sample_ids): + time = ( + paddle.cast(time, "int32") + if convert_dtype(time.dtype) not in ["int32"] + else time + ) + if self.sequence_length.dtype != time.dtype: + self.sequence_length = paddle.cast(self.sequence_length, time.dtype) + next_time = time + 1 + finished = paddle.less_equal(self.sequence_length, next_time) + + def _slice(x): + axes = [0 if self.time_major else 1] + return paddle.squeeze( + paddle.slice( + x, axes=axes, starts=[next_time], ends=[next_time + 1] + ), + axis=axes, + ) + + next_inputs = map_structure(_slice, self.inputs_) + return finished, next_inputs, states + + +class BasicDecoder(paddle.nn.decode.Decoder): + def __init__(self, cell, helper, output_fn=None): + super().__init__() + self.cell = cell + self.helper = helper + self.output_fn = output_fn + + def initialize(self, initial_cell_states): + (initial_inputs, initial_finished) = self.helper.initialize() + return initial_inputs, initial_cell_states, initial_finished + + class OutputWrapper( + collections.namedtuple("OutputWrapper", ("cell_outputs", "sample_ids")) + ): + pass + + def step(self, time, inputs, states, **kwargs): + cell_outputs, cell_states = self.cell(inputs, states, **kwargs) + if self.output_fn is not None: + cell_outputs = self.output_fn(cell_outputs) + sample_ids = self.helper.sample( + time=time, outputs=cell_outputs, states=cell_states + ) + sample_ids.stop_gradient = True + (finished, next_inputs, next_states) = self.helper.next_inputs( + time=time, + outputs=cell_outputs, + states=cell_states, + sample_ids=sample_ids, + ) + outputs = self.OutputWrapper(cell_outputs, sample_ids) + return (outputs, next_states, next_inputs, finished) + + +class BaseModel(Layer): + def __init__( + self, + vocab_size=10, + embed_dim=32, + hidden_size=32, + num_layers=1, + dropout_prob=0.0, + init_scale=0.1, + ): + super(BaseModel, self).__init__() + self.hidden_size = hidden_size + self.word_embedding = Embedding(vocab_size, embed_dim) + self.encoder = Encoder( + vocab_size, + embed_dim, + hidden_size, + num_layers, + dropout_prob, + init_scale, + ) + self.decoder = Decoder( + vocab_size, + embed_dim, + hidden_size, + num_layers, + dropout_prob, + init_scale, + ) + + def forward(self, src, src_length, trg, trg_length): + encoder_output = self.encoder(src, src_length) + trg_emb = self.decoder.embedder(trg) + helper = TrainingHelper(inputs=trg_emb, sequence_length=trg_length) + decoder = BasicDecoder(self.decoder.stack_lstm.cell, helper) + ( + decoder_output, + decoder_final_state, + dec_seq_lengths, + ) = dynamic_decode( + decoder, + inits=self.decoder.stack_lstm.cell.get_initial_states( + encoder_output + ), + impute_finished=True, + is_test=False, + return_length=True, + ) + logits, samples, sample_length = ( + decoder_output.cell_outputs, + decoder_output.sample_ids, + dec_seq_lengths, + ) + return logits + + +class TestDynamicDecode(ModuleApiTest): + def setUp(self): + paddle.set_default_dtype("float64") + shape = (1, 10) + bs_shape = 1 + self.inputs = [ + np.random.randint(0, 10, size=shape).astype("int64"), + np.random.randint(0, 10, size=bs_shape).astype("int64"), + np.random.randint(0, 10, size=shape).astype("int64"), + np.random.randint(0, 10, size=bs_shape).astype("int64"), + ] + self.outputs = None + self.attrs = { + "vocab_size": 10, + "embed_dim": 32, + "hidden_size": 32, + } + self.param_states = {} + + @staticmethod + def model_init( + self, + vocab_size, + embed_dim, + hidden_size, + bos_id=0, + eos_id=1, + ): + self.model = BaseModel( + vocab_size=vocab_size, embed_dim=embed_dim, hidden_size=hidden_size + ) + + @staticmethod + def model_forward(model, src, src_length, trg, trg_length): + return model.model(src, src_length, trg, trg_length) + + def make_inputs(self): + inputs = [ + Input([None, None], "int64", "src"), + Input([None], "int64", "src_length"), + Input([None, None], "int64", "trg"), + Input([None], "int64", "trg_length"), + ] + return inputs + + def func_check_output(self): + self.setUp() + self.make_inputs() + self.check_output() + + def test_check_output(self): + with _test_eager_guard(): + self.func_check_output() + self.func_check_output() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_slice_op.py b/python/paddle/fluid/tests/unittests/test_slice_op.py index 8f122d23ad..af76b09047 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_slice_op.py @@ -23,6 +23,7 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core import paddle.fluid.layers as layers +from paddle.tensor.manipulation import tensor_array_to_tensor paddle.enable_static() @@ -703,7 +704,7 @@ class TestSliceApiWithLoDTensorArray(unittest.TestCase): paddle.tensor.array_length(arr) - 1 ) # dtype of end is int64 self.sliced_arr = slice_arr = arr[self.start : end] - output, _ = fluid.layers.tensor_array_to_tensor( + output, _ = tensor_array_to_tensor( slice_arr, axis=self.axis, use_stack=True ) elif case_num == 3: @@ -711,7 +712,7 @@ class TestSliceApiWithLoDTensorArray(unittest.TestCase): [1], "int64", 2147483648 ) self.sliced_arr = slice_arr = arr[self.start : value_int64] - output, _ = fluid.layers.tensor_array_to_tensor( + output, _ = tensor_array_to_tensor( slice_arr, axis=self.axis, use_stack=True ) diff --git a/python/paddle/fluid/tests/unittests/test_tensor_array_to_tensor.py b/python/paddle/fluid/tests/unittests/test_tensor_array_to_tensor.py index 4cc6d05b38..aee4c82ccc 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_array_to_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_array_to_tensor.py @@ -20,6 +20,7 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid import Program, program_guard +from paddle.tensor.manipulation import tensor_array_to_tensor paddle.enable_static() @@ -32,12 +33,12 @@ class TestTensorArrayToTensorError(unittest.TestCase): input_data = np.random.random((2, 4)).astype("float32") def test_Variable(): - fluid.layers.tensor_array_to_tensor(input=input_data) + tensor_array_to_tensor(input=input_data) self.assertRaises(TypeError, test_Variable) def test_list_Variable(): - fluid.layers.tensor_array_to_tensor(input=[input_data]) + tensor_array_to_tensor(input=[input_data]) self.assertRaises(TypeError, test_list_Variable) @@ -198,7 +199,7 @@ class TestLoDTensorArrayStack(unittest.TestCase): for i, x in enumerate(self.inputs): x = fluid.layers.assign(x) paddle.tensor.array_write(x, idx + i, array) - output, output_index = fluid.layers.tensor_array_to_tensor( + output, output_index = tensor_array_to_tensor( input=array, **self.attrs ) loss = paddle.sum(output) @@ -241,15 +242,13 @@ class TestTensorArrayToTensorAPI(unittest.TestCase): array = paddle.tensor.create_array(dtype='float32') paddle.tensor.array_write(x0, i, array) paddle.tensor.array_write(x1, i + 1, array) - output_stack, output_index_stack = fluid.layers.tensor_array_to_tensor( + output_stack, output_index_stack = tensor_array_to_tensor( input=array, axis=1, use_stack=True ) ( output_concat, output_index_concat, - ) = fluid.layers.tensor_array_to_tensor( - input=array, axis=1, use_stack=False - ) + ) = tensor_array_to_tensor(input=array, axis=1, use_stack=False) return ( output_stack, output_index_stack, diff --git a/python/paddle/nn/decode.py b/python/paddle/nn/decode.py index b5eedf767e..c2e652b63d 100644 --- a/python/paddle/nn/decode.py +++ b/python/paddle/nn/decode.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,18 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. + import collections +import warnings import numpy as np import paddle +from paddle.framework import _non_static_mode +from paddle.static import default_main_program -from ..fluid.layers import dynamic_decode # noqa: F401 +from ..fluid.data_feeder import convert_dtype from ..fluid.layers.utils import flatten, map_structure __all__ = [] +class ArrayWrapper: + def __init__(self, x): + self.array = [x] + + def append(self, x): + self.array.append(x) + return self + + def __getitem__(self, item): + return self.array.__getitem__(item) + + class Decoder: """ Decoder is the base class for any decoder instance used in `dynamic_decode`. @@ -650,3 +666,425 @@ class BeamSearchDecoder(Decoder): bool: A python bool `True`. """ return True + + +def _dynamic_decode_imperative( + decoder, + inits=None, + max_step_num=None, + output_time_major=False, + impute_finished=False, + is_test=False, + return_length=False, + **kwargs +): + def _maybe_copy(state, new_state, step_mask): + # TODO: use where_op + state_dtype = state.dtype + if convert_dtype(state_dtype) in ["bool"]: + state = paddle.cast(state, dtype="float32") + new_state = paddle.cast(new_state, dtype="float32") + if step_mask.dtype != state.dtype: + step_mask = paddle.cast(step_mask, dtype=state.dtype) + # otherwise, renamed bool gradients of would be summed up leading + # to sum(bool) error. + step_mask = step_mask.unsqueeze([1]) + step_mask.stop_gradient = True + new_state = paddle.multiply(state, step_mask) - paddle.multiply( + new_state, (step_mask - 1) + ) + if convert_dtype(state_dtype) in ["bool"]: + new_state = paddle.cast(new_state, dtype=state_dtype) + return new_state + + initial_inputs, initial_states, initial_finished = decoder.initialize(inits) + inputs, states, finished = ( + initial_inputs, + initial_states, + initial_finished, + ) + cond = paddle.logical_not((paddle.all(initial_finished))) + sequence_lengths = paddle.cast(paddle.zeros_like(initial_finished), "int64") + outputs = None + + step_idx = 0 + step_idx_tensor = paddle.full(shape=[1], fill_value=step_idx, dtype="int64") + while cond.numpy(): + (step_outputs, next_states, next_inputs, next_finished) = decoder.step( + step_idx_tensor, inputs, states, **kwargs + ) + if not decoder.tracks_own_finished: + # BeamSearchDecoder would track it own finished, since + # beams would be reordered and the finished status of each + # entry might change. Otherwise, perform logical OR which + # would not change the already finished. + next_finished = paddle.logical_or(next_finished, finished) + # To confirm states.finished/finished be consistent with + # next_finished. + paddle.assign(next_finished, finished) + next_sequence_lengths = paddle.add( + sequence_lengths, + paddle.cast( + paddle.logical_not(finished), sequence_lengths.dtype + ), + ) + if impute_finished: # rectify the states for the finished. + next_states = map_structure( + lambda x, y: _maybe_copy(x, y, finished), + states, + next_states, + ) + else: + warnings.warn( + "`next_states` has no `lengths` attribute, the returned `sequence_lengths` would be all zeros." + ) if not hasattr(next_states, "lengths") else None + next_sequence_lengths = getattr( + next_states, "lengths", sequence_lengths + ) + + outputs = ( + map_structure(lambda x: ArrayWrapper(x), step_outputs) + if step_idx == 0 + else map_structure( + lambda x, x_array: x_array.append(x), step_outputs, outputs + ) + ) + inputs, states, finished, sequence_lengths = ( + next_inputs, + next_states, + next_finished, + next_sequence_lengths, + ) + + step_idx_tensor = paddle.increment(x=step_idx_tensor, value=1.0) + step_idx += 1 + + cond = paddle.logical_not(paddle.all(finished)) + if max_step_num is not None and step_idx > max_step_num: + break + + final_outputs = map_structure( + lambda x: paddle.stack(x.array, axis=0), outputs + ) + final_states = states + + try: + final_outputs, final_states = decoder.finalize( + final_outputs, final_states, sequence_lengths + ) + except NotImplementedError: + pass + + if not output_time_major: + final_outputs = map_structure( + lambda x: paddle.transpose( + x, [1, 0] + list(range(2, len(x.shape))) + ), + final_outputs, + ) + + return ( + (final_outputs, final_states, sequence_lengths) + if return_length + else (final_outputs, final_states) + ) + + +def _dynamic_decode_declarative( + decoder, + inits=None, + max_step_num=None, + output_time_major=False, + impute_finished=False, + is_test=False, + return_length=False, + **kwargs +): + initial_inputs, initial_states, initial_finished = decoder.initialize(inits) + global_inputs, global_states, global_finished = ( + initial_inputs, + initial_states, + initial_finished, + ) + global_finished.stop_gradient = True + step_idx = paddle.full(shape=[1], fill_value=0, dtype="int64") + + cond = paddle.logical_not((paddle.all(initial_finished))) + if max_step_num is not None: + max_step_num = paddle.full( + shape=[1], fill_value=max_step_num, dtype="int64" + ) + + while_op = paddle.static.nn.control_flow.While(cond, is_test=is_test) + + sequence_lengths = paddle.cast(paddle.zeros_like(initial_finished), "int64") + sequence_lengths.stop_gradient = True + + if is_test: + # for test, reuse inputs and states variables to save memory + inputs = map_structure(lambda x: x, initial_inputs) + states = map_structure(lambda x: x, initial_states) + else: + # inputs and states of all steps must be saved for backward and training + inputs_arrays = map_structure( + lambda x: paddle.tensor.array.array_write(x, step_idx), + initial_inputs, + ) + states_arrays = map_structure( + lambda x: paddle.tensor.array.array_write(x, step_idx), + initial_states, + ) + + def _maybe_copy(state, new_state, step_mask): + # TODO: use where_op + state_dtype = state.dtype + if convert_dtype(state_dtype) in ["bool"]: + state = paddle.cast(state, dtype="float32") + new_state = paddle.cast(new_state, dtype="float32") + if step_mask.dtype != state.dtype: + step_mask = paddle.cast(step_mask, dtype=state.dtype) + # otherwise, renamed bool gradients of would be summed up leading + # to sum(bool) error. + step_mask = step_mask.unsqueeze([1]) + step_mask.stop_gradient = True + new_state = paddle.multiply(state, step_mask) - paddle.multiply( + new_state, (step_mask - 1) + ) + if convert_dtype(state_dtype) in ["bool"]: + new_state = paddle.cast(new_state, dtype=state_dtype) + return new_state + + def _transpose_batch_time(x): + return paddle.transpose(x, [1, 0] + list(range(2, len(x.shape)))) + + def _create_array_out_of_while(dtype): + current_block_idx = default_main_program().current_block_idx + default_main_program().current_block_idx = ( + default_main_program().current_block().parent_idx + ) + tensor_array = paddle.tensor.array.create_array(dtype) + default_main_program().current_block_idx = current_block_idx + return tensor_array + + # While + with while_op.block(): + if not is_test: + inputs = map_structure( + lambda array: paddle.tensor.array.array_read(array, step_idx), + inputs_arrays, + ) + states = map_structure( + lambda array: paddle.tensor.array.array_read(array, step_idx), + states_arrays, + ) + (outputs, next_states, next_inputs, next_finished) = decoder.step( + step_idx, inputs, states, **kwargs + ) + if not decoder.tracks_own_finished: + # BeamSearchDecoder would track it own finished, since beams would + # be reordered and the finished status of each entry might change. + # Otherwise, perform logical OR which would not change the already + # finished. + next_finished = paddle.logical_or(next_finished, global_finished) + next_sequence_lengths = paddle.add( + sequence_lengths, + paddle.cast( + paddle.logical_not(global_finished), + sequence_lengths.dtype, + ), + ) + if impute_finished: # rectify the states for the finished. + next_states = map_structure( + lambda x, y: _maybe_copy(x, y, global_finished), + states, + next_states, + ) + else: + warnings.warn( + "`next_states` has no `lengths` attribute, the returned `sequence_lengths` would be all zeros." + ) if not hasattr(next_states, "lengths") else None + next_sequence_lengths = getattr( + next_states, "lengths", sequence_lengths + ) + + # create tensor array in global block after dtype[s] of outputs can be got + outputs_arrays = map_structure( + lambda x: _create_array_out_of_while(x.dtype), outputs + ) + + map_structure( + lambda x, x_array: paddle.tensor.array.array_write( + x, i=step_idx, array=x_array + ), + outputs, + outputs_arrays, + ) + step_idx = paddle.increment(x=step_idx, value=1.0) + # update the global_finished first, since it might be also in states of + # decoder, which otherwise would write a stale finished status to array + paddle.assign(next_finished, global_finished) + paddle.assign(next_sequence_lengths, sequence_lengths) + if is_test: + map_structure(paddle.assign, next_inputs, global_inputs) + map_structure(paddle.assign, next_states, global_states) + else: + map_structure( + lambda x, x_array: paddle.tensor.array.array_write( + x, i=step_idx, array=x_array + ), + next_inputs, + inputs_arrays, + ) + map_structure( + lambda x, x_array: paddle.tensor.array.array_write( + x, i=step_idx, array=x_array + ), + next_states, + states_arrays, + ) + if max_step_num is not None: + paddle.logical_and( + paddle.logical_not(paddle.all(global_finished)), + paddle.less_equal(step_idx, max_step_num), + cond, + ) + else: + paddle.logical_not(paddle.all(global_finished), cond) + + final_outputs = map_structure( + lambda array: paddle.tensor.manipulation.tensor_array_to_tensor( + array, axis=0, use_stack=True + )[0], + outputs_arrays, + ) + if is_test: + final_states = global_states + else: + final_states = map_structure( + lambda array: paddle.tensor.array.array_read(array, step_idx), + states_arrays, + ) + + try: + final_outputs, final_states = decoder.finalize( + final_outputs, final_states, sequence_lengths + ) + except NotImplementedError: + pass + + if not output_time_major: + final_outputs = map_structure(_transpose_batch_time, final_outputs) + + return ( + (final_outputs, final_states, sequence_lengths) + if return_length + else (final_outputs, final_states) + ) + + +def dynamic_decode( + decoder, + inits=None, + max_step_num=None, + output_time_major=False, + impute_finished=False, + is_test=False, + return_length=False, + **kwargs +): + r""" + Dynamic decoding performs :code:`decoder.step()` repeatedly until the returned + Tensor indicating finished status contains all True values or the number of + decoding step reaches to :attr:`max_step_num`. + + :code:`decoder.initialize()` would be called once before the decoding loop. + If the `decoder` has implemented `finalize` method, :code:`decoder.finalize()` + would be called once after the decoding loop. + + Parameters: + decoder(Decoder): An instance of `Decoder`. + inits(object, optional): Argument passed to `decoder.initialize`. + Default `None`. + max_step_num(int, optional): The maximum number of steps. If not provided, + decode until the decoder is fully done, or in other words, the returned + Tensor by :code:`decoder.step()` indicating finished status contains + all True. Default `None`. + output_time_major(bool, optional): Indicate the data layout of Tensor included + in the final outputs(the first returned value of this method). If + attr:`False`, the data layout would be batch major with shape + `[batch_size, seq_len, ...]`. If attr:`True`, the data layout would + be time major with shape `[seq_len, batch_size, ...]`. Default: `False`. + impute_finished(bool, optional): If `True` and `decoder.tracks_own_finished` + is False, then states get copied through for batch entries which are + marked as finished, which differs with the unfinished using the new states + returned by :code:`decoder.step()` and ensures that the final states have + the correct values. Otherwise, states wouldn't be copied through when + finished. If the returned `final_states` is needed, it should be set as + True, which causes some slowdown. Default `False`. + is_test(bool, optional): A flag indicating whether to use test mode. In + test mode, it is more memory saving. Default `False`. + return_length(bool, optional): A flag indicating whether to return an + extra Tensor variable in the output tuple, which stores the actual + lengths of all decoded sequences. Default `False`. + **kwargs: Additional keyword arguments. Arguments passed to `decoder.step`. + + Returns: + tuple: A tuple( :code:`(final_outputs, final_states, sequence_lengths)` ) \ + when `return_length` is True, otherwise a tuple( :code:`(final_outputs, final_states)` ). \ + The final outputs and states, both are Tensor or nested structure of Tensor. \ + `final_outputs` has the same structure and data types as the :code:`outputs` \ + returned by :code:`decoder.step()` , and each Tenser in `final_outputs` \ + is the stacked of all decoding steps' outputs, which might be revised \ + by :code:`decoder.finalize()` if the decoder has implemented `finalize`. \ + `final_states` is the counterpart at last time step of initial states \ + returned by :code:`decoder.initialize()` , thus has the same structure \ + with it and has tensors with same shapes and data types. `sequence_lengths` \ + is an `int64` tensor with the same shape as `finished` returned \ + by :code:`decoder.initialize()` , and it stores the actual lengths of \ + all decoded sequences. + + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + from paddle.nn import BeamSearchDecoder, dynamic_decode + from paddle.nn import GRUCell, Linear, Embedding + trg_embeder = Embedding(100, 32) + output_layer = Linear(32, 32) + decoder_cell = GRUCell(input_size=32, hidden_size=32) + decoder = BeamSearchDecoder(decoder_cell, + start_token=0, + end_token=1, + beam_size=4, + embedding_fn=trg_embeder, + output_fn=output_layer) + encoder_output = paddle.ones((4, 8, 32), dtype=paddle.get_default_dtype()) + outputs = dynamic_decode(decoder=decoder, + inits=decoder_cell.get_initial_states(encoder_output), + max_step_num=10) + """ + if _non_static_mode(): + return _dynamic_decode_imperative( + decoder, + inits, + max_step_num, + output_time_major, + impute_finished, + is_test, + return_length, + **kwargs + ) + else: + return _dynamic_decode_declarative( + decoder, + inits, + max_step_num, + output_time_major, + impute_finished, + is_test, + return_length, + **kwargs + ) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index c255a22368..7bb9a675d7 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -44,6 +44,120 @@ from .creation import _complex_to_real_dtype, _real_to_complex_dtype, zeros __all__ = [] +def tensor_array_to_tensor(input, axis=1, use_stack=False, name=None): + r""" + This function concatenates or stacks all tensors in the input LoDTensorArray + along the axis mentioned and returns that as the output. + + For Example: + + .. code-block:: text + + Case 1: + + Given: + + input.data = {[[0.6, 0.1, 0.3], + [0.5, 0.3, 0.2]], + [[1.3], + [1.8]], + [[2.3, 2.1], + [2.5, 2.4]]} + + axis = 1, use_stack = False + + Then: + + output.data = [[0.6, 0.1, 0.3, 1.3, 2.3, 2.1], + [0.5, 0.3, 0.2, 1.8, 2.5, 2.4]] + + output_index.data = [3, 1, 2] + + Case 2: + + Given: + + input.data = {[[0.6, 0.1], + [0.5, 0.3]], + [[0.3, 1.3], + [0.2, 1.8]], + [[2.3, 2.1], + [2.5, 2.4]]} + + axis = 1, use_stack = True + + Then: + + output.data = [[[0.6, 0.1] + [0.3, 1.3] + [2.3, 2.1], + [[0.5, 0.3] + [0.2, 1.8] + [2.5, 2.4]]] + + output_index.data = [2, 2, 2] + + Args: + input(TensorArray): A TensorArray variable. + axis(int): The axis along which the tensors in attr::`input` will be + concatenated or stacked. + use_stack(bool): Act as concat_op or stack_op. For stack mode, all + tensors in the tensor array must have the same shape. + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Tensor: The concatenated or stacked tensor variable. + Tensor: A 1-D tensor variable with int32 data type. The data in this \ + tensor contains all input including tensors' sizes along the axis. + + Examples: + .. code-block:: python + + import numpy + import paddle + x0 = paddle.assign(numpy.random.rand(2, 2).astype("float32")) + x1 = paddle.assign(numpy.random.rand(2, 2).astype("float32")) + i = paddle.full(shape=[1], dtype="int64", fill_value=0) + array = paddle.tensor.array.create_array(dtype='float32') + paddle.tensor.array.array_write(x0, i, array) + paddle.tensor.array.array_write(x1, i + 1, array) + output, output_index = paddle.tensor.manipulation.tensor_array_to_tensor(input=array) + """ + if _non_static_mode(): + assert isinstance( + input, list + ), "The 'input' in tensor_array_to_tensor must be list" + from paddle import concat, stack + + op = stack if use_stack else concat + res = op(input, axis=axis) + sizes = paddle.to_tensor( + np.array(list(map(lambda x: int(x.shape[axis]), input))) + ) + return res, sizes + + check_type(input, 'input', (list, Variable), 'tensor_array_to_tensor') + if isinstance(input, list): + for i, input_x in enumerate(input): + check_type( + input_x, + 'input[' + str(i) + ']', + Variable, + 'tensor_array_to_tensor', + ) + helper = LayerHelper('tensor_array_to_tensor', **locals()) + out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) + out_index = helper.create_variable_for_type_inference(dtype="int32") + helper.append_op( + type='tensor_array_to_tensor', + inputs={'X': input}, + outputs={'Out': [out], 'OutIndex': [out_index]}, + attrs={'axis': axis, 'use_stack': use_stack}, + ) + return out, out_index + + def cast(x, dtype): """ -- GitLab