rnn_numpy.py 18.8 KB
Newer Older
F
Feiyu Chan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import math


class LayerMixin(object):
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)


class LayerListMixin(LayerMixin):
    def __init__(self, layers=None):
        self._layers = list(layers) if layers else []

    def append(self, layer):
        self._layers.append(layer)

    def __iter__(self):
        return iter(self._layers)


class SimpleRNNCell(LayerMixin):
36 37 38 39 40 41 42 43
    def __init__(
        self,
        input_size,
        hidden_size,
        bias=True,
        nonlinearity="RNN_TANH",
        dtype="float64",
    ):
F
Feiyu Chan 已提交
44 45 46
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
47
        if nonlinearity == 'RNN_TANH':
F
Feiyu Chan 已提交
48 49
            self.nonlinearity = np.tanh
        else:
50
            self.nonlinearity = lambda x: np.maximum(x, 0.0)
F
Feiyu Chan 已提交
51 52 53

        self.parameters = dict()
        std = 1.0 / math.sqrt(hidden_size)
54
        self.weight_ih = np.random.uniform(
55 56
            -std, std, (hidden_size, input_size)
        ).astype(dtype)
57
        self.weight_hh = np.random.uniform(
58 59
            -std, std, (hidden_size, hidden_size)
        ).astype(dtype)
F
Feiyu Chan 已提交
60 61 62
        self.parameters['weight_ih'] = self.weight_ih
        self.parameters['weight_hh'] = self.weight_hh
        if bias:
63 64 65 66 67 68
            self.bias_ih = np.random.uniform(-std, std, (hidden_size,)).astype(
                dtype
            )
            self.bias_hh = np.random.uniform(-std, std, (hidden_size,)).astype(
                dtype
            )
F
Feiyu Chan 已提交
69 70 71 72 73 74
            self.parameters['bias_ih'] = self.bias_ih
            self.parameters['bias_hh'] = self.bias_hh
        else:
            self.bias_ih = None
            self.bias_hh = None

75 76
    def init_state(self, inputs, batch_dim_index=0):
        batch_size = inputs.shape[batch_dim_index]
F
Feiyu Chan 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
        return np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype)

    def forward(self, inputs, hx=None):
        if hx is None:
            hx = self.init_state(inputs)
        pre_h = hx
        i2h = np.matmul(inputs, self.weight_ih.T)
        if self.bias_ih is not None:
            i2h += self.bias_ih
        h2h = np.matmul(pre_h, self.weight_hh.T)
        if self.bias_hh is not None:
            h2h += self.bias_hh
        h = self.nonlinearity(i2h + h2h)
        return h, h


class GRUCell(LayerMixin):
94
    def __init__(self, input_size, hidden_size, bias=True, dtype="float64"):
F
Feiyu Chan 已提交
95 96 97 98 99
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.parameters = dict()
        std = 1.0 / math.sqrt(hidden_size)
100
        self.weight_ih = np.random.uniform(
101 102
            -std, std, (3 * hidden_size, input_size)
        ).astype(dtype)
103
        self.weight_hh = np.random.uniform(
104 105
            -std, std, (3 * hidden_size, hidden_size)
        ).astype(dtype)
F
Feiyu Chan 已提交
106 107 108
        self.parameters['weight_ih'] = self.weight_ih
        self.parameters['weight_hh'] = self.weight_hh
        if bias:
109 110 111 112 113 114
            self.bias_ih = np.random.uniform(
                -std, std, (3 * hidden_size)
            ).astype(dtype)
            self.bias_hh = np.random.uniform(
                -std, std, (3 * hidden_size)
            ).astype(dtype)
F
Feiyu Chan 已提交
115 116 117 118 119 120
            self.parameters['bias_ih'] = self.bias_ih
            self.parameters['bias_hh'] = self.bias_hh
        else:
            self.bias_ih = None
            self.bias_hh = None

121 122
    def init_state(self, inputs, batch_dim_index=0):
        batch_size = inputs.shape[batch_dim_index]
