rnn.py 17.4 KB
Newer Older
G
Guo Sheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import sys
G
Guo Sheng 已提交
16
from functools import partial, reduce
J
Jiaqi Liu 已提交
17
import warnings
G
Guo Sheng 已提交
18

19

20
import paddle
21
from paddle.utils import deprecated
G
Guo Sheng 已提交
22 23 24 25
from . import nn
from . import tensor
from . import control_flow
from . import utils
26
from . import sequence_lod
G
Guo Sheng 已提交
27
from .utils import *
weixin_46829950's avatar
weixin_46829950 已提交
28
from .. import core
29 30
from ..framework import default_main_program
from ..data_feeder import convert_dtype
31
from ..layer_helper import LayerHelper
J
Jiabin Yang 已提交
32
from ..framework import _non_static_mode
33
from ..param_attr import ParamAttr
X
Xing Wu 已提交
34
from ..data_feeder import check_variable_and_dtype, check_type, check_dtype
35

36
from collections.abc import Sequence
G
Guo Sheng 已提交
37 38 39 40 41 42

__all__ = [
    'dynamic_decode',
]


43
class ArrayWrapper:
F
Feiyu Chan 已提交
44 45 46 47 48 49 50
    def __init__(self, x):
        self.array = [x]

    def append(self, x):
        self.array.append(x)
        return self

51 52 53
    def __getitem__(self, item):
        return self.array.__getitem__(item)

F
Feiyu Chan 已提交
54

55 56 57 58 59 60 61 62 63 64
def _dynamic_decode_imperative(
    decoder,
    inits=None,
    max_step_num=None,
    output_time_major=False,
    impute_finished=False,
    is_test=False,
    return_length=False,
    **kwargs
):
65 66 67 68 69 70 71 72 73 74 75
    def _maybe_copy(state, new_state, step_mask):
        # TODO: use where_op
        state_dtype = state.dtype
        if convert_dtype(state_dtype) in ["bool"]:
            state = tensor.cast(state, dtype="float32")
            new_state = tensor.cast(new_state, dtype="float32")
        if step_mask.dtype != state.dtype:
            step_mask = tensor.cast(step_mask, dtype=state.dtype)
            # otherwise, renamed bool gradients of would be summed up leading
            # to sum(bool) error.
            step_mask.stop_gradient = True
76
        new_state = paddle.tensor.math._multiply_with_axis(
77
            state, step_mask, axis=0
78 79 80
        ) - paddle.tensor.math._multiply_with_axis(
            new_state, (step_mask - 1), axis=0
        )
81 82 83
        if convert_dtype(state_dtype) in ["bool"]:
            new_state = tensor.cast(new_state, dtype=state_dtype)
        return new_state
S
swtkiwi 已提交
84

85
    initial_inputs, initial_states, initial_finished = decoder.initialize(inits)
86 87 88 89 90
    inputs, states, finished = (
        initial_inputs,
        initial_states,
        initial_finished,
    )
91
    cond = paddle.logical_not((paddle.all(initial_finished)))
92
    sequence_lengths = tensor.cast(paddle.zeros_like(initial_finished), "int64")
93 94 95
    outputs = None

    step_idx = 0
96 97 98
    step_idx_tensor = tensor.fill_constant(
        shape=[1], dtype="int64", value=step_idx
    )
99
    while cond.numpy():
100 101 102
        (step_outputs, next_states, next_inputs, next_finished) = decoder.step(
            step_idx_tensor, inputs, states, **kwargs
        )
103 104 105 106 107
        if not decoder.tracks_own_finished:
            # BeamSearchDecoder would track it own finished, since
            # beams would be reordered and the finished status of each
            # entry might change. Otherwise, perform logical OR which
            # would not change the already finished.
2
201716010711 已提交
108
            next_finished = paddle.logical_or(next_finished, finished)
109 110 111
            # To confirm states.finished/finished be consistent with
            # next_finished.
            tensor.assign(next_finished, finished)
112
            next_sequence_lengths = paddle.add(
J
Jiaqi Liu 已提交
113
                sequence_lengths,
114
                tensor.cast(
2
201716010711 已提交
115
                    paddle.logical_not(finished), sequence_lengths.dtype
116 117
                ),
            )
J
Jiaqi Liu 已提交
118 119
            if impute_finished:  # rectify the states for the finished.
                next_states = map_structure(
120 121 122 123
                    lambda x, y: _maybe_copy(x, y, finished),
                    states,
                    next_states,
                )
