test_lstm_op.py 10.2 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import unittest
16

17
import numpy as np
18
from op_test import OpTest
19

20 21 22 23
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0

24 25 26 27 28 29

def identity(x):
    return x


def sigmoid(x):
30 31 32
    y = np.copy(x)
    y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
    y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
33
    return 1.0 / (1.0 + np.exp(-y))
34 35 36


def tanh(x):
37
    y = -2.0 * x
38
    y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
39
    return (2.0 / (1.0 + np.exp(y))) - 1.0
40 41 42 43 44 45


def relu(x):
    return np.maximum(x, 0)


46
ACTIVATION = {
D
dangqingqing 已提交
47 48 49
    'identity': identity,
    'sigmoid': sigmoid,
    'tanh': tanh,
50
    'relu': relu,
D
dangqingqing 已提交
51 52 53
}


54
def lstm(
55 56 57 58 59 60 61 62 63 64 65 66
    input,  # T x 4D
    lod,  # 1 x N
    h0=None,  # N x D
    c0=None,  # N x D
    w_h=None,  # D x 4D
    w_b=None,  # 1 x 4D
    w_c=None,  # 1 x 3D
    is_reverse=False,
    act_gate=None,
    act_cell=None,
    act_cand=None,
):
D
dangqingqing 已提交
67
    def _step(x, w_h, w_c, h_pre, c_pre, act_gate, act_cell, act_cand):
68 69 70
        g = np.dot(h_pre, w_h)  # 1 x 4D
        g = g + x
        g = np.reshape(g, (1, g.size))
D
dangqingqing 已提交
71
        c, g_i, g_f, g_o = np.split(g, 4, axis=1)
72
        if w_c is None:
D
dangqingqing 已提交
73 74
            g_i = act_gate(g_i)  # 1 x D
            g_f = act_gate(g_f)  # 1 x D
75 76
        else:
            w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1)
D
dangqingqing 已提交
77 78
            g_i = act_gate(g_i + w_ic * c_pre)  # 1 x D
            g_f = act_gate(g_f + w_fc * c_pre)  # 1 x D
D
dangqingqing 已提交
79
        c = g_f * c_pre + g_i * act_cand(c)  # 1 x D
80 81

        if w_c is None:
D
dangqingqing 已提交
82
            g_o = act_gate(g_o)  # 1 x D
83 84
        else:
            _, _, w_oc = np.split(w_c, 3, axis=1)
D
dangqingqing 已提交
85 86
            g_o = act_gate(g_o + w_oc * c)  # 1 x D
        h = g_o * act_cell(c)
D
dangqingqing 已提交
87
        return h, c
88

89
    def _reverse(x, offset):
D
dangqingqing 已提交
90
        y = np.zeros_like(x)
91 92
        for i in range(len(offset) - 1):
            b, e = offset[i], offset[i + 1]
D
dangqingqing 已提交
93 94 95
            y[b:e, :] = np.flip(x[b:e, :], 0)
        return y

96 97 98 99
    offset = [0]
    for l in lod[0]:
        offset.append(offset[-1] + l)
    batch_size = len(lod[0])
100 101
    hidden = []
    cell = []
D
dangqingqing 已提交
102
    input = _reverse(input, offset) if is_reverse else input
103 104 105 106
    if w_b is not None:
        input = input + np.tile(w_b, (offset[-1], 1))
    for i in range(batch_size):
        # compute one sequence
107
        seq_len = lod[0][i]
108
        x = input[offset[i] : offset[i + 1], :]
109
        h_pre = h0[i]  # 1 x D
110
        c_pre = c0[i]  # 1 x D
111 112
        for j in range(seq_len):
            # compute one step
113 114 115
            h_pre, c_pre = _step(
                x[j], w_h, w_c, h_pre, c_pre, act_gate, act_cell, act_cand
            )
116 117 118
            hidden.append(h_pre.flatten())
            cell.append(c_pre.flatten())

119 120
    hidden = np.array(hidden).astype('float64')
    cell = np.array(cell).astype('float64')
D
dangqingqing 已提交
121 122 123 124

    hidden = _reverse(hidden, offset) if is_reverse else hidden
    cell = _reverse(cell, offset) if is_reverse else cell