F
Feiyu Chan 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
        return np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype)

    def forward(self, inputs, hx=None):
        if hx is None:
            hx = self.init_state(inputs)
        pre_hidden = hx
        x_gates = np.matmul(inputs, self.weight_ih.T)
        if self.bias_ih is not None:
            x_gates = x_gates + self.bias_ih
        h_gates = np.matmul(pre_hidden, self.weight_hh.T)
        if self.bias_hh is not None:
            h_gates = h_gates + self.bias_hh
        x_r, x_z, x_c = np.split(x_gates, 3, 1)
        h_r, h_z, h_c = np.split(h_gates, 3, 1)

        r = 1.0 / (1.0 + np.exp(-(x_r + h_r)))
        z = 1.0 / (1.0 + np.exp(-(x_z + h_z)))
        c = np.tanh(x_c + r * h_c)  # apply reset gate after mm
        h = (pre_hidden - c) * z + c
        return h, h


class LSTMCell(LayerMixin):
146
    def __init__(self, input_size, hidden_size, bias=True, dtype="float64"):
F
Feiyu Chan 已提交
147 148 149 150 151
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.parameters = dict()
        std = 1.0 / math.sqrt(hidden_size)
152
        self.weight_ih = np.random.uniform(
153 154
            -std, std, (4 * hidden_size, input_size)
        ).astype(dtype)
155
        self.weight_hh = np.random.uniform(
156 157
            -std, std, (4 * hidden_size, hidden_size)
        ).astype(dtype)
F
Feiyu Chan 已提交
158 159 160
        self.parameters['weight_ih'] = self.weight_ih
        self.parameters['weight_hh'] = self.weight_hh
        if bias:
161 162 163 164 165 166
            self.bias_ih = np.random.uniform(
                -std, std, (4 * hidden_size)
            ).astype(dtype)
            self.bias_hh = np.random.uniform(
                -std, std, (4 * hidden_size)
            ).astype(dtype)
F
Feiyu Chan 已提交
167 168 169 170 171 172
            self.parameters['bias_ih'] = self.bias_ih
            self.parameters['bias_hh'] = self.bias_hh
        else:
            self.bias_ih = None
            self.bias_hh = None

173 174
    def init_state(self, inputs, batch_dim_index=0):
        batch_size = inputs.shape[batch_dim_index]
F
Feiyu Chan 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
        init_h = np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype)
        init_c = np.zeros((batch_size, self.hidden_size), dtype=inputs.dtype)
        return init_h, init_c

    def forward(self, inputs, hx=None):
        if hx is None:
            hx = self.init_state(inputs)
        pre_hidden, pre_cell = hx
        gates = np.matmul(inputs, self.weight_ih.T)
        if self.bias_ih is not None:
            gates = gates + self.bias_ih
        gates += np.matmul(pre_hidden, self.weight_hh.T)
        if self.bias_hh is not None:
            gates = gates + self.bias_hh

        chunked_gates = np.split(gates, 4, -1)

        i = 1.0 / (1.0 + np.exp(-chunked_gates[0]))
        f = 1.0 / (1.0 + np.exp(-chunked_gates[1]))
        o = 1.0 / (1.0 + np.exp(-chunked_gates[3]))
        c = f * pre_cell + i * np.tanh(chunked_gates[2])
        h = o * np.tanh(c)

        return h, (h, c)


def sequence_mask(lengths, max_len=None):
    if max_len is None:
        max_len = np.max(lengths)
    else:
        assert max_len >= np.max(lengths)
    return np.arange(max_len) < np.expand_dims(lengths, -1)


def update_state(mask, new, old):
    if not isinstance(old, (tuple, list)):
        return np.where(mask, new, old)
    else:
        return tuple(map(lambda x, y: np.where(mask, x, y), new, old))


216 217 218 219 220 221 222 223
def rnn(
    cell,
    inputs,
    initial_states,
    sequence_length=None,
    time_major=False,
    is_reverse=False,
):
F
Feiyu Chan 已提交
224 225 226 227 228
    if not time_major:
        inputs = np.transpose(inputs, [1, 0, 2])
    if is_reverse:
        inputs = np.flip(inputs, 0)

229 230 231
    if initial_states is None:
        initial_states = cell.init_state(inputs, 1)