J
Jiaqi Liu 已提交
124 125 126 127
        else:
            warnings.warn(
                "`next_states` has no `lengths` attribute, the returned `sequence_lengths` would be all zeros."
            ) if not hasattr(next_states, "lengths") else None
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
            next_sequence_lengths = getattr(
                next_states, "lengths", sequence_lengths
            )

        outputs = (
            map_structure(lambda x: ArrayWrapper(x), step_outputs)
            if step_idx == 0
            else map_structure(
                lambda x, x_array: x_array.append(x), step_outputs, outputs
            )
        )
        inputs, states, finished, sequence_lengths = (
            next_inputs,
            next_states,
            next_finished,
            next_sequence_lengths,
        )
G
Guo Sheng 已提交
145

146
        paddle.increment(x=step_idx_tensor, value=1.0)
147
        step_idx += 1
G
Guo Sheng 已提交
148

149
        cond = paddle.logical_not(paddle.all(finished))
150 151
        if max_step_num is not None and step_idx > max_step_num:
            break
G
Guo Sheng 已提交
152

153 154 155
    final_outputs = map_structure(
        lambda x: paddle.stack(x.array, axis=0), outputs
    )
156
    final_states = states
G
Guo Sheng 已提交
157

158
    try:
159 160 161
        final_outputs, final_states = decoder.finalize(
            final_outputs, final_states, sequence_lengths
        )
162 163
    except NotImplementedError:
        pass
G
Guo Sheng 已提交
164

165 166
    if not output_time_major:
        final_outputs = map_structure(
167 168 169
            lambda x: paddle.transpose(
                x, [1, 0] + list(range(2, len(x.shape)))
            ),
170 171
            final_outputs,
        )
172

173 174 175 176 177
    return (
        (final_outputs, final_states, sequence_lengths)
        if return_length
        else (final_outputs, final_states)
    )
178 179


180 181 182 183 184 185 186 187 188 189
def _dynamic_decode_declarative(
    decoder,
    inits=None,
    max_step_num=None,
    output_time_major=False,
    impute_finished=False,
    is_test=False,
    return_length=False,
    **kwargs
):
G
Guo Sheng 已提交
190
    initial_inputs, initial_states, initial_finished = decoder.initialize(inits)
191 192 193 194 195
    global_inputs, global_states, global_finished = (
        initial_inputs,
        initial_states,
        initial_finished,
    )
196
    global_finished.stop_gradient = True
G
Guo Sheng 已提交
197
    step_idx = tensor.fill_constant(shape=[1], dtype="int64", value=0)
198

199
    cond = paddle.logical_not((paddle.all(initial_finished)))
G
Guo Sheng 已提交
200
    if max_step_num is not None:
201 202 203
        max_step_num = tensor.fill_constant(
            shape=[1], dtype="int64", value=max_step_num
        )
204
    while_op = paddle.static.nn.control_flow.While(cond, is_test=is_test)
G
Guo Sheng 已提交
205

206
    sequence_lengths = tensor.cast(paddle.zeros_like(initial_finished), "int64")
207 208 209 210 211 212 213 214 215
    sequence_lengths.stop_gradient = True

    if is_test:
        # for test, reuse inputs and states variables to save memory
        inputs = map_structure(lambda x: x, initial_inputs)
        states = map_structure(lambda x: x, initial_states)
    else:
        # inputs and states of all steps must be saved for backward and training
        inputs_arrays = map_structure(
216
            lambda x: paddle.tensor.array_write(x, step_idx), initial_inputs
217
        )
218
        states_arrays = map_structure(
219
            lambda x: paddle.tensor.array_write(x, step_idx), initial_states
220
        )
G
Guo Sheng 已提交
221 222 223

    def _maybe_copy(state, new_state, step_mask):
        # TODO: use where_op
224 225 226 227 228 229 230 231 232
        state_dtype = state.dtype
        if convert_dtype(state_dtype) in ["bool"]:
            state = tensor.cast(state, dtype="float32")
            new_state = tensor.cast(new_state, dtype="float32")
        if step_mask.dtype != state.dtype:
            step_mask = tensor.cast(step_mask, dtype=state.dtype)
            # otherwise, renamed bool gradients of would be summed up leading
            # to sum(bool) error.
            step_mask.stop_gradient = True
233
        new_state = paddle.tensor.math._multiply_with_axis(
234
            state, step_mask, axis=0
235 236 237
        ) - paddle.tensor.math._multiply_with_axis(
            new_state, (step_mask - 1), axis=0
        )
238 239
        if convert_dtype(state_dtype) in ["bool"]:
            new_state = tensor.cast(new_state, dtype=state_dtype)
G
Guo Sheng 已提交
240 241 242
        return new_state

    def _transpose_batch_time(x):
243
        return paddle.transpose(x, [1, 0] + list(range(2, len(x.shape))))
G
Guo Sheng 已提交
244