125 126
    assert hidden.shape == (input.shape[0], input.shape[1] / 4)
    assert cell.shape == (input.shape[0], input.shape[1] / 4)
D
dangqingqing 已提交
127
    return hidden, cell
128 129


D
dangqingqing 已提交
130
class TestLstmOp(OpTest):
131 132 133
    def set_is_test(self):
        self.is_test = False

134
    def set_lod(self):
135
        self.lod = [[2, 3, 2]]
136 137

    def set_argument(self):
138
        self.set_is_test()
139
        self.set_lod()
140 141
        self.D = 16

142 143 144
        self.act_gate = 'sigmoid'
        self.act_cell = 'tanh'
        self.act_cand = 'tanh'
D
dangqingqing 已提交
145

D
dangqingqing 已提交
146
        self.has_initial_state = False
D
dangqingqing 已提交
147
        self.is_reverse = False
D
dangqingqing 已提交
148
        self.use_peepholes = True
D
dangqingqing 已提交
149 150

    def setUp(self):
151
        self.set_argument()
152
        self.op_type = 'lstm'
153 154
        T = sum(self.lod[0])
        N = len(self.lod[0])
D
dangqingqing 已提交
155

156
        x = np.random.normal(size=(T, 4 * self.D)).astype('float64')
D
dangqingqing 已提交
157 158 159 160 161 162
        if self.has_initial_state:
            h0 = np.random.normal(size=(N, self.D)).astype('float64')
            c0 = np.random.normal(size=(N, self.D)).astype('float64')
        else:
            h0 = np.zeros((N, self.D)).astype('float64')
            c0 = np.zeros((N, self.D)).astype('float64')
163
        w = np.random.normal(size=(self.D, 4 * self.D)).astype('float64')
D
dangqingqing 已提交
164 165 166 167
        if self.use_peepholes:
            b = np.random.normal(size=(1, 7 * self.D)).astype('float64')
        else:
            b = np.random.normal(size=(1, 4 * self.D)).astype('float64')
D
dangqingqing 已提交
168

169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
        w_b = b[:, 0 : 4 * self.D]
        w_c = b[:, 4 * self.D :] if self.use_peepholes else None
        h, c = lstm(
            x,
            self.lod,
            h0,
            c0,
            w,
            w_b,
            w_c,
            self.is_reverse,
            ACTIVATION[self.act_gate],
            ACTIVATION[self.act_cell],
            ACTIVATION[self.act_cand],
        )
184

185 186
        self.inputs = {'Input': (x, self.lod), 'Weight': w}

D
dangqingqing 已提交
187
        self.inputs['Bias'] = b
188

D
dangqingqing 已提交
189 190 191
        if self.has_initial_state:
            self.inputs['H0'] = h0
            self.inputs['C0'] = c0
192

193 194 195 196
        self.outputs = {
            'Hidden': (h, self.lod),
            'Cell': (c, self.lod),
        }
197
        self.attrs = {
D
dangqingqing 已提交
198
            'use_peepholes': self.use_peepholes,
199 200 201
            'is_reverse': self.is_reverse,
            'gate_activation': self.act_gate,
            'cell_activation': self.act_cell,
202
            'candidate_activation': self.act_cand,
203
            'is_test': self.is_test,
204 205
        }

D
dangqingqing 已提交
206
    def test_check_output(self):
H
hong 已提交
207
        self.check_output(atol=1e-8, check_dygraph=False)
208

D
dangqingqing 已提交
209
    def test_check_grad(self):
D
dangqingqing 已提交
210
        # TODO(qingqing) remove folowing lines after the check_grad is refined.
211
        N = len(self.lod[0])
D
dangqingqing 已提交
212
        self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
213 214 215 216 217 218 219 220 221
        self.outputs['BatchCellPreAct'] = np.zeros((N, self.D)).astype(
            'float64'
        )
        self.check_grad(
            ['Input', 'Weight', 'Bias'],
            ['Hidden'],
            max_relative_error=5e-4,
            check_dygraph=False,
        )
222 223


224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
class TestLstmOpCase1(TestLstmOp):
    def set_lod(self):
        self.lod = [[0, 3, 2]]


class TestLstmOpCase2(TestLstmOp):
    def set_lod(self):
        self.lod = [[0, 3, 0]]


