test_gru_op.py 6.4 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

G
guosheng 已提交
17 18 19
import unittest
import numpy as np
import math
M
minqiyang 已提交
20
import functools
21
from op_test import OpTest
T
tensor-tang 已提交
22 23 24 25 26 27 28 29 30 31 32
from test_lstm_op import ACTIVATION


def gru(
        input,  # T x 3D
        lod,  # 1 x N
        h0,  # N x D
        weight,  # D x 3D
        bias,  # 1 x 3D
        is_reverse,
        act_state,
33
        act_gate,
Q
Qiao Longfei 已提交
34 35
        dtype='float32',
        origin_mode=False):
T
tensor-tang 已提交
36
    def _seq_to_batch(lod, is_reverse):
G
guosheng 已提交
37
        idx_in_seq_list = []
38 39 40 41
        seq_lens = lod[0]
        seq_starts = [0]
        for i in range(len(seq_lens)):
            seq_starts.append(seq_starts[-1] + seq_lens[i])
G
guosheng 已提交
42
        sorted_seqs = sorted(
M
minqiyang 已提交
43 44
            list(range(len(seq_lens))),
            key=functools.cmp_to_key(lambda x, y: seq_lens[y] - seq_lens[x]))
G
guosheng 已提交
45 46 47 48 49 50 51 52 53 54 55
        num_batch = seq_lens[sorted_seqs[0]]
        for batch_idx in range(num_batch):
            idx_in_seq = []
            for i in range(len(seq_lens)):
                if seq_lens[sorted_seqs[i]] <= batch_idx:
                    break
                idx = (seq_starts[sorted_seqs[i] + 1] - 1 - batch_idx
                       ) if is_reverse else (
                           seq_starts[sorted_seqs[i]] + batch_idx)
                idx_in_seq.append(idx)
            idx_in_seq_list.append(idx_in_seq)
G
guosheng 已提交
56
        return idx_in_seq_list, sorted_seqs
G
guosheng 已提交
57

T
tensor-tang 已提交
58 59 60 61 62 63 64 65
    def _step(x, h_p, w, b, act_state, act_gate):
        T = x.shape[0]
        D = w.shape[0]
        g = x + np.tile(b, (T, 1))
        w_u_r = w.flatten()[:D * D * 2].reshape((D, D * 2))
        u_r = act_gate(np.dot(h_p, w_u_r) + g[:, :D * 2])
        u = u_r[:, :D]
        r = u_r[:, D:D * 2]
G
guosheng 已提交
66
        r_h_p = r * h_p
T
tensor-tang 已提交
67 68
        w_c = w.flatten()[D * D * 2:].reshape((D, D))
        c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:])
G
guosheng 已提交
69
        g = np.hstack((u_r, c))
Q
Qiao Longfei 已提交
70 71 72 73
        if origin_mode:
            h = (1 - u) * c + u * h_p
        else:
            h = u * c + (1 - u) * h_p
G
guosheng 已提交
74 75
        return g, r_h_p, h

T
tensor-tang 已提交
76 77 78
    T = sum(lod[0])
    N = len(lod[0])
    D = weight.shape[0]
79 80 81 82
    batch_gate = np.zeros((T, 3 * D), dtype=dtype)
    batch_reset_hidden_prev = np.zeros((T, D), dtype=dtype)
    batch_hidden = np.zeros((T, D), dtype=dtype)
    hidden = np.zeros((T, D), dtype=dtype)
T
tensor-tang 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100

    idx_in_seq_list, sorted_seqs = _seq_to_batch(lod, is_reverse)
    h_p = h0[sorted_seqs]
    max_seq_len = len(idx_in_seq_list)
    assert len(idx_in_seq_list[0]) == N
    end_idx = 0
    for batch_idx in range(max_seq_len):
        x = input[idx_in_seq_list[batch_idx]]
        g, r_h_p, h = _step(x, h_p, weight, bias, act_state, act_gate)
        if batch_idx < (max_seq_len - 1):
            h_p = h[:len(idx_in_seq_list[batch_idx + 1])]
        start_idx = end_idx
        end_idx = start_idx + len(idx_in_seq_list[batch_idx])
        batch_gate[start_idx:end_idx] = g
        batch_reset_hidden_prev[start_idx:end_idx] = r_h_p
        batch_hidden[start_idx:end_idx] = h
        hidden[idx_in_seq_list[batch_idx]] = h
    return batch_gate, batch_reset_hidden_prev, batch_hidden, hidden