F
Feiyu Chan 已提交
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
    if sequence_length is None:
        mask = None
    else:
        mask = np.transpose(sequence_mask(sequence_length), [1, 0])
        mask = np.expand_dims(mask, -1)
        if is_reverse:
            mask = np.flip(mask, 0)

    time_steps = inputs.shape[0]
    state = initial_states
    outputs = []
    for t in range(time_steps):
        x_t = inputs[t]
        if mask is not None:
            m_t = mask[t]
            y, new_state = cell(x_t, state)
248
            y = np.where(m_t, y, 0.0)
F
Feiyu Chan 已提交
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
            outputs.append(y)
            state = update_state(m_t, new_state, state)
        else:
            y, new_state = cell(x_t, state)
            outputs.append(y)
            state = new_state

    outputs = np.stack(outputs)
    final_state = state

    if is_reverse:
        outputs = np.flip(outputs, 0)
    if not time_major:
        outputs = np.transpose(outputs, [1, 0, 2])
    return outputs, final_state


266 267 268 269 270 271 272 273
def birnn(
    cell_fw,
    cell_bw,
    inputs,
    initial_states,
    sequence_length=None,
    time_major=False,
):
F
Feiyu Chan 已提交
274
    states_fw, states_bw = initial_states
275 276 277 278 279 280 281 282 283 284 285 286
    outputs_fw, states_fw = rnn(
        cell_fw, inputs, states_fw, sequence_length, time_major=time_major
    )

    outputs_bw, states_bw = rnn(
        cell_bw,
        inputs,
        states_bw,
        sequence_length,
        time_major=time_major,
        is_reverse=True,
    )
F
Feiyu Chan 已提交
287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357

    outputs = np.concatenate((outputs_fw, outputs_bw), -1)
    final_states = (states_fw, states_bw)
    return outputs, final_states


def flatten(nested):
    return list(_flatten(nested))


def _flatten(nested):
    for item in nested:
        if isinstance(item, (list, tuple)):
            for subitem in _flatten(item):
                yield subitem
        else:
            yield item


def unstack(array, axis=0):
    num = array.shape[axis]
    sub_arrays = np.split(array, num, axis)
    return [np.squeeze(sub_array, axis) for sub_array in sub_arrays]


def dropout(array, p=0.5):
    if p == 0.0:
        return array
    mask = (np.random.uniform(size=array.shape) < (1 - p)).astype(array.dtype)
    return array * (mask / (1 - p))


def split_states(states, bidirectional=False, state_components=1):
    if state_components == 1:
        states = unstack(states)
        if not bidirectional:
            return states
        else:
            return list(zip(states[::2], states[1::2]))
    else:
        assert len(states) == state_components
        states = tuple([unstack(item) for item in states])
        if not bidirectional:
            return list(zip(*states))
        else:
            states = list(zip(*states))
            return list(zip(states[::2], states[1::2]))


def concat_states(states, bidirectional=False, state_components=1):
    if state_components == 1:
        return np.stack(flatten(states))
    else:
        states = flatten(states)
        componnets = []
        for i in range(state_components):
            componnets.append(states[i::state_components])
        return [np.stack(item) for item in componnets]


class RNN(LayerMixin):
    def __init__(self, cell, is_reverse=False, time_major=False):
        super(RNN, self).__init__()
        self.cell = cell
        if not hasattr(self.cell, "call"):
            # for non-dygraph mode, `rnn` api uses cell.call
            self.cell.call = self.cell.forward
        self.is_reverse = is_reverse
        self.time_major = time_major

    def forward(self, inputs, initial_states=None, sequence_length=None):
358 359 360 361 362 363 364 365
        final_outputs, final_states = rnn(
            self.cell,
            inputs,
            initial_states=initial_states,
            sequence_length=sequence_length,
            time_major=self.time_major,
            is_reverse=self.is_reverse,
        )
F
Feiyu Chan 已提交
366 367 368 369 370 371 372 373 374 375
        return final_outputs, final_states


class BiRNN(LayerMixin):
    def __init__(self, cell_fw, cell_bw, time_major=False):
        super(BiRNN, self).__init__()
        self.cell_fw = cell_fw
        self.cell_bw = cell_bw
        self.time_major = time_major

376 377 378
    def forward(
        self, inputs, initial_states=None, sequence_length=None, **kwargs
    ):
