test_gru_op.py 5.8 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

G
guosheng 已提交
17 18 19
import unittest
import numpy as np
import math
M
minqiyang 已提交
20
import functools
21
from op_test import OpTest
T
tensor-tang 已提交
22 23 24 25 26 27 28 29 30 31 32
from test_lstm_op import ACTIVATION


def gru(
        input,  # T x 3D
        lod,  # 1 x N
        h0,  # N x D
        weight,  # D x 3D
        bias,  # 1 x 3D
        is_reverse,
        act_state,
33 34
        act_gate,
        dtype='float32'):
T
tensor-tang 已提交
35
    def _seq_to_batch(lod, is_reverse):
G
guosheng 已提交
36
        idx_in_seq_list = []
37 38 39 40
        seq_lens = lod[0]
        seq_starts = [0]
        for i in range(len(seq_lens)):
            seq_starts.append(seq_starts[-1] + seq_lens[i])
G
guosheng 已提交
41
        sorted_seqs = sorted(
M
minqiyang 已提交
42 43
            list(range(len(seq_lens))),
            key=functools.cmp_to_key(lambda x, y: seq_lens[y] - seq_lens[x]))
G
guosheng 已提交
44 45 46 47 48 49 50 51 52 53 54
        num_batch = seq_lens[sorted_seqs[0]]
        for batch_idx in range(num_batch):
            idx_in_seq = []
            for i in range(len(seq_lens)):
                if seq_lens[sorted_seqs[i]] <= batch_idx:
                    break
                idx = (seq_starts[sorted_seqs[i] + 1] - 1 - batch_idx
                       ) if is_reverse else (
                           seq_starts[sorted_seqs[i]] + batch_idx)
                idx_in_seq.append(idx)
            idx_in_seq_list.append(idx_in_seq)
G
guosheng 已提交
55
        return idx_in_seq_list, sorted_seqs
G
guosheng 已提交
56

T
tensor-tang 已提交
57 58 59 60 61 62 63 64
    def _step(x, h_p, w, b, act_state, act_gate):
        T = x.shape[0]
        D = w.shape[0]
        g = x + np.tile(b, (T, 1))
        w_u_r = w.flatten()[:D * D * 2].reshape((D, D * 2))
        u_r = act_gate(np.dot(h_p, w_u_r) + g[:, :D * 2])
        u = u_r[:, :D]
        r = u_r[:, D:D * 2]
G
guosheng 已提交
65
        r_h_p = r * h_p
T
tensor-tang 已提交
66 67
        w_c = w.flatten()[D * D * 2:].reshape((D, D))
        c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:])
G
guosheng 已提交
68 69 70 71
        g = np.hstack((u_r, c))
        h = u * c + (1 - u) * h_p
        return g, r_h_p, h

T
tensor-tang 已提交
72 73 74
    T = sum(lod[0])
    N = len(lod[0])
    D = weight.shape[0]
75 76 77 78
    batch_gate = np.zeros((T, 3 * D), dtype=dtype)
    batch_reset_hidden_prev = np.zeros((T, D), dtype=dtype)
    batch_hidden = np.zeros((T, D), dtype=dtype)
    hidden = np.zeros((T, D), dtype=dtype)
T
tensor-tang 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

    idx_in_seq_list, sorted_seqs = _seq_to_batch(lod, is_reverse)
    h_p = h0[sorted_seqs]
    max_seq_len = len(idx_in_seq_list)
    assert len(idx_in_seq_list[0]) == N
    end_idx = 0
    for batch_idx in range(max_seq_len):
        x = input[idx_in_seq_list[batch_idx]]
        g, r_h_p, h = _step(x, h_p, weight, bias, act_state, act_gate)
        if batch_idx < (max_seq_len - 1):
            h_p = h[:len(idx_in_seq_list[batch_idx + 1])]
        start_idx = end_idx
        end_idx = start_idx + len(idx_in_seq_list[batch_idx])
        batch_gate[start_idx:end_idx] = g
        batch_reset_hidden_prev[start_idx:end_idx] = r_h_p
        batch_hidden[start_idx:end_idx] = h
        hidden[idx_in_seq_list[batch_idx]] = h
    return batch_gate, batch_reset_hidden_prev, batch_hidden, hidden
G
guosheng 已提交
97 98


T
tensor-tang 已提交
99
class TestGRUOp(OpTest):
G
guosheng 已提交
100
    def set_confs(self):
T
tensor-tang 已提交
101
        pass
G
guosheng 已提交
102 103 104

    def setUp(self):
        self.op_type = "gru"
T
tensor-tang 已提交
105 106 107 108 109 110 111
        self.lod = [[2, 4, 3]]
        self.D = 5
        self.is_reverse = False
        self.with_h0 = True
        self.with_bias = True
        self.act_state = 'tanh'
        self.act_gate = 'sigmoid'
112
        self.dtype = 'float64'
G
guosheng 已提交
113
        self.set_confs()
T
tensor-tang 已提交
114 115 116 117

        T = sum(self.lod[0])
        N = len(self.lod[0])

118 119
        input = np.random.rand(T, 3 * self.D).astype(self.dtype)
        weight = np.random.rand(self.D, 3 * self.D).astype(self.dtype)
T
tensor-tang 已提交
120
        bias = np.random.rand(
121 122
            1, 3 * self.D).astype(self.dtype) if self.with_bias else np.zeros(
                (1, 3 * self.D), dtype=self.dtype)
T
tensor-tang 已提交
123
        h0 = np.random.rand(
124 125
            N, self.D).astype(self.dtype) if self.with_h0 else np.zeros(
                (N, self.D), dtype=self.dtype)
T
tensor-tang 已提交
126 127 128

        batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru(
            input, self.lod, h0, weight, bias, self.is_reverse,
129
            ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype)
T
tensor-tang 已提交
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
        self.inputs = {'Input': (input, self.lod), 'Weight': weight}

        if self.with_bias:
            self.inputs['Bias'] = bias

        if self.with_h0:
            self.inputs['H0'] = h0

        self.outputs = {
            'Hidden': (hidden, self.lod),
            'BatchGate': batch_gate,
            'BatchResetHiddenPrev': batch_reset_hidden_prev,
            'BatchHidden': batch_hidden,
        }

        self.attrs = {
            'activation': self.act_state,
            'gate_activation': self.act_gate,
            'is_reverse': self.is_reverse
        }
G
guosheng 已提交
150 151

    def test_check_output(self):
T
tensor-tang 已提交
152
        self.check_output(atol=1e-8)
G
guosheng 已提交
153 154 155 156 157

    def test_check_grad(self):
        self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])


158 159 160 161 162 163
class TestGRUOp2(TestGRUOp):
    def set_confs(self):
        self.D = 19
        self.dtype = 'float32'


G
guosheng 已提交
164
class TestGRUOpNoInitial(TestGRUOp):
T
tensor-tang 已提交
165 166
    def set_confs(self):
        self.with_h0 = False
G
guosheng 已提交
167 168 169 170 171

    def test_check_grad(self):
        self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden'])


T
tensor-tang 已提交
172 173 174 175 176 177 178 179
class TestGRUOpNoBias(TestGRUOp):
    def set_confs(self):
        self.with_bias = False

    def test_check_grad(self):
        self.check_grad(['Input', 'H0', 'Weight'], ['Hidden'])


G
guosheng 已提交
180 181 182 183 184 185 186
class TestGRUOpReverse(TestGRUOp):
    def set_confs(self):
        self.is_reverse = True


if __name__ == "__main__":
    unittest.main()