G
guosheng 已提交
101 102


T
tensor-tang 已提交
103
class TestGRUOp(OpTest):
G
guosheng 已提交
104
    def set_confs(self):
T
tensor-tang 已提交
105
        pass
G
guosheng 已提交
106 107 108

    def setUp(self):
        self.op_type = "gru"
T
tensor-tang 已提交
109 110 111 112 113 114 115
        self.lod = [[2, 4, 3]]
        self.D = 5
        self.is_reverse = False
        self.with_h0 = True
        self.with_bias = True
        self.act_state = 'tanh'
        self.act_gate = 'sigmoid'
116
        self.dtype = 'float64'
Q
Qiao Longfei 已提交
117
        self.origin_mode = False
G
guosheng 已提交
118
        self.set_confs()
T
tensor-tang 已提交
119 120 121 122

        T = sum(self.lod[0])
        N = len(self.lod[0])

123 124
        input = np.random.rand(T, 3 * self.D).astype(self.dtype)
        weight = np.random.rand(self.D, 3 * self.D).astype(self.dtype)
T
tensor-tang 已提交
125
        bias = np.random.rand(
126 127
            1, 3 * self.D).astype(self.dtype) if self.with_bias else np.zeros(
                (1, 3 * self.D), dtype=self.dtype)
T
tensor-tang 已提交
128
        h0 = np.random.rand(
129 130
            N, self.D).astype(self.dtype) if self.with_h0 else np.zeros(
                (N, self.D), dtype=self.dtype)
T
tensor-tang 已提交
131 132 133

        batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru(
            input, self.lod, h0, weight, bias, self.is_reverse,
Q
Qiao Longfei 已提交
134 135
            ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype,
            self.origin_mode)
T
tensor-tang 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
        self.inputs = {'Input': (input, self.lod), 'Weight': weight}

        if self.with_bias:
            self.inputs['Bias'] = bias

        if self.with_h0:
            self.inputs['H0'] = h0

        self.outputs = {
            'Hidden': (hidden, self.lod),
            'BatchGate': batch_gate,
            'BatchResetHiddenPrev': batch_reset_hidden_prev,
            'BatchHidden': batch_hidden,
        }

        self.attrs = {
            'activation': self.act_state,
            'gate_activation': self.act_gate,
Q
Qiao Longfei 已提交
154 155
            'is_reverse': self.is_reverse,
            'origin_mode': self.origin_mode
T
tensor-tang 已提交
156
        }
G
guosheng 已提交
157 158

    def test_check_output(self):
L
lujun 已提交
159
        self.check_output(atol=1e-8, check_dygraph=True)
G
guosheng 已提交
160 161 162 163 164

    def test_check_grad(self):
        self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])


Q
Qiao Longfei 已提交
165 166 167 168 169
class TestGRUOriginMode(TestGRUOp):
    def set_confs(self):
        self.origin_mode = True


170 171 172 173 174 175
class TestGRUOp2(TestGRUOp):
    def set_confs(self):
        self.D = 19
        self.dtype = 'float32'


Q
Qiao Longfei 已提交
176 177 178 179 180 181 182
class TestGRUOp2OriginMode(TestGRUOp):
    def set_confs(self):
        self.D = 19
        self.dtype = 'float32'
        self.origin_mode = True


G
guosheng 已提交
183
class TestGRUOpNoInitial(TestGRUOp):
T
tensor-tang 已提交
184 185
    def set_confs(self):
        self.with_h0 = False
G
guosheng 已提交
186 187 188 189 190

    def test_check_grad(self):
        self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden'])


T
tensor-tang 已提交
191 192 193 194 195 196 197 198
class TestGRUOpNoBias(TestGRUOp):
    def set_confs(self):
        self.with_bias = False

    def test_check_grad(self):
        self.check_grad(['Input', 'H0', 'Weight'], ['Hidden'])


G
guosheng 已提交
199 200 201 202 203
class TestGRUOpReverse(TestGRUOp):
    def set_confs(self):
        self.is_reverse = True


Q
Qiao Longfei 已提交
204 205 206 207 208 209
class TestGRUOpReverseOriginMode(TestGRUOp):
    def set_confs(self):
        self.is_reverse = True
        self.origin_mode = True


G
guosheng 已提交
210 211
if __name__ == "__main__":
    unittest.main()