F
Feiyu Chan 已提交
379
        if isinstance(initial_states, (list, tuple)):
380 381 382
            assert (
                len(initial_states) == 2
            ), "length of initial_states should be 2 when it is a list/tuple"
F
Feiyu Chan 已提交
383 384 385
        else:
            initial_states = [initial_states, initial_states]

386 387 388 389 390 391 392 393
        outputs, final_states = birnn(
            self.cell_fw,
            self.cell_bw,
            inputs,
            initial_states,
            sequence_length,
            self.time_major,
        )
F
Feiyu Chan 已提交
394 395 396 397 398 399 400 401 402
        return outputs, final_states


class RNNMixin(LayerListMixin):
    def forward(self, inputs, initial_states=None, sequence_length=None):
        batch_index = 1 if self.time_major else 0
        batch_size = inputs.shape[batch_index]
        dtype = inputs.dtype
        if initial_states is None:
403 404 405 406 407
            state_shape = (
                self.num_layers * self.num_directions,
                batch_size,
                self.hidden_size,
            )
F
Feiyu Chan 已提交
408 409 410
            if self.state_components == 1:
                initial_states = np.zeros(state_shape, dtype)
            else:
411 412 413 414 415 416 417 418 419 420
                initial_states = tuple(
                    [
                        np.zeros(state_shape, dtype)
                        for _ in range(self.state_components)
                    ]
                )

        states = split_states(
            initial_states, self.num_directions == 2, self.state_components
        )
F
Feiyu Chan 已提交
421
        final_states = []
422
        input_temp = inputs
F
Feiyu Chan 已提交
423 424
        for i, rnn_layer in enumerate(self):
            if i > 0:
425
                input_temp = dropout(inputs, self.dropout)
426 427 428
            outputs, final_state = rnn_layer(
                input_temp, states[i], sequence_length
            )
F
Feiyu Chan 已提交
429 430 431
            final_states.append(final_state)
            inputs = outputs

432 433 434
        final_states = concat_states(
            final_states, self.num_directions == 2, self.state_components
        )
F
Feiyu Chan 已提交
435 436 437 438
        return outputs, final_states


class SimpleRNN(RNNMixin):
439 440 441 442 443 444 445 446 447 448 449
    def __init__(
        self,
        input_size,
        hidden_size,
        num_layers=1,
        nonlinearity="RNN_TANH",
        direction="forward",
        dropout=0.0,
        time_major=False,
        dtype="float64",
    ):
F
Feiyu Chan 已提交
450
        super(SimpleRNN, self).__init__()
451 452 453
        bidirectional_list = ["bidirectional", "bidirect"]
        if direction in ["forward"]:
            is_reverse = False
454 455 456
            cell = SimpleRNNCell(
                input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype
            )
F
Feiyu Chan 已提交
457 458
            self.append(RNN(cell, is_reverse, time_major))
            for i in range(1, num_layers):
459 460 461 462 463 464
                cell = SimpleRNNCell(
                    hidden_size,
                    hidden_size,
                    nonlinearity=nonlinearity,
                    dtype=dtype,
                )
F
Feiyu Chan 已提交
465
                self.append(RNN(cell, is_reverse, time_major))
466
        elif direction in bidirectional_list:
467 468 469 470 471 472
            cell_fw = SimpleRNNCell(
                input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype
            )
            cell_bw = SimpleRNNCell(
                input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype
            )
F
Feiyu Chan 已提交
473 474
            self.append(BiRNN(cell_fw, cell_bw, time_major))
            for i in range(1, num_layers):
475 476 477 478 479 480
                cell_fw = SimpleRNNCell(
                    2 * hidden_size, hidden_size, nonlinearity, dtype=dtype
                )
                cell_bw = SimpleRNNCell(
                    2 * hidden_size, hidden_size, nonlinearity, dtype=dtype
                )
F
Feiyu Chan 已提交
481 482 483 484
                self.append(BiRNN(cell_fw, cell_bw, time_major))
        else:
            raise ValueError(
                "direction should be forward, backward or bidirectional, "
485 486
                "received direction = {}".format(direction)
            )
F
Feiyu Chan 已提交
487 488 489 490

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.dropout = dropout
491
        self.num_directions = 2 if direction in bidirectional_list else 1