245 246
    def _create_array_out_of_while(dtype):
        current_block_idx = default_main_program().current_block_idx
247 248 249
        default_main_program().current_block_idx = (
            default_main_program().current_block().parent_idx
        )
250
        tensor_array = paddle.tensor.create_array(dtype)
251 252 253
        default_main_program().current_block_idx = current_block_idx
        return tensor_array

G
Guo Sheng 已提交
254 255
    # While
    with while_op.block():
256 257
        if not is_test:
            inputs = map_structure(
258
                lambda array: paddle.tensor.array_read(array, step_idx),
259 260
                inputs_arrays,
            )
261
            states = map_structure(
262
                lambda array: paddle.tensor.array_read(array, step_idx),
263 264 265 266 267
                states_arrays,
            )
        (outputs, next_states, next_inputs, next_finished) = decoder.step(
            step_idx, inputs, states, **kwargs
        )
268 269 270 271 272
        if not decoder.tracks_own_finished:
            # BeamSearchDecoder would track it own finished, since beams would
            # be reordered and the finished status of each entry might change.
            # Otherwise, perform logical OR which would not change the already
            # finished.
2
201716010711 已提交
273
            next_finished = paddle.logical_or(next_finished, global_finished)
274
            next_sequence_lengths = paddle.add(
J
Jiaqi Liu 已提交
275
                sequence_lengths,
276
                tensor.cast(
2
201716010711 已提交
277
                    paddle.logical_not(global_finished),
278 279 280
                    sequence_lengths.dtype,
                ),
            )
J
Jiaqi Liu 已提交
281 282 283 284
            if impute_finished:  # rectify the states for the finished.
                next_states = map_structure(
                    lambda x, y: _maybe_copy(x, y, global_finished),
                    states,
285 286
                    next_states,
                )
J
Jiaqi Liu 已提交
287 288 289 290
        else:
            warnings.warn(
                "`next_states` has no `lengths` attribute, the returned `sequence_lengths` would be all zeros."
            ) if not hasattr(next_states, "lengths") else None
291 292 293
            next_sequence_lengths = getattr(
                next_states, "lengths", sequence_lengths
            )
294 295 296

        # create tensor array in global block after dtype[s] of outputs can be got
        outputs_arrays = map_structure(
297 298
            lambda x: _create_array_out_of_while(x.dtype), outputs
        )
299

G
Guo Sheng 已提交
300
        map_structure(
301
            lambda x, x_array: paddle.tensor.array_write(
302 303 304 305 306
                x, i=step_idx, array=x_array
            ),
            outputs,
            outputs_arrays,
        )
307 308

        paddle.increment(x=step_idx, value=1.0)
309 310 311 312
        # update the global_finished first, since it might be also in states of
        # decoder, which otherwise would write a stale finished status to array
        tensor.assign(next_finished, global_finished)
        tensor.assign(next_sequence_lengths, sequence_lengths)
313 314 315 316 317
        if is_test:
            map_structure(tensor.assign, next_inputs, global_inputs)
            map_structure(tensor.assign, next_states, global_states)
        else:
            map_structure(
318
                lambda x, x_array: paddle.tensor.array_write(
319 320 321 322 323
                    x, i=step_idx, array=x_array
                ),
                next_inputs,
                inputs_arrays,
            )
324
            map_structure(
325
                lambda x, x_array: paddle.tensor.array_write(
326 327 328 329 330
                    x, i=step_idx, array=x_array
                ),
                next_states,
                states_arrays,
            )
G
Guo Sheng 已提交
331
        if max_step_num is not None:
332
            paddle.logical_and(
333
                paddle.logical_not(paddle.all(global_finished)),
334
                paddle.less_equal(step_idx, max_step_num),
335 336
                cond,
            )
G
Guo Sheng 已提交
337
        else:
338
            paddle.logical_not(paddle.all(global_finished), cond)
G
Guo Sheng 已提交
339 340 341

    final_outputs = map_structure(
        lambda array: tensor.tensor_array_to_tensor(
342 343 344 345
            array, axis=0, use_stack=True
        )[0],
        outputs_arrays,
    )
346 347 348 349
    if is_test:
        final_states = global_states
    else:
        final_states = map_structure(
350
            lambda array: paddle.tensor.array_read(array, step_idx),
351 352
            states_arrays,
        )
G
Guo Sheng 已提交
353 354

    try:
355 356 357
        final_outputs, final_states = decoder.finalize(
            final_outputs, final_states, sequence_lengths
        )
G
Guo Sheng 已提交
358 359 360 361 362 363
    except NotImplementedError:
        pass

    if not output_time_major:
        final_outputs = map_structure(_transpose_batch_time, final_outputs)