class TestLstmOpCase3(TestLstmOp):
    def set_lod(self):
        self.lod = [[2, 0, 4]]


239 240 241 242 243 244 245 246 247
class TestLstmOpInference(TestLstmOp):
    def set_is_test(self):
        self.is_test = True

    # avoid checking gradient
    def test_check_grad(self):
        pass


248 249
# class TestLstmOpHasInitial(TestLstmOp):
#     def set_argument(self):
250
#         self.lod = [[2, 3, 2]]
251 252 253 254 255 256 257 258 259 260 261 262
#         self.D = 16

#         self.act_gate = 'sigmoid'
#         self.act_cell = 'tanh'
#         self.act_cand = 'tanh'

#         self.has_initial_state = True
#         self.is_reverse = True
#         self.use_peepholes = True

#     def test_check_grad(self):
#         # TODO(qingqing) remove folowing lines after the check_grad is refined.
263
#         N = len(self.lod[0])
264 265 266 267 268 269 270 271
#         self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
#         self.outputs['BatchCellPreAct'] = np.zeros(
#             (N, self.D)).astype('float64')
#         self.check_grad(
#             ['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'],
#             max_relative_error=5e-4)

#     def test_check_grad_ingore_bias(self):
272
#         N = len(self.lod[0])
273 274 275 276 277 278 279 280 281
#         self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
#         self.outputs['BatchCellPreAct'] = np.zeros(
#             (N, self.D)).astype('float64')
#         self.check_grad(
#             ['Input', 'Weight'], ['Hidden'],
#             max_relative_error=5e-4,
#             no_grad_set=set('Bias'))

#     def test_check_grad_ingore_weight(self):
282
#         N = len(self.lod[0])
283 284 285 286 287 288 289 290 291
#         self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
#         self.outputs['BatchCellPreAct'] = np.zeros(
#             (N, self.D)).astype('float64')
#         self.check_grad(
#             ['Input', 'Bias'], ['Hidden'],
#             max_relative_error=5e-4,
#             no_grad_set=set('Weight'))

#     def test_check_grad_ingore_input(self):
292
#         N = len(self.lod[0])
293 294 295 296 297 298 299 300 301
#         self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
#         self.outputs['BatchCellPreAct'] = np.zeros(
#             (N, self.D)).astype('float64')
#         self.check_grad(
#             ['Weight', 'Bias'], ['Hidden'],
#             max_relative_error=5e-4,
#             no_grad_set=set('Input'))

#     def test_check_grad_ingore_h0(self):
302
#         N = len(self.lod[0])
303 304 305 306 307 308 309 310 311
#         self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
#         self.outputs['BatchCellPreAct'] = np.zeros(
#             (N, self.D)).astype('float64')
#         self.check_grad(
#             ['Input', 'Weight', 'Bias', 'C0'], ['Hidden'],
#             max_relative_error=5e-4,
#             no_grad_set=set('H0'))

#     def test_check_grad_ingore_c0(self):
312
#         N = len(self.lod[0])
313 314 315 316 317 318 319 320 321 322
#         self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
#         self.outputs['BatchCellPreAct'] = np.zeros(
#             (N, self.D)).astype('float64')
#         self.check_grad(
#             ['Input', 'Weight', 'Bias', 'H0'], ['Hidden'],
#             max_relative_error=5e-4,
#             no_grad_set=set('C0'))

# class TestLstmOpRerverse(TestLstmOp):
#     def set_argument(self):
323
#         self.lod = [[2, 3, 2]]
324 325 326 327 328 329 330 331 332 333 334 335
#         self.D = 16

#         self.act_gate = 'sigmoid'
#         self.act_cell = 'tanh'
#         self.act_cand = 'tanh'

#         self.has_initial_state = False
#         self.is_reverse = True
#         self.use_peepholes = True

# class TestLstmOpNotUsePeepholes(TestLstmOp):
#     def set_argument(self):
336
#         self.lod = [[2, 3, 2]]
337 338 339 340 341 342 343 344 345
#         self.D = 16

#         self.act_gate = 'sigmoid'
#         self.act_cell = 'tanh'
#         self.act_cand = 'tanh'

#         self.has_initial_state = False
#         self.is_reverse = True
#         self.use_peepholes = False
346 347

if __name__ == '__main__':
348
    unittest.main()