F
Feiyu Chan 已提交
492 493 494 495 496 497
        self.time_major = time_major
        self.num_layers = num_layers
        self.state_components = 1


class LSTM(RNNMixin):
498 499 500 501 502 503 504 505 506 507
    def __init__(
        self,
        input_size,
        hidden_size,
        num_layers=1,
        direction="forward",
        dropout=0.0,
        time_major=False,
        dtype="float64",
    ):
F
Feiyu Chan 已提交
508 509
        super(LSTM, self).__init__()

510 511 512
        bidirectional_list = ["bidirectional", "bidirect"]
        if direction in ["forward"]:
            is_reverse = False
513
            cell = LSTMCell(input_size, hidden_size, dtype=dtype)
F
Feiyu Chan 已提交
514 515
            self.append(RNN(cell, is_reverse, time_major))
            for i in range(1, num_layers):
516
                cell = LSTMCell(hidden_size, hidden_size, dtype=dtype)
F
Feiyu Chan 已提交
517
                self.append(RNN(cell, is_reverse, time_major))
518
        elif direction in bidirectional_list:
519 520
            cell_fw = LSTMCell(input_size, hidden_size, dtype=dtype)
            cell_bw = LSTMCell(input_size, hidden_size, dtype=dtype)
F
Feiyu Chan 已提交
521 522
            self.append(BiRNN(cell_fw, cell_bw, time_major))
            for i in range(1, num_layers):
523 524
                cell_fw = LSTMCell(2 * hidden_size, hidden_size, dtype=dtype)
                cell_bw = LSTMCell(2 * hidden_size, hidden_size, dtype=dtype)
F
Feiyu Chan 已提交
525 526 527 528
                self.append(BiRNN(cell_fw, cell_bw, time_major))
        else:
            raise ValueError(
                "direction should be forward, backward or bidirectional, "
529 530
                "received direction = {}".format(direction)
            )
F
Feiyu Chan 已提交
531 532 533 534

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.dropout = dropout
535
        self.num_directions = 2 if direction in bidirectional_list else 1
F
Feiyu Chan 已提交
536 537 538 539 540 541
        self.time_major = time_major
        self.num_layers = num_layers
        self.state_components = 2


class GRU(RNNMixin):
542 543 544 545 546 547 548 549 550 551
    def __init__(
        self,
        input_size,
        hidden_size,
        num_layers=1,
        direction="forward",
        dropout=0.0,
        time_major=False,
        dtype="float64",
    ):
F
Feiyu Chan 已提交
552 553
        super(GRU, self).__init__()

554 555 556
        bidirectional_list = ["bidirectional", "bidirect"]
        if direction in ["forward"]:
            is_reverse = False
557
            cell = GRUCell(input_size, hidden_size, dtype=dtype)
F
Feiyu Chan 已提交
558 559
            self.append(RNN(cell, is_reverse, time_major))
            for i in range(1, num_layers):
560
                cell = GRUCell(hidden_size, hidden_size, dtype=dtype)
F
Feiyu Chan 已提交
561
                self.append(RNN(cell, is_reverse, time_major))
562
        elif direction in bidirectional_list:
563 564
            cell_fw = GRUCell(input_size, hidden_size, dtype=dtype)
            cell_bw = GRUCell(input_size, hidden_size, dtype=dtype)
F
Feiyu Chan 已提交
565 566
            self.append(BiRNN(cell_fw, cell_bw, time_major))
            for i in range(1, num_layers):
567 568
                cell_fw = GRUCell(2 * hidden_size, hidden_size, dtype=dtype)
                cell_bw = GRUCell(2 * hidden_size, hidden_size, dtype=dtype)
F
Feiyu Chan 已提交
569 570 571 572
                self.append(BiRNN(cell_fw, cell_bw, time_major))
        else:
            raise ValueError(
                "direction should be forward, backward or bidirectional, "
573 574
                "received direction = {}".format(direction)
            )
F
Feiyu Chan 已提交
575 576 577 578

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.dropout = dropout
579
        self.num_directions = 2 if direction in bidirectional_list else 1
F
Feiyu Chan 已提交
580 581 582
        self.time_major = time_major
        self.num_layers = num_layers
        self.state_components = 1