364 365 366 367 368
    return (
        (final_outputs, final_states, sequence_lengths)
        if return_length
        else (final_outputs, final_states)
    )
369 370


371 372 373 374 375 376 377 378 379 380
def dynamic_decode(
    decoder,
    inits=None,
    max_step_num=None,
    output_time_major=False,
    impute_finished=False,
    is_test=False,
    return_length=False,
    **kwargs
):
381
    r"""
382 383 384 385 386 387 388 389 390 391
    Dynamic decoding performs :code:`decoder.step()` repeatedly until the returned
    Tensor indicating finished status contains all True values or the number of
    decoding step reaches to :attr:`max_step_num`.

    :code:`decoder.initialize()` would be called once before the decoding loop.
    If the `decoder` has implemented `finalize` method, :code:`decoder.finalize()`
    would be called once after the decoding loop.

    Parameters:
        decoder(Decoder): An instance of `Decoder`.
392
        inits(object, optional): Argument passed to `decoder.initialize`.
393 394 395 396 397 398 399 400 401 402
            Default `None`.
        max_step_num(int, optional): The maximum number of steps. If not provided,
            decode until the decoder is fully done, or in other words, the returned
            Tensor by :code:`decoder.step()` indicating finished status contains
            all True. Default `None`.
        output_time_major(bool, optional): Indicate the data layout of Tensor included
            in the final outputs(the first returned value of this method). If
            attr:`False`, the data layout would be batch major with shape
            `[batch_size, seq_len, ...]`.  If attr:`True`, the data layout would
            be time major with shape `[seq_len, batch_size, ...]`. Default: `False`.
J
Jiaqi Liu 已提交
403 404 405 406 407 408 409
        impute_finished(bool, optional): If `True` and `decoder.tracks_own_finished`
            is False, then states get copied through for batch entries which are
            marked as finished, which differs with the unfinished using the new states
            returned by :code:`decoder.step()` and ensures that the final states have
            the correct values. Otherwise, states wouldn't be copied through when
            finished. If the returned `final_states` is needed, it should be set as
            True, which causes some slowdown. Default `False`.
410 411 412 413 414
        is_test(bool, optional): A flag indicating whether to use test mode. In
            test mode, it is more memory saving. Default `False`.
        return_length(bool, optional):  A flag indicating whether to return an
            extra Tensor variable in the output tuple, which stores the actual
            lengths of all decoded sequences. Default `False`.
415
        **kwargs: Additional keyword arguments. Arguments passed to `decoder.step`.
416 417

    Returns:
418

Z
Zman 已提交
419 420 421 422 423 424 425 426 427 428 429
        - final_outputs (Tensor, nested structure of Tensor), each Tensor in :code:`final_outputs` is the stacked of all decoding steps' outputs, which might be revised
            by :code:`decoder.finalize()` if the decoder has implemented finalize.
            And :code:`final_outputs` has the same structure and data types as the :code:`outputs`
            returned by :code:`decoder.step()`

        - final_states (Tensor, nested structure of Tensor), :code:`final_states` is the counterpart at last time step of initial states \
            returned by :code:`decoder.initialize()` , thus has the same structure
            with it and has tensors with same shapes and data types.

        - sequence_lengths (Tensor), stores the actual lengths of all decoded sequences.
            sequence_lengths is provided only if :code:`return_length` is True.
430 431 432 433

    Examples:

        .. code-block:: python
434

435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
            import paddle
            from paddle.nn import BeamSearchDecoder, dynamic_decode
            from paddle.nn import GRUCell, Linear, Embedding
            trg_embeder = Embedding(100, 32)
            output_layer = Linear(32, 32)
            decoder_cell = GRUCell(input_size=32, hidden_size=32)
            decoder = BeamSearchDecoder(decoder_cell,
                                        start_token=0,
                                        end_token=1,
                                        beam_size=4,
                                        embedding_fn=trg_embeder,
                                        output_fn=output_layer)
            encoder_output = paddle.ones((4, 8, 32), dtype=paddle.get_default_dtype())
            outputs = dynamic_decode(decoder=decoder,
                                    inits=decoder_cell.get_initial_states(encoder_output),
                                    max_step_num=10)
    """
J
Jiabin Yang 已提交
452
    if _non_static_mode():
453 454 455 456 457 458 459 460 461 462
        return _dynamic_decode_imperative(
            decoder,
            inits,
            max_step_num,
            output_time_major,
            impute_finished,
            is_test,
            return_length,
            **kwargs
        )
463
    else:
464 465 466 467 468 469 470 471 472 473
        return _dynamic_decode_declarative(
            decoder,
            inits,
            max_step_num,
            output_time_major,
            impute_finished,
            is_test,
            return_length,
            **kwargs
        )